Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
de9e2979
Unverified
Commit
de9e2979
authored
Sep 13, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 13, 2020
Browse files
[s2s] distributed eval cleanup (#7110)
parent
54395d87
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
11 deletions
+15
-11
examples/seq2seq/run_distributed_eval.py
examples/seq2seq/run_distributed_eval.py
+13
-10
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+2
-1
No files found.
examples/seq2seq/run_distributed_eval.py
View file @
de9e2979
import
argparse
import
warnings
from
logging
import
getLogger
from
pathlib
import
Path
from
typing
import
Dict
...
...
@@ -18,6 +17,7 @@ try:
except
ImportError
:
from
utils
import
Seq2SeqDataset
,
parse_numeric_cl_kwargs
,
save_json
,
use_task_specific_params
DEFAULT_DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
...
...
@@ -51,6 +51,8 @@ def eval_data_dir(
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
logger
.
info
(
f
"Inferred tokenizer type:
{
tokenizer
.
__class__
}
"
)
# if this is wrong, check config.model_type.
use_task_specific_params
(
model
,
task
)
# update config with task specific params
if
max_source_length
is
None
:
max_source_length
=
tokenizer
.
model_max_length
ds
=
Seq2SeqDataset
(
tokenizer
,
data_dir
,
...
...
@@ -97,9 +99,11 @@ def run_generate():
default
=
"sshleifer/distilbart-xsum-12-3"
,
)
parser
.
add_argument
(
"--save_dir"
,
type
=
str
,
help
=
"where to save"
,
default
=
"tmp_gen"
)
parser
.
add_argument
(
"--prefix"
,
type
=
str
,
default
=
"test"
,
help
=
"which subset to evaluate typically train/val/test"
)
parser
.
add_argument
(
"--max_source_length"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--type_path"
,
type
=
str
,
default
=
"test"
,
help
=
"which subset to evaluate typically train/val/test"
)
parser
.
add_argument
(
"--reference_path"
,
type
=
str
,
required
=
False
,
help
=
"like cnn_dm/test.target"
)
parser
.
add_argument
(
"--score_path"
,
type
=
str
,
required
=
False
,
default
=
"metrics.json"
,
help
=
"where to save metrics"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"summarization"
,
help
=
"used for task_specific_params + metrics"
)
parser
.
add_argument
(
"--bs"
,
type
=
int
,
default
=
8
,
required
=
False
,
help
=
"batch size"
)
parser
.
add_argument
(
...
...
@@ -113,24 +117,23 @@ def run_generate():
parser
.
add_argument
(
"--save_source"
,
action
=
"store_true"
)
args
,
rest
=
parser
.
parse_known_args
()
parsed
=
parse_numeric_cl_kwargs
(
rest
)
if
parsed
:
print
(
f
"parsed the following generate kwargs:
{
parsed
}
"
)
generate_kwargs
=
parse_numeric_cl_kwargs
(
rest
)
if
generate_kwargs
:
print
(
f
"parsed the following generate kwargs:
{
generate_kwargs
}
"
)
Path
(
args
.
save_dir
).
mkdir
(
exist_ok
=
True
)
if
args
.
reference_path
is
None
and
Path
(
args
.
score_path
).
exists
():
warnings
.
warn
(
f
"score_path
{
args
.
score_path
}
will be overwritten unless you type ctrl-c."
)
eval_data_dir
(
args
.
input_path
,
args
.
save_dir
,
args
.
model_name
,
prefix
=
args
.
prefix
,
type_path
=
args
.
type_path
,
batch_size
=
args
.
bs
,
fp16
=
args
.
fp16
,
task
=
args
.
task
,
local_rank
=
args
.
local_rank
,
n_obs
=
args
.
n_obs
,
save_source
=
args
.
save_source
,
**
parsed
,
max_source_length
=
args
.
max_source_length
,
**
generate_kwargs
,
)
...
...
examples/seq2seq/utils.py
View file @
de9e2979
...
...
@@ -98,7 +98,8 @@ class AbstractSeq2SeqDataset(Dataset):
self
.
max_target_length
=
max_target_length
assert
min
(
self
.
src_lens
)
>
0
,
f
"found empty line in
{
self
.
src_file
}
"
self
.
tokenizer
=
tokenizer
self
.
prefix
=
prefix
self
.
prefix
=
prefix
if
prefix
is
not
None
else
""
if
n_obs
is
not
None
:
self
.
src_lens
=
self
.
src_lens
[:
n_obs
]
self
.
pad_token_id
=
self
.
tokenizer
.
pad_token_id
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment