Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b76cb1c3
Unverified
Commit
b76cb1c3
authored
Sep 12, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 12, 2020
Browse files
[s2s] run_eval supports --prefix clarg. (#6953)
parent
563ffb3d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
3 deletions
+9
-3
examples/seq2seq/run_eval.py
examples/seq2seq/run_eval.py
+8
-2
examples/seq2seq/test_bash_script.py
examples/seq2seq/test_bash_script.py
+1
-1
No files found.
examples/seq2seq/run_eval.py
View file @
b76cb1c3
...
@@ -36,6 +36,7 @@ def generate_summaries_or_translations(
...
@@ -36,6 +36,7 @@ def generate_summaries_or_translations(
device
:
str
=
DEFAULT_DEVICE
,
device
:
str
=
DEFAULT_DEVICE
,
fp16
=
False
,
fp16
=
False
,
task
=
"summarization"
,
task
=
"summarization"
,
prefix
=
None
,
**
generate_kwargs
,
**
generate_kwargs
,
)
->
Dict
:
)
->
Dict
:
"""Save model.generate results to <out_file>, and return how long it took."""
"""Save model.generate results to <out_file>, and return how long it took."""
...
@@ -51,9 +52,10 @@ def generate_summaries_or_translations(
...
@@ -51,9 +52,10 @@ def generate_summaries_or_translations(
start_time
=
time
.
time
()
start_time
=
time
.
time
()
# update config with task specific params
# update config with task specific params
use_task_specific_params
(
model
,
task
)
use_task_specific_params
(
model
,
task
)
if
prefix
is
None
:
prefix
=
prefix
or
getattr
(
model
.
config
,
"prefix"
,
""
)
or
""
for
examples_chunk
in
tqdm
(
list
(
chunks
(
examples
,
batch_size
))):
for
examples_chunk
in
tqdm
(
list
(
chunks
(
examples
,
batch_size
))):
if
"t5"
in
model_name
:
examples_chunk
=
[
prefix
+
text
for
text
in
examples_chunk
]
examples_chunk
=
[
model
.
config
.
prefix
+
text
for
text
in
examples_chunk
]
batch
=
tokenizer
(
examples_chunk
,
return_tensors
=
"pt"
,
truncation
=
True
,
padding
=
"longest"
).
to
(
device
)
batch
=
tokenizer
(
examples_chunk
,
return_tensors
=
"pt"
,
truncation
=
True
,
padding
=
"longest"
).
to
(
device
)
summaries
=
model
.
generate
(
summaries
=
model
.
generate
(
input_ids
=
batch
.
input_ids
,
input_ids
=
batch
.
input_ids
,
...
@@ -78,6 +80,9 @@ def run_generate():
...
@@ -78,6 +80,9 @@ def run_generate():
parser
.
add_argument
(
"--reference_path"
,
type
=
str
,
required
=
False
,
help
=
"like cnn_dm/test.target"
)
parser
.
add_argument
(
"--reference_path"
,
type
=
str
,
required
=
False
,
help
=
"like cnn_dm/test.target"
)
parser
.
add_argument
(
"--score_path"
,
type
=
str
,
required
=
False
,
default
=
"metrics.json"
,
help
=
"where to save metrics"
)
parser
.
add_argument
(
"--score_path"
,
type
=
str
,
required
=
False
,
default
=
"metrics.json"
,
help
=
"where to save metrics"
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
required
=
False
,
default
=
DEFAULT_DEVICE
,
help
=
"cuda, cuda:1, cpu etc."
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
required
=
False
,
default
=
DEFAULT_DEVICE
,
help
=
"cuda, cuda:1, cpu etc."
)
parser
.
add_argument
(
"--prefix"
,
type
=
str
,
required
=
False
,
default
=
None
,
help
=
"will be added to the begininng of src examples"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"summarization"
,
help
=
"used for task_specific_params + metrics"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"summarization"
,
help
=
"used for task_specific_params + metrics"
)
parser
.
add_argument
(
"--bs"
,
type
=
int
,
default
=
8
,
required
=
False
,
help
=
"batch size"
)
parser
.
add_argument
(
"--bs"
,
type
=
int
,
default
=
8
,
required
=
False
,
help
=
"batch size"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -103,6 +108,7 @@ def run_generate():
...
@@ -103,6 +108,7 @@ def run_generate():
device
=
args
.
device
,
device
=
args
.
device
,
fp16
=
args
.
fp16
,
fp16
=
args
.
fp16
,
task
=
args
.
task
,
task
=
args
.
task
,
prefix
=
args
.
prefix
,
**
parsed
,
**
parsed
,
)
)
if
args
.
reference_path
is
None
:
if
args
.
reference_path
is
None
:
...
...
examples/seq2seq/test_bash_script.py
View file @
b76cb1c3
...
@@ -160,7 +160,7 @@ def test_opus_mt_distill_script():
...
@@ -160,7 +160,7 @@ def test_opus_mt_distill_script():
metrics
=
load_json
(
model
.
metrics_save_path
)
metrics
=
load_json
(
model
.
metrics_save_path
)
first_step_stats
=
metrics
[
"val"
][
0
]
first_step_stats
=
metrics
[
"val"
][
0
]
last_step_stats
=
metrics
[
"val"
][
-
1
]
last_step_stats
=
metrics
[
"val"
][
-
1
]
assert
len
(
metrics
[
"val"
])
=
=
(
args
.
max_epochs
/
args
.
val_check_interval
)
+
1
# +1 accounts for val_sanity_check
assert
len
(
metrics
[
"val"
])
>
=
(
args
.
max_epochs
/
args
.
val_check_interval
)
# +1 accounts for val_sanity_check
assert
last_step_stats
[
"val_avg_gen_time"
]
>=
0.01
assert
last_step_stats
[
"val_avg_gen_time"
]
>=
0.01
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment