Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
efeab6a3
Unverified
Commit
efeab6a3
authored
Sep 17, 2020
by
Stas Bekman
Committed by
GitHub
Sep 17, 2020
Browse files
[s2s] run_eval/run_eval_search tweaks (#7192)
Co-authored-by:
Sam Shleifer
<
sshleifer@gmail.com
>
parent
9c5bcab5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
48 deletions
+58
-48
examples/seq2seq/run_eval_search.py
examples/seq2seq/run_eval_search.py
+8
-8
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+50
-40
No files found.
examples/seq2seq/run_eval_search.py
View file @
efeab6a3
...
...
@@ -15,7 +15,6 @@ except ImportError:
# To add a new task, simply list the score names that `run_eval.run_generate()` returns
task_score_names
=
{
"translation"
:
[
"bleu"
],
"translation_en_to_de"
:
[
"bleu"
],
"summarization"
:
[
"rouge1"
,
"rouge2"
,
"rougeL"
],
}
...
...
@@ -66,9 +65,7 @@ def run_search():
parser
.
add_argument
(
"--bs"
,
type
=
int
,
default
=
8
,
required
=
False
,
help
=
"initial batch size (may get reduced if it's too big)"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
help
=
"used for task_specific_params + metrics"
,
choices
=
task_score_names
.
keys
()
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
help
=
"used for task_specific_params + metrics"
)
parser
.
add_argument
(
"--info"
,
nargs
=
"?"
,
...
...
@@ -81,8 +78,11 @@ def run_search():
args_main
.
extend
([
"--task"
,
args
.
task
])
args_normal
=
[
prog
]
+
args_main
# to support variations like translation_en_to_de"
task
=
"translation"
if
"translation"
in
args
.
task
else
"summarization"
matrix
,
col_names
=
parse_search_arg
(
args
.
search
)
col_names
[
0
:
0
]
=
task_score_names
[
args
.
task
]
# score cols first
col_names
[
0
:
0
]
=
task_score_names
[
task
]
# score cols first
col_widths
=
{
col
:
len
(
str
(
col
))
for
col
in
col_names
}
results
=
[]
for
r
in
matrix
:
...
...
@@ -96,7 +96,7 @@ def run_search():
scores
=
run_generate
(
verbose
=
False
)
# make sure scores are first in the table
result
=
OrderedDict
()
for
score
in
task_score_names
[
args
.
task
]:
for
score
in
task_score_names
[
task
]:
result
[
score
]
=
scores
[
score
]
result
.
update
(
hparams
)
results
.
append
(
result
)
...
...
@@ -107,14 +107,14 @@ def run_search():
if
l
>
col_widths
[
k
]:
col_widths
[
k
]
=
l
results_sorted
=
sorted
(
results
,
key
=
operator
.
itemgetter
(
*
task_score_names
[
args
.
task
]),
reverse
=
True
)
results_sorted
=
sorted
(
results
,
key
=
operator
.
itemgetter
(
*
task_score_names
[
task
]),
reverse
=
True
)
print
(
" | "
.
join
([
f
"
{
col
:
{
col_widths
[
col
]
}}
"
for
col
in
col_names
]))
print
(
" | "
.
join
([
f
"
{
'-'
*
col_widths
[
col
]
}
"
for
col
in
col_names
]))
for
row
in
results_sorted
:
print
(
" | "
.
join
([
f
"
{
row
[
col
]:
{
col_widths
[
col
]
}}
"
for
col
in
col_names
]))
best
=
results_sorted
[
0
]
for
score
in
task_score_names
[
args
.
task
]:
for
score
in
task_score_names
[
task
]:
del
best
[
score
]
best_args
=
[
f
"--
{
k
}
{
v
}
"
for
k
,
v
in
best
.
items
()]
dyn_args
=
[
"--bs"
,
str
(
args
.
bs
)]
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
efeab6a3
...
...
@@ -106,6 +106,9 @@ T5_TINY = "patrickvonplaten/t5-tiny-random"
BART_TINY
=
"sshleifer/bart-tiny-random"
MBART_TINY
=
"sshleifer/tiny-mbart"
MARIAN_TINY
=
"sshleifer/tiny-marian-en-de"
BERT_BASE_CASED
=
"bert-base-cased"
PEGASUS_XSUM
=
"google/pegasus-xsum"
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
logger
.
addHandler
(
stream_handler
)
logging
.
disable
(
logging
.
CRITICAL
)
# remove noisy download output from tracebacks
...
...
@@ -284,8 +287,7 @@ class TestSummarizationDistiller(unittest.TestCase):
return
model
@
pytest
.
mark
.
parametrize
(
"model"
,
[
pytest
.
param
(
T5_TINY
),
pytest
.
param
(
BART_TINY
),
pytest
.
param
(
MBART_TINY
)])
def
test_run_eval
(
model
):
def
run_eval_tester
(
model
):
input_file_name
=
Path
(
tempfile
.
mkdtemp
())
/
"utest_input.source"
output_file_name
=
input_file_name
.
parent
/
"utest_output.txt"
assert
not
output_file_name
.
exists
()
...
...
@@ -293,28 +295,39 @@ def test_run_eval(model):
_dump_articles
(
input_file_name
,
articles
)
score_path
=
str
(
Path
(
tempfile
.
mkdtemp
())
/
"scores.json"
)
task
=
"translation_en_to_de"
if
model
==
T5_TINY
else
"summarization"
testargs
=
[
"run_eval.py"
,
model
,
str
(
input_file_name
),
str
(
output_file_name
),
"--score_path"
,
score_path
,
"--task"
,
task
,
"--num_beams"
,
"2"
,
"--length_penalty"
,
"2.0"
,
]
testargs
=
f
"""
run_eval_search.py
{
model
}
{
input_file_name
}
{
output_file_name
}
--score_path
{
score_path
}
--task
{
task
}
--num_beams 2
--length_penalty 2.0
"""
.
split
()
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
run_generate
()
assert
Path
(
output_file_name
).
exists
()
os
.
remove
(
Path
(
output_file_name
))
# test one model to quickly (no-@slow) catch simple problems and do an
# extensive testing of functionality with multiple models as @slow separately
def
test_run_eval
():
run_eval_tester
(
T5_TINY
)
# any extra models should go into the list here - can be slow
@
slow
@
pytest
.
mark
.
parametrize
(
"model"
,
[
pytest
.
param
(
T5_TINY
)])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
BART_TINY
,
MBART_TINY
])
def
test_run_eval_slow
(
model
):
run_eval_tester
(
model
)
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
@
slow
@
pytest
.
mark
.
parametrize
(
"model"
,
[
T5_TINY
,
MBART_TINY
])
def
test_run_eval_search
(
model
):
input_file_name
=
Path
(
tempfile
.
mkdtemp
())
/
"utest_input.source"
output_file_name
=
input_file_name
.
parent
/
"utest_output.txt"
...
...
@@ -335,20 +348,17 @@ def test_run_eval_search(model):
_dump_articles
(
input_file_name
,
text
[
"en"
])
_dump_articles
(
reference_path
,
text
[
"de"
])
task
=
"translation_en_to_de"
if
model
==
T5_TINY
else
"summarization"
testargs
=
[
"run_eval_search.py"
,
model
,
str
(
input_file_name
),
str
(
output_file_name
),
"--score_path"
,
score_path
,
"--reference_path"
,
reference_path
,
"--task"
,
task
,
"--search"
,
"num_beams=1:2 length_penalty=0.9:1.0"
,
]
testargs
=
f
"""
run_eval_search.py
--model_name
{
model
}
--data_dir
{
str
(
input_file_name
)
}
--save_dir
{
str
(
output_file_name
)
}
--score_path
{
score_path
}
--reference_path
{
reference_path
}
,
--task
{
task
}
--search num_beams=1:2 length_penalty=0.9:1.0
"""
.
split
()
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
CaptureStdout
()
as
cs
:
run_search
()
...
...
@@ -367,8 +377,8 @@ def test_run_eval_search(model):
@
pytest
.
mark
.
parametrize
(
[
"model"
]
,
[
pytest
.
param
(
T5_TINY
)
,
pytest
.
param
(
BART_TINY
)
,
pytest
.
param
(
MBART_TINY
)
,
pytest
.
param
(
MARIAN_TINY
)
],
"model"
,
[
T5_TINY
,
BART_TINY
,
MBART_TINY
,
MARIAN_TINY
],
)
def
test_finetune
(
model
):
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
...
...
@@ -541,13 +551,13 @@ def test_pack_dataset():
@
pytest
.
mark
.
parametrize
(
[
"tok_name"
]
,
"tok_name"
,
[
pytest
.
param
(
MBART_TINY
)
,
pytest
.
param
(
MARIAN_TINY
)
,
pytest
.
param
(
T5_TINY
)
,
pytest
.
param
(
BART_TINY
)
,
pytest
.
param
(
"google/pegasus-xsum"
)
,
MBART_TINY
,
MARIAN_TINY
,
T5_TINY
,
BART_TINY
,
PEGASUS_XSUM
,
],
)
def
test_seq2seq_dataset_truncation
(
tok_name
):
...
...
@@ -589,7 +599,7 @@ def test_seq2seq_dataset_truncation(tok_name):
break
# No need to test every batch
@
pytest
.
mark
.
parametrize
(
[
"tok"
]
,
[
pytest
.
param
(
BART_TINY
)
,
pytest
.
param
(
"bert-base-cased"
)
])
@
pytest
.
mark
.
parametrize
(
"tok"
,
[
BART_TINY
,
BERT_BASE_CASED
])
def
test_legacy_dataset_truncation
(
tok
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tok
)
tmp_dir
=
make_test_data_dir
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment