Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a4fc0c80
Unverified
Commit
a4fc0c80
authored
Sep 04, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 04, 2020
Browse files
[s2s] run_eval.py parses generate_kwargs (#6948)
parent
6078b120
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
24 deletions
+36
-24
examples/seq2seq/run_eval.py
examples/seq2seq/run_eval.py
+11
-23
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+4
-0
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+21
-1
No files found.
examples/seq2seq/run_eval.py
View file @
a4fc0c80
...
@@ -15,9 +15,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
...
@@ -15,9 +15,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
logger
=
getLogger
(
__name__
)
logger
=
getLogger
(
__name__
)
try
:
try
:
from
.utils
import
calculate_bleu
,
calculate_rouge
,
use_task_specific_params
from
.utils
import
calculate_bleu
,
calculate_rouge
,
parse_numeric_cl_kwargs
,
use_task_specific_params
except
ImportError
:
except
ImportError
:
from
utils
import
calculate_bleu
,
calculate_rouge
,
use_task_specific_params
from
utils
import
calculate_bleu
,
calculate_rouge
,
parse_numeric_cl_kwargs
,
use_task_specific_params
DEFAULT_DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
DEFAULT_DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
...
@@ -36,7 +36,6 @@ def generate_summaries_or_translations(
...
@@ -36,7 +36,6 @@ def generate_summaries_or_translations(
device
:
str
=
DEFAULT_DEVICE
,
device
:
str
=
DEFAULT_DEVICE
,
fp16
=
False
,
fp16
=
False
,
task
=
"summarization"
,
task
=
"summarization"
,
decoder_start_token_id
=
None
,
**
generate_kwargs
,
**
generate_kwargs
,
)
->
Dict
:
)
->
Dict
:
"""Save model.generate results to <out_file>, and return how long it took."""
"""Save model.generate results to <out_file>, and return how long it took."""
...
@@ -59,7 +58,6 @@ def generate_summaries_or_translations(
...
@@ -59,7 +58,6 @@ def generate_summaries_or_translations(
summaries
=
model
.
generate
(
summaries
=
model
.
generate
(
input_ids
=
batch
.
input_ids
,
input_ids
=
batch
.
input_ids
,
attention_mask
=
batch
.
attention_mask
,
attention_mask
=
batch
.
attention_mask
,
decoder_start_token_id
=
decoder_start_token_id
,
**
generate_kwargs
,
**
generate_kwargs
,
)
)
dec
=
tokenizer
.
batch_decode
(
summaries
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
dec
=
tokenizer
.
batch_decode
(
summaries
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
...
@@ -77,30 +75,20 @@ def run_generate():
...
@@ -77,30 +75,20 @@ def run_generate():
parser
.
add_argument
(
"model_name"
,
type
=
str
,
help
=
"like facebook/bart-large-cnn,t5-base, etc."
)
parser
.
add_argument
(
"model_name"
,
type
=
str
,
help
=
"like facebook/bart-large-cnn,t5-base, etc."
)
parser
.
add_argument
(
"input_path"
,
type
=
str
,
help
=
"like cnn_dm/test.source"
)
parser
.
add_argument
(
"input_path"
,
type
=
str
,
help
=
"like cnn_dm/test.source"
)
parser
.
add_argument
(
"save_path"
,
type
=
str
,
help
=
"where to save summaries"
)
parser
.
add_argument
(
"save_path"
,
type
=
str
,
help
=
"where to save summaries"
)
parser
.
add_argument
(
"--reference_path"
,
type
=
str
,
required
=
False
,
help
=
"like cnn_dm/test.target"
)
parser
.
add_argument
(
"--reference_path"
,
type
=
str
,
required
=
False
,
help
=
"like cnn_dm/test_reference_summaries.txt"
)
parser
.
add_argument
(
"--score_path"
,
type
=
str
,
required
=
False
,
default
=
"metrics.json"
,
help
=
"where to save metrics"
)
parser
.
add_argument
(
"--score_path"
,
type
=
str
,
required
=
False
,
default
=
"metrics.json"
,
help
=
"where to save the rouge score in json format"
,
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
required
=
False
,
default
=
DEFAULT_DEVICE
,
help
=
"cuda, cuda:1, cpu etc."
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
required
=
False
,
default
=
DEFAULT_DEVICE
,
help
=
"cuda, cuda:1, cpu etc."
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"summarization"
,
help
=
"
typically translation or summarization
"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"summarization"
,
help
=
"
used for task_specific_params + metrics
"
)
parser
.
add_argument
(
"--bs"
,
type
=
int
,
default
=
8
,
required
=
False
,
help
=
"batch size"
)
parser
.
add_argument
(
"--bs"
,
type
=
int
,
default
=
8
,
required
=
False
,
help
=
"batch size"
)
parser
.
add_argument
(
"--decoder_start_token_id"
,
type
=
int
,
default
=
None
,
required
=
False
,
help
=
"Defaults to using config"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--n_obs"
,
type
=
int
,
default
=-
1
,
required
=
False
,
help
=
"How many observations. Defaults to all."
"--n_obs"
,
type
=
int
,
default
=-
1
,
required
=
False
,
help
=
"How many observations. Defaults to all."
)
)
parser
.
add_argument
(
"--fp16"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--fp16"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
# Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate
args
,
rest
=
parser
.
parse_known_args
()
parsed
=
parse_numeric_cl_kwargs
(
rest
)
if
parsed
:
print
(
f
"parsed the following generate kwargs:
{
parsed
}
"
)
examples
=
[
" "
+
x
.
rstrip
()
if
"t5"
in
args
.
model_name
else
x
.
rstrip
()
for
x
in
open
(
args
.
input_path
).
readlines
()]
examples
=
[
" "
+
x
.
rstrip
()
if
"t5"
in
args
.
model_name
else
x
.
rstrip
()
for
x
in
open
(
args
.
input_path
).
readlines
()]
if
args
.
n_obs
>
0
:
if
args
.
n_obs
>
0
:
examples
=
examples
[:
args
.
n_obs
]
examples
=
examples
[:
args
.
n_obs
]
...
@@ -115,7 +103,7 @@ def run_generate():
...
@@ -115,7 +103,7 @@ def run_generate():
device
=
args
.
device
,
device
=
args
.
device
,
fp16
=
args
.
fp16
,
fp16
=
args
.
fp16
,
task
=
args
.
task
,
task
=
args
.
task
,
decoder_start_token_id
=
args
.
decoder_start_token_i
d
,
**
parse
d
,
)
)
if
args
.
reference_path
is
None
:
if
args
.
reference_path
is
None
:
return
return
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
a4fc0c80
...
@@ -300,6 +300,10 @@ def test_run_eval(model):
...
@@ -300,6 +300,10 @@ def test_run_eval(model):
score_path
,
score_path
,
"--task"
,
"--task"
,
task
,
task
,
"--num_beams"
,
"2"
,
"--length_penalty"
,
"2.0"
,
]
]
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
run_generate
()
run_generate
()
...
...
examples/seq2seq/utils.py
View file @
a4fc0c80
...
@@ -5,7 +5,7 @@ import os
...
@@ -5,7 +5,7 @@ import os
import
pickle
import
pickle
from
logging
import
getLogger
from
logging
import
getLogger
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Callable
,
Dict
,
Iterable
,
List
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Union
import
git
import
git
import
numpy
as
np
import
numpy
as
np
...
@@ -309,3 +309,23 @@ def assert_not_all_frozen(model):
...
@@ -309,3 +309,23 @@ def assert_not_all_frozen(model):
model_grads
:
List
[
bool
]
=
list
(
grad_status
(
model
))
model_grads
:
List
[
bool
]
=
list
(
grad_status
(
model
))
npars
=
len
(
model_grads
)
npars
=
len
(
model_grads
)
assert
any
(
model_grads
),
f
"none of
{
npars
}
weights require grad"
assert
any
(
model_grads
),
f
"none of
{
npars
}
weights require grad"
# CLI Parsing utils
def
parse_numeric_cl_kwargs
(
unparsed_args
:
List
[
str
])
->
Dict
[
str
,
Union
[
int
,
float
]]:
"""Parse an argv list of unspecified command line args to a dict. Assumes all values are numeric."""
result
=
{}
assert
len
(
unparsed_args
)
%
2
==
0
,
f
"got odd number of unparsed args:
{
unparsed_args
}
"
num_pairs
=
len
(
unparsed_args
)
//
2
for
pair_num
in
range
(
num_pairs
):
i
=
2
*
pair_num
assert
unparsed_args
[
i
].
startswith
(
"--"
)
try
:
value
=
int
(
unparsed_args
[
i
+
1
])
except
ValueError
:
value
=
float
(
unparsed_args
[
i
+
1
])
# this can raise another informative ValueError
result
[
unparsed_args
[
i
][
2
:]]
=
value
return
result
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment