Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
efeab6a3
Unverified
Commit
efeab6a3
authored
Sep 17, 2020
by
Stas Bekman
Committed by
GitHub
Sep 17, 2020
Browse files
[s2s] run_eval/run_eval_search tweaks (#7192)
Co-authored-by:
Sam Shleifer
<
sshleifer@gmail.com
>
parent
9c5bcab5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
48 deletions
+58
-48
examples/seq2seq/run_eval_search.py
examples/seq2seq/run_eval_search.py
+8
-8
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+50
-40
No files found.
examples/seq2seq/run_eval_search.py
View file @
efeab6a3
...
...
@@ -15,7 +15,6 @@ except ImportError:
# To add a new task, simply list the score names that `run_eval.run_generate()` returns
task_score_names
=
{
"translation"
:
[
"bleu"
],
"translation_en_to_de"
:
[
"bleu"
],
"summarization"
:
[
"rouge1"
,
"rouge2"
,
"rougeL"
],
}
...
...
@@ -66,9 +65,7 @@ def run_search():
parser
.
add_argument
(
"--bs"
,
type
=
int
,
default
=
8
,
required
=
False
,
help
=
"initial batch size (may get reduced if it's too big)"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
help
=
"used for task_specific_params + metrics"
,
choices
=
task_score_names
.
keys
()
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
help
=
"used for task_specific_params + metrics"
)
parser
.
add_argument
(
"--info"
,
nargs
=
"?"
,
...
...
@@ -81,8 +78,11 @@ def run_search():
args_main
.
extend
([
"--task"
,
args
.
task
])
args_normal
=
[
prog
]
+
args_main
# to support variations like translation_en_to_de"
task
=
"translation"
if
"translation"
in
args
.
task
else
"summarization"
matrix
,
col_names
=
parse_search_arg
(
args
.
search
)
col_names
[
0
:
0
]
=
task_score_names
[
args
.
task
]
# score cols first
col_names
[
0
:
0
]
=
task_score_names
[
task
]
# score cols first
col_widths
=
{
col
:
len
(
str
(
col
))
for
col
in
col_names
}
results
=
[]
for
r
in
matrix
:
...
...
@@ -96,7 +96,7 @@ def run_search():
scores
=
run_generate
(
verbose
=
False
)
# make sure scores are first in the table
result
=
OrderedDict
()
for
score
in
task_score_names
[
args
.
task
]:
for
score
in
task_score_names
[
task
]:
result
[
score
]
=
scores
[
score
]
result
.
update
(
hparams
)
results
.
append
(
result
)
...
...
@@ -107,14 +107,14 @@ def run_search():
if
l
>
col_widths
[
k
]:
col_widths
[
k
]
=
l
results_sorted
=
sorted
(
results
,
key
=
operator
.
itemgetter
(
*
task_score_names
[
args
.
task
]),
reverse
=
True
)
results_sorted
=
sorted
(
results
,
key
=
operator
.
itemgetter
(
*
task_score_names
[
task
]),
reverse
=
True
)
print
(
" | "
.
join
([
f
"
{
col
:
{
col_widths
[
col
]
}}
"
for
col
in
col_names
]))
print
(
" | "
.
join
([
f
"
{
'-'
*
col_widths
[
col
]
}
"
for
col
in
col_names
]))
for
row
in
results_sorted
:
print
(
" | "
.
join
([
f
"
{
row
[
col
]:
{
col_widths
[
col
]
}}
"
for
col
in
col_names
]))
best
=
results_sorted
[
0
]
for
score
in
task_score_names
[
args
.
task
]:
for
score
in
task_score_names
[
task
]:
del
best
[
score
]
best_args
=
[
f
"--
{
k
}
{
v
}
"
for
k
,
v
in
best
.
items
()]
dyn_args
=
[
"--bs"
,
str
(
args
.
bs
)]
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
efeab6a3
...
...
@@ -106,6 +106,9 @@ T5_TINY = "patrickvonplaten/t5-tiny-random"
BART_TINY
=
"sshleifer/bart-tiny-random"
MBART_TINY
=
"sshleifer/tiny-mbart"
MARIAN_TINY
=
"sshleifer/tiny-marian-en-de"
BERT_BASE_CASED
=
"bert-base-cased"
PEGASUS_XSUM
=
"google/pegasus-xsum"
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
logger
.
addHandler
(
stream_handler
)
logging
.
disable
(
logging
.
CRITICAL
)
# remove noisy download output from tracebacks
...
...
@@ -284,8 +287,7 @@ class TestSummarizationDistiller(unittest.TestCase):
return
model
@
pytest
.
mark
.
parametrize
(
"model"
,
[
pytest
.
param
(
T5_TINY
),
pytest
.
param
(
BART_TINY
),
pytest
.
param
(
MBART_TINY
)])
def
test_run_eval
(
model
):
def
run_eval_tester
(
model
):
input_file_name
=
Path
(
tempfile
.
mkdtemp
())
/
"utest_input.source"
output_file_name
=
input_file_name
.
parent
/
"utest_output.txt"
assert
not
output_file_name
.
exists
()
...
...
@@ -293,28 +295,39 @@ def test_run_eval(model):
_dump_articles
(
input_file_name
,
articles
)
score_path
=
str
(
Path
(
tempfile
.
mkdtemp
())
/
"scores.json"
)
task
=
"translation_en_to_de"
if
model
==
T5_TINY
else
"summarization"
testargs
=
[
"run_eval.py"
,
model
,
str
(
input_file_name
),
str
(
output_file_name
),
"--score_path"
,
score_path
,
"--task"
,
task
,
"--num_beams"
,
"2"
,
"--length_penalty"
,
"2.0"
,
]
testargs
=
f
"""
run_eval_search.py
{
model
}
{
input_file_name
}
{
output_file_name
}
--score_path
{
score_path
}
--task
{
task
}
--num_beams 2
--length_penalty 2.0
"""
.
split
()
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
run_generate
()
assert
Path
(
output_file_name
).
exists
()
os
.
remove
(
Path
(
output_file_name
))
# test one model to quickly (no-@slow) catch simple problems and do an
# extensive testing of functionality with multiple models as @slow separately
def
test_run_eval
():
run_eval_tester
(
T5_TINY
)
# any extra models should go into the list here - can be slow
@
slow
@
pytest
.
mark
.
parametrize
(
"model"
,
[
pytest
.
param
(
T5_TINY
)])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
BART_TINY
,
MBART_TINY
])
def
test_run_eval_slow
(
model
):
run_eval_tester
(
model
)
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
@
slow
@
pytest
.
mark
.
parametrize
(
"model"
,
[
T5_TINY
,
MBART_TINY
])
def
test_run_eval_search
(
model
):
input_file_name
=
Path
(
tempfile
.
mkdtemp
())
/
"utest_input.source"
output_file_name
=
input_file_name
.
parent
/
"utest_output.txt"
...
...
@@ -335,20 +348,17 @@ def test_run_eval_search(model):
_dump_articles
(
input_file_name
,
text
[
"en"
])
_dump_articles
(
reference_path
,
text
[
"de"
])
task
=
"translation_en_to_de"
if
model
==
T5_TINY
else
"summarization"
testargs
=
[
"run_eval_search.py"
,
model
,
str
(
input_file_name
),
str
(
output_file_name
),
"--score_path"
,
score_path
,
"--reference_path"
,
reference_path
,
"--task"
,
task
,
"--search"
,
"num_beams=1:2 length_penalty=0.9:1.0"
,
]
testargs
=
f
"""
run_eval_search.py
--model_name
{
model
}
--data_dir
{
str
(
input_file_name
)
}
--save_dir
{
str
(
output_file_name
)
}
--score_path
{
score_path
}
--reference_path
{
reference_path
}
,
--task
{
task
}
--search num_beams=1:2 length_penalty=0.9:1.0
"""
.
split
()
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
CaptureStdout
()
as
cs
:
run_search
()
...
...
@@ -367,8 +377,8 @@ def test_run_eval_search(model):
@
pytest
.
mark
.
parametrize
(
[
"model"
]
,
[
pytest
.
param
(
T5_TINY
)
,
pytest
.
param
(
BART_TINY
)
,
pytest
.
param
(
MBART_TINY
)
,
pytest
.
param
(
MARIAN_TINY
)
],
"model"
,
[
T5_TINY
,
BART_TINY
,
MBART_TINY
,
MARIAN_TINY
],
)
def
test_finetune
(
model
):
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
...
...
@@ -541,13 +551,13 @@ def test_pack_dataset():
@
pytest
.
mark
.
parametrize
(
[
"tok_name"
]
,
"tok_name"
,
[
pytest
.
param
(
MBART_TINY
)
,
pytest
.
param
(
MARIAN_TINY
)
,
pytest
.
param
(
T5_TINY
)
,
pytest
.
param
(
BART_TINY
)
,
pytest
.
param
(
"google/pegasus-xsum"
)
,
MBART_TINY
,
MARIAN_TINY
,
T5_TINY
,
BART_TINY
,
PEGASUS_XSUM
,
],
)
def
test_seq2seq_dataset_truncation
(
tok_name
):
...
...
@@ -589,7 +599,7 @@ def test_seq2seq_dataset_truncation(tok_name):
break
# No need to test every batch
@
pytest
.
mark
.
parametrize
(
[
"tok"
]
,
[
pytest
.
param
(
BART_TINY
)
,
pytest
.
param
(
"bert-base-cased"
)
])
@
pytest
.
mark
.
parametrize
(
"tok"
,
[
BART_TINY
,
BERT_BASE_CASED
])
def
test_legacy_dataset_truncation
(
tok
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok
)
tmp_dir
=
make_test_data_dir
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment