Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
de9e2979
Unverified
Commit
de9e2979
authored
Sep 13, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 13, 2020
Browse files
[s2s] distributed eval cleanup (#7110)
parent
54395d87
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
11 deletions
+15
-11
examples/seq2seq/run_distributed_eval.py
examples/seq2seq/run_distributed_eval.py
+13
-10
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+2
-1
No files found.
examples/seq2seq/run_distributed_eval.py
View file @
de9e2979
import
argparse
import
argparse
import
warnings
from
logging
import
getLogger
from
logging
import
getLogger
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
from
typing
import
Dict
...
@@ -18,6 +17,7 @@ try:
...
@@ -18,6 +17,7 @@ try:
except
ImportError
:
except
ImportError
:
from
utils
import
Seq2SeqDataset
,
parse_numeric_cl_kwargs
,
save_json
,
use_task_specific_params
from
utils
import
Seq2SeqDataset
,
parse_numeric_cl_kwargs
,
save_json
,
use_task_specific_params
DEFAULT_DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
DEFAULT_DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
...
@@ -51,6 +51,8 @@ def eval_data_dir(
...
@@ -51,6 +51,8 @@ def eval_data_dir(
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
logger
.
info
(
f
"Inferred tokenizer type:
{
tokenizer
.
__class__
}
"
)
# if this is wrong, check config.model_type.
logger
.
info
(
f
"Inferred tokenizer type:
{
tokenizer
.
__class__
}
"
)
# if this is wrong, check config.model_type.
use_task_specific_params
(
model
,
task
)
# update config with task specific params
use_task_specific_params
(
model
,
task
)
# update config with task specific params
if
max_source_length
is
None
:
max_source_length
=
tokenizer
.
model_max_length
ds
=
Seq2SeqDataset
(
ds
=
Seq2SeqDataset
(
tokenizer
,
tokenizer
,
data_dir
,
data_dir
,
...
@@ -97,9 +99,11 @@ def run_generate():
...
@@ -97,9 +99,11 @@ def run_generate():
default
=
"sshleifer/distilbart-xsum-12-3"
,
default
=
"sshleifer/distilbart-xsum-12-3"
,
)
)
parser
.
add_argument
(
"--save_dir"
,
type
=
str
,
help
=
"where to save"
,
default
=
"tmp_gen"
)
parser
.
add_argument
(
"--save_dir"
,
type
=
str
,
help
=
"where to save"
,
default
=
"tmp_gen"
)
parser
.
add_argument
(
"--prefix"
,
type
=
str
,
default
=
"test"
,
help
=
"which subset to evaluate typically train/val/test"
)
parser
.
add_argument
(
"--max_source_length"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--type_path"
,
type
=
str
,
default
=
"test"
,
help
=
"which subset to evaluate typically train/val/test"
)
parser
.
add_argument
(
"--reference_path"
,
type
=
str
,
required
=
False
,
help
=
"like cnn_dm/test.target"
)
parser
.
add_argument
(
"--reference_path"
,
type
=
str
,
required
=
False
,
help
=
"like cnn_dm/test.target"
)
parser
.
add_argument
(
"--score_path"
,
type
=
str
,
required
=
False
,
default
=
"metrics.json"
,
help
=
"where to save metrics"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"summarization"
,
help
=
"used for task_specific_params + metrics"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"summarization"
,
help
=
"used for task_specific_params + metrics"
)
parser
.
add_argument
(
"--bs"
,
type
=
int
,
default
=
8
,
required
=
False
,
help
=
"batch size"
)
parser
.
add_argument
(
"--bs"
,
type
=
int
,
default
=
8
,
required
=
False
,
help
=
"batch size"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -113,24 +117,23 @@ def run_generate():
...
@@ -113,24 +117,23 @@ def run_generate():
parser
.
add_argument
(
"--save_source"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--save_source"
,
action
=
"store_true"
)
args
,
rest
=
parser
.
parse_known_args
()
args
,
rest
=
parser
.
parse_known_args
()
parsed
=
parse_numeric_cl_kwargs
(
rest
)
generate_kwargs
=
parse_numeric_cl_kwargs
(
rest
)
if
parsed
:
if
generate_kwargs
:
print
(
f
"parsed the following generate kwargs:
{
parsed
}
"
)
print
(
f
"parsed the following generate kwargs:
{
generate_kwargs
}
"
)
Path
(
args
.
save_dir
).
mkdir
(
exist_ok
=
True
)
Path
(
args
.
save_dir
).
mkdir
(
exist_ok
=
True
)
if
args
.
reference_path
is
None
and
Path
(
args
.
score_path
).
exists
():
warnings
.
warn
(
f
"score_path
{
args
.
score_path
}
will be overwritten unless you type ctrl-c."
)
eval_data_dir
(
eval_data_dir
(
args
.
input_path
,
args
.
input_path
,
args
.
save_dir
,
args
.
save_dir
,
args
.
model_name
,
args
.
model_name
,
prefix
=
args
.
prefix
,
type_path
=
args
.
type_path
,
batch_size
=
args
.
bs
,
batch_size
=
args
.
bs
,
fp16
=
args
.
fp16
,
fp16
=
args
.
fp16
,
task
=
args
.
task
,
task
=
args
.
task
,
local_rank
=
args
.
local_rank
,
local_rank
=
args
.
local_rank
,
n_obs
=
args
.
n_obs
,
n_obs
=
args
.
n_obs
,
save_source
=
args
.
save_source
,
save_source
=
args
.
save_source
,
**
parsed
,
max_source_length
=
args
.
max_source_length
,
**
generate_kwargs
,
)
)
...
...
examples/seq2seq/utils.py
View file @
de9e2979
...
@@ -98,7 +98,8 @@ class AbstractSeq2SeqDataset(Dataset):
...
@@ -98,7 +98,8 @@ class AbstractSeq2SeqDataset(Dataset):
self
.
max_target_length
=
max_target_length
self
.
max_target_length
=
max_target_length
assert
min
(
self
.
src_lens
)
>
0
,
f
"found empty line in
{
self
.
src_file
}
"
assert
min
(
self
.
src_lens
)
>
0
,
f
"found empty line in
{
self
.
src_file
}
"
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
prefix
=
prefix
self
.
prefix
=
prefix
if
prefix
is
not
None
else
""
if
n_obs
is
not
None
:
if
n_obs
is
not
None
:
self
.
src_lens
=
self
.
src_lens
[:
n_obs
]
self
.
src_lens
=
self
.
src_lens
[:
n_obs
]
self
.
pad_token_id
=
self
.
tokenizer
.
pad_token_id
self
.
pad_token_id
=
self
.
tokenizer
.
pad_token_id
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment