Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a4fc0c80
Unverified
Commit
a4fc0c80
authored
Sep 04, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 04, 2020
Browse files
[s2s] run_eval.py parses generate_kwargs (#6948)
parent
6078b120
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
24 deletions
+36
-24
examples/seq2seq/run_eval.py
examples/seq2seq/run_eval.py
+11
-23
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+4
-0
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+21
-1
No files found.
examples/seq2seq/run_eval.py
View file @
a4fc0c80
...
@@ -15,9 +15,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
...
@@ -15,9 +15,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
logger
=
getLogger
(
__name__
)
logger
=
getLogger
(
__name__
)
try
:
try
:
from
.utils
import
calculate_bleu
,
calculate_rouge
,
use_task_specific_params
from
.utils
import
calculate_bleu
,
calculate_rouge
,
parse_numeric_cl_kwargs
,
use_task_specific_params
except
ImportError
:
except
ImportError
:
from
utils
import
calculate_bleu
,
calculate_rouge
,
use_task_specific_params
from
utils
import
calculate_bleu
,
calculate_rouge
,
parse_numeric_cl_kwargs
,
use_task_specific_params
DEFAULT_DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
DEFAULT_DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
...
@@ -36,7 +36,6 @@ def generate_summaries_or_translations(
...
@@ -36,7 +36,6 @@ def generate_summaries_or_translations(
device
:
str
=
DEFAULT_DEVICE
,
device
:
str
=
DEFAULT_DEVICE
,
fp16
=
False
,
fp16
=
False
,
task
=
"summarization"
,
task
=
"summarization"
,
decoder_start_token_id
=
None
,
**
generate_kwargs
,
**
generate_kwargs
,
)
->
Dict
:
)
->
Dict
:
"""Save model.generate results to <out_file>, and return how long it took."""
"""Save model.generate results to <out_file>, and return how long it took."""
...
@@ -59,7 +58,6 @@ def generate_summaries_or_translations(
...
@@ -59,7 +58,6 @@ def generate_summaries_or_translations(
summaries
=
model
.
generate
(
summaries
=
model
.
generate
(
input_ids
=
batch
.
input_ids
,
input_ids
=
batch
.
input_ids
,
attention_mask
=
batch
.
attention_mask
,
attention_mask
=
batch
.
attention_mask
,
decoder_start_token_id
=
decoder_start_token_id
,
**
generate_kwargs
,
**
generate_kwargs
,
)
)
dec
=
tokenizer
.
batch_decode
(
summaries
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
dec
=
tokenizer
.
batch_decode
(
summaries
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
...
@@ -77,30 +75,20 @@ def run_generate():
...
@@ -77,30 +75,20 @@ def run_generate():
parser
.
add_argument
(
"model_name"
,
type
=
str
,
help
=
"like facebook/bart-large-cnn,t5-base, etc."
)
parser
.
add_argument
(
"model_name"
,
type
=
str
,
help
=
"like facebook/bart-large-cnn,t5-base, etc."
)
parser
.
add_argument
(
"input_path"
,
type
=
str
,
help
=
"like cnn_dm/test.source"
)
parser
.
add_argument
(
"input_path"
,
type
=
str
,
help
=
"like cnn_dm/test.source"
)
parser
.
add_argument
(
"save_path"
,
type
=
str
,
help
=
"where to save summaries"
)
parser
.
add_argument
(
"save_path"
,
type
=
str
,
help
=
"where to save summaries"
)
parser
.
add_argument
(
"--reference_path"
,
type
=
str
,
required
=
False
,
help
=
"like cnn_dm/test.target"
)
parser
.
add_argument
(
"--reference_path"
,
type
=
str
,
required
=
False
,
help
=
"like cnn_dm/test_reference_summaries.txt"
)
parser
.
add_argument
(
"--score_path"
,
type
=
str
,
required
=
False
,
default
=
"metrics.json"
,
help
=
"where to save metrics"
)
parser
.
add_argument
(
"--score_path"
,
type
=
str
,
required
=
False
,
default
=
"metrics.json"
,
help
=
"where to save the rouge score in json format"
,
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
required
=
False
,
default
=
DEFAULT_DEVICE
,
help
=
"cuda, cuda:1, cpu etc."
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
required
=
False
,
default
=
DEFAULT_DEVICE
,
help
=
"cuda, cuda:1, cpu etc."
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"summarization"
,
help
=
"
typically translation or summarization
"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"summarization"
,
help
=
"
used for task_specific_params + metrics
"
)
parser
.
add_argument
(
"--bs"
,
type
=
int
,
default
=
8
,
required
=
False
,
help
=
"batch size"
)
parser
.
add_argument
(
"--bs"
,
type
=
int
,
default
=
8
,
required
=
False
,
help
=
"batch size"
)
parser
.
add_argument
(
"--decoder_start_token_id"
,
type
=
int
,
default
=
None
,
required
=
False
,
help
=
"Defaults to using config"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--n_obs"
,
type
=
int
,
default
=-
1
,
required
=
False
,
help
=
"How many observations. Defaults to all."
"--n_obs"
,
type
=
int
,
default
=-
1
,
required
=
False
,
help
=
"How many observations. Defaults to all."
)
)
parser
.
add_argument
(
"--fp16"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--fp16"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
# Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate
args
,
rest
=
parser
.
parse_known_args
()
parsed
=
parse_numeric_cl_kwargs
(
rest
)
if
parsed
:
print
(
f
"parsed the following generate kwargs:
{
parsed
}
"
)
examples
=
[
" "
+
x
.
rstrip
()
if
"t5"
in
args
.
model_name
else
x
.
rstrip
()
for
x
in
open
(
args
.
input_path
).
readlines
()]
examples
=
[
" "
+
x
.
rstrip
()
if
"t5"
in
args
.
model_name
else
x
.
rstrip
()
for
x
in
open
(
args
.
input_path
).
readlines
()]
if
args
.
n_obs
>
0
:
if
args
.
n_obs
>
0
:
examples
=
examples
[:
args
.
n_obs
]
examples
=
examples
[:
args
.
n_obs
]
...
@@ -115,7 +103,7 @@ def run_generate():
...
@@ -115,7 +103,7 @@ def run_generate():
device
=
args
.
device
,
device
=
args
.
device
,
fp16
=
args
.
fp16
,
fp16
=
args
.
fp16
,
task
=
args
.
task
,
task
=
args
.
task
,
decoder_start_token_id
=
args
.
decoder_start_token_i
d
,
**
parse
d
,
)
)
if
args
.
reference_path
is
None
:
if
args
.
reference_path
is
None
:
return
return
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
a4fc0c80
...
@@ -300,6 +300,10 @@ def test_run_eval(model):
...
@@ -300,6 +300,10 @@ def test_run_eval(model):
score_path
,
score_path
,
"--task"
,
"--task"
,
task
,
task
,
"--num_beams"
,
"2"
,
"--length_penalty"
,
"2.0"
,
]
]
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
run_generate
()
run_generate
()
...
...
examples/seq2seq/utils.py
View file @
a4fc0c80
...
@@ -5,7 +5,7 @@ import os
...
@@ -5,7 +5,7 @@ import os
import
pickle
import
pickle
from
logging
import
getLogger
from
logging
import
getLogger
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Callable
,
Dict
,
Iterable
,
List
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Union
import
git
import
git
import
numpy
as
np
import
numpy
as
np
...
@@ -309,3 +309,23 @@ def assert_not_all_frozen(model):
...
@@ -309,3 +309,23 @@ def assert_not_all_frozen(model):
model_grads
:
List
[
bool
]
=
list
(
grad_status
(
model
))
model_grads
:
List
[
bool
]
=
list
(
grad_status
(
model
))
npars
=
len
(
model_grads
)
npars
=
len
(
model_grads
)
assert
any
(
model_grads
),
f
"none of
{
npars
}
weights require grad"
assert
any
(
model_grads
),
f
"none of
{
npars
}
weights require grad"
# CLI Parsing utils
def
parse_numeric_cl_kwargs
(
unparsed_args
:
List
[
str
])
->
Dict
[
str
,
Union
[
int
,
float
]]:
"""Parse an argv list of unspecified command line args to a dict. Assumes all values are numeric."""
result
=
{}
assert
len
(
unparsed_args
)
%
2
==
0
,
f
"got odd number of unparsed args:
{
unparsed_args
}
"
num_pairs
=
len
(
unparsed_args
)
//
2
for
pair_num
in
range
(
num_pairs
):
i
=
2
*
pair_num
assert
unparsed_args
[
i
].
startswith
(
"--"
)
try
:
value
=
int
(
unparsed_args
[
i
+
1
])
except
ValueError
:
value
=
float
(
unparsed_args
[
i
+
1
])
# this can raise another informative ValueError
result
[
unparsed_args
[
i
][
2
:]]
=
value
return
result
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment