Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
61518e2d
Unverified
Commit
61518e2d
authored
Aug 26, 2020
by
Sam Shleifer
Committed by
GitHub
Aug 26, 2020
Browse files
[s2s] run_eval.py QOL improvements and cleanup(#6746)
parent
434936f3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
51 additions
and
22 deletions
+51
-22
examples/seq2seq/run_eval.py
examples/seq2seq/run_eval.py
+38
-20
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+13
-2
No files found.
examples/seq2seq/run_eval.py
View file @
61518e2d
import
argparse
import
json
import
time
import
warnings
from
logging
import
getLogger
from
pathlib
import
Path
from
typing
import
Dict
,
List
import
torch
from
tqdm
import
tqdm
...
...
@@ -8,10 +12,12 @@ from tqdm import tqdm
from
transformers
import
AutoModelForSeq2SeqLM
,
AutoTokenizer
logger
=
getLogger
(
__name__
)
try
:
from
.utils
import
calculate_bleu
,
calculate_rouge
,
trim_batch
,
use_task_specific_params
from
.utils
import
calculate_bleu
,
calculate_rouge
,
use_task_specific_params
except
ImportError
:
from
utils
import
calculate_bleu
,
calculate_rouge
,
trim_batch
,
use_task_specific_params
from
utils
import
calculate_bleu
,
calculate_rouge
,
use_task_specific_params
DEFAULT_DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
...
...
@@ -23,7 +29,7 @@ def chunks(lst, n):
def
generate_summaries_or_translations
(
examples
:
l
ist
,
examples
:
L
ist
[
str
]
,
out_file
:
str
,
model_name
:
str
,
batch_size
:
int
=
8
,
...
...
@@ -31,36 +37,39 @@ def generate_summaries_or_translations(
fp16
=
False
,
task
=
"summarization"
,
decoder_start_token_id
=
None
,
**
gen_kwargs
,
)
->
None
:
**
generate_kwargs
,
)
->
Dict
:
"""Save model.generate results to <out_file>, and return how long it took."""
fout
=
Path
(
out_file
).
open
(
"w"
,
encoding
=
"utf-8"
)
model_name
=
str
(
model_name
)
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
model_name
).
to
(
device
)
if
fp16
:
model
=
model
.
half
()
if
decoder_start_token_id
is
None
:
decoder_start_token_id
=
gen_kwargs
.
pop
(
"decoder_start_token_id"
,
None
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
logger
.
info
(
f
"Inferred tokenizer type:
{
tokenizer
.
__class__
}
"
)
# if this is wrong, check config.model_type.
# update config with summarization specific params
start_time
=
time
.
time
()
# update config with task specific params
use_task_specific_params
(
model
,
task
)
for
batch
in
tqdm
(
list
(
chunks
(
examples
,
batch_size
))):
for
examples_chunk
in
tqdm
(
list
(
chunks
(
examples
,
batch_size
))):
if
"t5"
in
model_name
:
batch
=
[
model
.
config
.
prefix
+
text
for
text
in
batch
]
batch
=
tokenizer
(
batch
,
return_tensors
=
"pt"
,
truncation
=
True
,
padding
=
"max_length"
).
to
(
device
)
input_ids
,
attention_mask
=
trim_batch
(
**
batch
,
pad_token_id
=
tokenizer
.
pad_token_id
)
examples_chunk
=
[
model
.
config
.
prefix
+
text
for
text
in
examples_chunk
]
batch
=
tokenizer
(
examples_chunk
,
return_tensors
=
"pt"
,
truncation
=
True
,
padding
=
"longest"
).
to
(
device
)
summaries
=
model
.
generate
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
input_ids
=
batch
.
input_ids
,
attention_mask
=
batch
.
attention_mask
,
decoder_start_token_id
=
decoder_start_token_id
,
**
gen_kwargs
,
**
gen
erate
_kwargs
,
)
dec
=
tokenizer
.
batch_decode
(
summaries
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
for
hypothesis
in
dec
:
fout
.
write
(
hypothesis
+
"
\n
"
)
fout
.
flush
()
fout
.
close
()
runtime
=
time
.
time
()
-
start_time
n_obs
=
len
(
examples
)
return
dict
(
n_obs
=
n_obs
,
runtime
=
runtime
,
seconds_per_sample
=
round
(
runtime
/
n_obs
,
4
))
def
run_generate
():
...
...
@@ -70,7 +79,13 @@ def run_generate():
parser
.
add_argument
(
"save_path"
,
type
=
str
,
help
=
"where to save summaries"
)
parser
.
add_argument
(
"--reference_path"
,
type
=
str
,
required
=
False
,
help
=
"like cnn_dm/test_reference_summaries.txt"
)
parser
.
add_argument
(
"--score_path"
,
type
=
str
,
required
=
False
,
help
=
"where to save the rouge score in json format"
)
parser
.
add_argument
(
"--score_path"
,
type
=
str
,
required
=
False
,
default
=
"metrics.json"
,
help
=
"where to save the rouge score in json format"
,
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
required
=
False
,
default
=
DEFAULT_DEVICE
,
help
=
"cuda, cuda:1, cpu etc."
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"summarization"
,
help
=
"typically translation or summarization"
)
parser
.
add_argument
(
"--bs"
,
type
=
int
,
default
=
8
,
required
=
False
,
help
=
"batch size"
)
...
...
@@ -79,7 +94,7 @@ def run_generate():
type
=
int
,
default
=
None
,
required
=
False
,
help
=
"
decoder_start_token_id (otherwise will look at
config
)
"
,
help
=
"
Defaults to using
config"
,
)
parser
.
add_argument
(
"--n_obs"
,
type
=
int
,
default
=-
1
,
required
=
False
,
help
=
"How many observations. Defaults to all."
...
...
@@ -90,7 +105,9 @@ def run_generate():
if
args
.
n_obs
>
0
:
examples
=
examples
[:
args
.
n_obs
]
Path
(
args
.
save_path
).
parent
.
mkdir
(
exist_ok
=
True
)
generate_summaries_or_translations
(
if
args
.
reference_path
is
None
and
Path
(
args
.
score_path
).
exists
():
warnings
.
warn
(
f
"score_path
{
args
.
score_path
}
will be overwritten unless you type ctrl-c."
)
runtime_metrics
=
generate_summaries_or_translations
(
examples
,
args
.
save_path
,
args
.
model_name
,
...
...
@@ -107,9 +124,10 @@ def run_generate():
output_lns
=
[
x
.
rstrip
()
for
x
in
open
(
args
.
save_path
).
readlines
()]
reference_lns
=
[
x
.
rstrip
()
for
x
in
open
(
args
.
reference_path
).
readlines
()][:
len
(
output_lns
)]
scores
:
dict
=
score_fn
(
output_lns
,
reference_lns
)
scores
.
update
(
runtime_metrics
)
print
(
scores
)
if
args
.
score_path
is
not
None
:
json
.
dump
(
scores
,
open
(
args
.
score_path
,
"w
+
"
))
json
.
dump
(
scores
,
open
(
args
.
score_path
,
"w"
))
return
scores
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
61518e2d
...
...
@@ -252,13 +252,24 @@ class TestSummarizationDistiller(unittest.TestCase):
@
pytest
.
mark
.
parametrize
([
"model"
],
[
pytest
.
param
(
T5_TINY
),
pytest
.
param
(
BART_TINY
),
pytest
.
param
(
MBART_TINY
)])
def
test_run_eval
_bart
(
model
):
def
test_run_eval
(
model
):
input_file_name
=
Path
(
tempfile
.
mkdtemp
())
/
"utest_input.source"
output_file_name
=
input_file_name
.
parent
/
"utest_output.txt"
assert
not
output_file_name
.
exists
()
articles
=
[
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."
]
_dump_articles
(
input_file_name
,
articles
)
testargs
=
[
"run_eval.py"
,
model
,
str
(
input_file_name
),
str
(
output_file_name
)]
# TODO: test score_path
score_path
=
str
(
Path
(
tempfile
.
mkdtemp
())
/
"scores.json"
)
task
=
"translation_en_to_de"
if
model
==
T5_TINY
else
"summarization"
testargs
=
[
"run_eval.py"
,
model
,
str
(
input_file_name
),
str
(
output_file_name
),
"--score_path"
,
score_path
,
"--task"
,
task
,
]
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
run_generate
()
assert
Path
(
output_file_name
).
exists
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment