Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
61518e2d
Unverified
Commit
61518e2d
authored
Aug 26, 2020
by
Sam Shleifer
Committed by
GitHub
Aug 26, 2020
Browse files
[s2s] run_eval.py QOL improvements and cleanup(#6746)
parent
434936f3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
51 additions
and
22 deletions
+51
-22
examples/seq2seq/run_eval.py
examples/seq2seq/run_eval.py
+38
-20
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+13
-2
No files found.
examples/seq2seq/run_eval.py
View file @
61518e2d
import
argparse
import
argparse
import
json
import
json
import
time
import
warnings
from
logging
import
getLogger
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
,
List
import
torch
import
torch
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -8,10 +12,12 @@ from tqdm import tqdm
...
@@ -8,10 +12,12 @@ from tqdm import tqdm
from
transformers
import
AutoModelForSeq2SeqLM
,
AutoTokenizer
from
transformers
import
AutoModelForSeq2SeqLM
,
AutoTokenizer
logger
=
getLogger
(
__name__
)
try
:
try
:
from
.utils
import
calculate_bleu
,
calculate_rouge
,
trim_batch
,
use_task_specific_params
from
.utils
import
calculate_bleu
,
calculate_rouge
,
use_task_specific_params
except
ImportError
:
except
ImportError
:
from
utils
import
calculate_bleu
,
calculate_rouge
,
trim_batch
,
use_task_specific_params
from
utils
import
calculate_bleu
,
calculate_rouge
,
use_task_specific_params
DEFAULT_DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
DEFAULT_DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
...
@@ -23,7 +29,7 @@ def chunks(lst, n):
...
@@ -23,7 +29,7 @@ def chunks(lst, n):
def
generate_summaries_or_translations
(
def
generate_summaries_or_translations
(
examples
:
l
ist
,
examples
:
L
ist
[
str
]
,
out_file
:
str
,
out_file
:
str
,
model_name
:
str
,
model_name
:
str
,
batch_size
:
int
=
8
,
batch_size
:
int
=
8
,
...
@@ -31,36 +37,39 @@ def generate_summaries_or_translations(
...
@@ -31,36 +37,39 @@ def generate_summaries_or_translations(
fp16
=
False
,
fp16
=
False
,
task
=
"summarization"
,
task
=
"summarization"
,
decoder_start_token_id
=
None
,
decoder_start_token_id
=
None
,
**
gen_kwargs
,
**
generate_kwargs
,
)
->
None
:
)
->
Dict
:
"""Save model.generate results to <out_file>, and return how long it took."""
fout
=
Path
(
out_file
).
open
(
"w"
,
encoding
=
"utf-8"
)
fout
=
Path
(
out_file
).
open
(
"w"
,
encoding
=
"utf-8"
)
model_name
=
str
(
model_name
)
model_name
=
str
(
model_name
)
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
model_name
).
to
(
device
)
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
model_name
).
to
(
device
)
if
fp16
:
if
fp16
:
model
=
model
.
half
()
model
=
model
.
half
()
if
decoder_start_token_id
is
None
:
decoder_start_token_id
=
gen_kwargs
.
pop
(
"decoder_start_token_id"
,
None
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
logger
.
info
(
f
"Inferred tokenizer type:
{
tokenizer
.
__class__
}
"
)
# if this is wrong, check config.model_type.
# update config with summarization specific params
start_time
=
time
.
time
()
# update config with task specific params
use_task_specific_params
(
model
,
task
)
use_task_specific_params
(
model
,
task
)
for
examples_chunk
in
tqdm
(
list
(
chunks
(
examples
,
batch_size
))):
for
batch
in
tqdm
(
list
(
chunks
(
examples
,
batch_size
))):
if
"t5"
in
model_name
:
if
"t5"
in
model_name
:
batch
=
[
model
.
config
.
prefix
+
text
for
text
in
batch
]
examples_chunk
=
[
model
.
config
.
prefix
+
text
for
text
in
examples_chunk
]
batch
=
tokenizer
(
batch
,
return_tensors
=
"pt"
,
truncation
=
True
,
padding
=
"max_length"
).
to
(
device
)
batch
=
tokenizer
(
examples_chunk
,
return_tensors
=
"pt"
,
truncation
=
True
,
padding
=
"longest"
).
to
(
device
)
input_ids
,
attention_mask
=
trim_batch
(
**
batch
,
pad_token_id
=
tokenizer
.
pad_token_id
)
summaries
=
model
.
generate
(
summaries
=
model
.
generate
(
input_ids
=
input_ids
,
input_ids
=
batch
.
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
batch
.
attention_mask
,
decoder_start_token_id
=
decoder_start_token_id
,
decoder_start_token_id
=
decoder_start_token_id
,
**
gen_kwargs
,
**
gen
erate
_kwargs
,
)
)
dec
=
tokenizer
.
batch_decode
(
summaries
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
dec
=
tokenizer
.
batch_decode
(
summaries
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
for
hypothesis
in
dec
:
for
hypothesis
in
dec
:
fout
.
write
(
hypothesis
+
"
\n
"
)
fout
.
write
(
hypothesis
+
"
\n
"
)
fout
.
flush
()
fout
.
flush
()
fout
.
close
()
runtime
=
time
.
time
()
-
start_time
n_obs
=
len
(
examples
)
return
dict
(
n_obs
=
n_obs
,
runtime
=
runtime
,
seconds_per_sample
=
round
(
runtime
/
n_obs
,
4
))
def
run_generate
():
def
run_generate
():
...
@@ -70,7 +79,13 @@ def run_generate():
...
@@ -70,7 +79,13 @@ def run_generate():
parser
.
add_argument
(
"save_path"
,
type
=
str
,
help
=
"where to save summaries"
)
parser
.
add_argument
(
"save_path"
,
type
=
str
,
help
=
"where to save summaries"
)
parser
.
add_argument
(
"--reference_path"
,
type
=
str
,
required
=
False
,
help
=
"like cnn_dm/test_reference_summaries.txt"
)
parser
.
add_argument
(
"--reference_path"
,
type
=
str
,
required
=
False
,
help
=
"like cnn_dm/test_reference_summaries.txt"
)
parser
.
add_argument
(
"--score_path"
,
type
=
str
,
required
=
False
,
help
=
"where to save the rouge score in json format"
)
parser
.
add_argument
(
"--score_path"
,
type
=
str
,
required
=
False
,
default
=
"metrics.json"
,
help
=
"where to save the rouge score in json format"
,
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
required
=
False
,
default
=
DEFAULT_DEVICE
,
help
=
"cuda, cuda:1, cpu etc."
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
required
=
False
,
default
=
DEFAULT_DEVICE
,
help
=
"cuda, cuda:1, cpu etc."
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"summarization"
,
help
=
"typically translation or summarization"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"summarization"
,
help
=
"typically translation or summarization"
)
parser
.
add_argument
(
"--bs"
,
type
=
int
,
default
=
8
,
required
=
False
,
help
=
"batch size"
)
parser
.
add_argument
(
"--bs"
,
type
=
int
,
default
=
8
,
required
=
False
,
help
=
"batch size"
)
...
@@ -79,7 +94,7 @@ def run_generate():
...
@@ -79,7 +94,7 @@ def run_generate():
type
=
int
,
type
=
int
,
default
=
None
,
default
=
None
,
required
=
False
,
required
=
False
,
help
=
"
decoder_start_token_id (otherwise will look at
config
)
"
,
help
=
"
Defaults to using
config"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--n_obs"
,
type
=
int
,
default
=-
1
,
required
=
False
,
help
=
"How many observations. Defaults to all."
"--n_obs"
,
type
=
int
,
default
=-
1
,
required
=
False
,
help
=
"How many observations. Defaults to all."
...
@@ -90,7 +105,9 @@ def run_generate():
...
@@ -90,7 +105,9 @@ def run_generate():
if
args
.
n_obs
>
0
:
if
args
.
n_obs
>
0
:
examples
=
examples
[:
args
.
n_obs
]
examples
=
examples
[:
args
.
n_obs
]
Path
(
args
.
save_path
).
parent
.
mkdir
(
exist_ok
=
True
)
Path
(
args
.
save_path
).
parent
.
mkdir
(
exist_ok
=
True
)
generate_summaries_or_translations
(
if
args
.
reference_path
is
None
and
Path
(
args
.
score_path
).
exists
():
warnings
.
warn
(
f
"score_path
{
args
.
score_path
}
will be overwritten unless you type ctrl-c."
)
runtime_metrics
=
generate_summaries_or_translations
(
examples
,
examples
,
args
.
save_path
,
args
.
save_path
,
args
.
model_name
,
args
.
model_name
,
...
@@ -107,9 +124,10 @@ def run_generate():
...
@@ -107,9 +124,10 @@ def run_generate():
output_lns
=
[
x
.
rstrip
()
for
x
in
open
(
args
.
save_path
).
readlines
()]
output_lns
=
[
x
.
rstrip
()
for
x
in
open
(
args
.
save_path
).
readlines
()]
reference_lns
=
[
x
.
rstrip
()
for
x
in
open
(
args
.
reference_path
).
readlines
()][:
len
(
output_lns
)]
reference_lns
=
[
x
.
rstrip
()
for
x
in
open
(
args
.
reference_path
).
readlines
()][:
len
(
output_lns
)]
scores
:
dict
=
score_fn
(
output_lns
,
reference_lns
)
scores
:
dict
=
score_fn
(
output_lns
,
reference_lns
)
scores
.
update
(
runtime_metrics
)
print
(
scores
)
print
(
scores
)
if
args
.
score_path
is
not
None
:
if
args
.
score_path
is
not
None
:
json
.
dump
(
scores
,
open
(
args
.
score_path
,
"w
+
"
))
json
.
dump
(
scores
,
open
(
args
.
score_path
,
"w"
))
return
scores
return
scores
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
61518e2d
...
@@ -252,13 +252,24 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -252,13 +252,24 @@ class TestSummarizationDistiller(unittest.TestCase):
@
pytest
.
mark
.
parametrize
([
"model"
],
[
pytest
.
param
(
T5_TINY
),
pytest
.
param
(
BART_TINY
),
pytest
.
param
(
MBART_TINY
)])
@
pytest
.
mark
.
parametrize
([
"model"
],
[
pytest
.
param
(
T5_TINY
),
pytest
.
param
(
BART_TINY
),
pytest
.
param
(
MBART_TINY
)])
def
test_run_eval
_bart
(
model
):
def
test_run_eval
(
model
):
input_file_name
=
Path
(
tempfile
.
mkdtemp
())
/
"utest_input.source"
input_file_name
=
Path
(
tempfile
.
mkdtemp
())
/
"utest_input.source"
output_file_name
=
input_file_name
.
parent
/
"utest_output.txt"
output_file_name
=
input_file_name
.
parent
/
"utest_output.txt"
assert
not
output_file_name
.
exists
()
assert
not
output_file_name
.
exists
()
articles
=
[
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."
]
articles
=
[
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."
]
_dump_articles
(
input_file_name
,
articles
)
_dump_articles
(
input_file_name
,
articles
)
testargs
=
[
"run_eval.py"
,
model
,
str
(
input_file_name
),
str
(
output_file_name
)]
# TODO: test score_path
score_path
=
str
(
Path
(
tempfile
.
mkdtemp
())
/
"scores.json"
)
task
=
"translation_en_to_de"
if
model
==
T5_TINY
else
"summarization"
testargs
=
[
"run_eval.py"
,
model
,
str
(
input_file_name
),
str
(
output_file_name
),
"--score_path"
,
score_path
,
"--task"
,
task
,
]
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
run_generate
()
run_generate
()
assert
Path
(
output_file_name
).
exists
()
assert
Path
(
output_file_name
).
exists
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment