Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e53138a1
Unverified
Commit
e53138a1
authored
Sep 22, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 22, 2020
Browse files
[s2s] add src_lang kwarg for distributed eval (#7300)
parent
a9c7849c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
1 deletion
+17
-1
examples/seq2seq/run_distributed_eval.py
examples/seq2seq/run_distributed_eval.py
+16
-1
src/transformers/tokenization_mbart.py
src/transformers/tokenization_mbart.py
+1
-0
No files found.
examples/seq2seq/run_distributed_eval.py
View file @
e53138a1
...
...
@@ -38,6 +38,9 @@ def eval_data_dir(
fp16
=
False
,
task
=
"summarization"
,
local_rank
=
None
,
src_lang
=
None
,
tgt_lang
=
None
,
prefix
=
""
,
**
generate_kwargs
,
)
->
Dict
:
"""Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json"""
...
...
@@ -57,6 +60,8 @@ def eval_data_dir(
use_task_specific_params
(
model
,
task
)
# update config with task specific params
if
max_source_length
is
None
:
max_source_length
=
tokenizer
.
model_max_length
if
prefix
is
None
:
prefix
=
prefix
or
getattr
(
model
.
config
,
"prefix"
,
""
)
or
""
ds
=
Seq2SeqDataset
(
tokenizer
,
data_dir
,
...
...
@@ -64,7 +69,9 @@ def eval_data_dir(
max_target_length
=
1024
,
type_path
=
type_path
,
n_obs
=
n_obs
,
prefix
=
model
.
config
.
prefix
,
src_lang
=
src_lang
,
tgt_lang
=
tgt_lang
,
prefix
=
prefix
,
)
# I set shuffle=True for a more accurate progress bar.
# If all the longest samples are first, the prog bar estimate is too high at the beginning.
...
...
@@ -118,6 +125,11 @@ def run_generate():
required
=
False
,
help
=
"How long should master process wait for other processes to finish."
,
)
parser
.
add_argument
(
"--src_lang"
,
type
=
str
,
default
=
None
,
required
=
False
)
parser
.
add_argument
(
"--tgt_lang"
,
type
=
str
,
default
=
None
,
required
=
False
)
parser
.
add_argument
(
"--prefix"
,
type
=
str
,
required
=
False
,
default
=
None
,
help
=
"will be added to the begininng of src examples"
)
parser
.
add_argument
(
"--fp16"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--debug"
,
action
=
"store_true"
)
start_time
=
time
.
time
()
...
...
@@ -144,6 +156,9 @@ def run_generate():
local_rank
=
args
.
local_rank
,
n_obs
=
args
.
n_obs
,
max_source_length
=
args
.
max_source_length
,
prefix
=
args
.
prefix
,
src_lang
=
args
.
src_lang
,
tgt_lang
=
args
.
tgt_lang
,
**
generate_kwargs
,
)
...
...
src/transformers/tokenization_mbart.py
View file @
e53138a1
...
...
@@ -168,6 +168,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
truncation
:
bool
=
True
,
padding
:
str
=
"longest"
,
return_tensors
:
str
=
"pt"
,
add_prefix_space
:
bool
=
False
,
# ignored
**
kwargs
,
)
->
BatchEncoding
:
"""Prepare a batch that can be passed directly to an instance of MBartModel.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment