Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9dab39fe
Unverified
Commit
9dab39fe
authored
Jul 21, 2020
by
Sam Shleifer
Committed by
GitHub
Jul 21, 2020
Browse files
seq2seq/run_eval.py can take decoder_start_token_id (#5949)
parent
5b193b39
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
3 deletions
+35
-3
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+1
-0
examples/seq2seq/run_eval.py
examples/seq2seq/run_eval.py
+17
-1
src/transformers/tokenization_utils_base.py
src/transformers/tokenization_utils_base.py
+17
-2
No files found.
examples/seq2seq/finetune.py
View file @
9dab39fe
...
...
@@ -327,6 +327,7 @@ class TranslationModule(SummarizationModule):
self
.
dataset_kwargs
[
"tgt_lang"
]
=
hparams
.
tgt_lang
if
self
.
model
.
config
.
decoder_start_token_id
is
None
and
isinstance
(
self
.
tokenizer
,
MBartTokenizer
):
self
.
decoder_start_token_id
=
self
.
tokenizer
.
lang_code_to_id
[
hparams
.
tgt_lang
]
self
.
model
.
config
.
decoder_start_token_id
=
self
.
decoder_start_token_id
if
isinstance
(
self
.
tokenizer
,
MBartTokenizer
):
self
.
dataset_class
=
MBartDataset
...
...
examples/seq2seq/run_eval.py
View file @
9dab39fe
...
...
@@ -30,6 +30,7 @@ def generate_summaries_or_translations(
device
:
str
=
DEFAULT_DEVICE
,
fp16
=
False
,
task
=
"summarization"
,
decoder_start_token_id
=
None
,
**
gen_kwargs
,
)
->
None
:
fout
=
Path
(
out_file
).
open
(
"w"
,
encoding
=
"utf-8"
)
...
...
@@ -37,6 +38,8 @@ def generate_summaries_or_translations(
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
model_name
).
to
(
device
)
if
fp16
:
model
=
model
.
half
()
if
decoder_start_token_id
is
None
:
decoder_start_token_id
=
gen_kwargs
.
pop
(
"decoder_start_token_id"
,
None
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
...
...
@@ -48,7 +51,12 @@ def generate_summaries_or_translations(
batch
=
[
model
.
config
.
prefix
+
text
for
text
in
batch
]
batch
=
tokenizer
(
batch
,
return_tensors
=
"pt"
,
truncation
=
True
,
padding
=
"max_length"
).
to
(
device
)
input_ids
,
attention_mask
=
trim_batch
(
**
batch
,
pad_token_id
=
tokenizer
.
pad_token_id
)
summaries
=
model
.
generate
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
**
gen_kwargs
)
summaries
=
model
.
generate
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
decoder_start_token_id
=
decoder_start_token_id
,
**
gen_kwargs
,
)
dec
=
tokenizer
.
batch_decode
(
summaries
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
for
hypothesis
in
dec
:
fout
.
write
(
hypothesis
+
"
\n
"
)
...
...
@@ -66,6 +74,13 @@ def run_generate():
parser
.
add_argument
(
"--device"
,
type
=
str
,
required
=
False
,
default
=
DEFAULT_DEVICE
,
help
=
"cuda, cuda:1, cpu etc."
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"summarization"
,
help
=
"typically translation or summarization"
)
parser
.
add_argument
(
"--bs"
,
type
=
int
,
default
=
8
,
required
=
False
,
help
=
"batch size"
)
parser
.
add_argument
(
"--decoder_start_token_id"
,
type
=
int
,
default
=
None
,
required
=
False
,
help
=
"decoder_start_token_id (otherwise will look at config)"
,
)
parser
.
add_argument
(
"--n_obs"
,
type
=
int
,
default
=-
1
,
required
=
False
,
help
=
"How many observations. Defaults to all."
)
...
...
@@ -83,6 +98,7 @@ def run_generate():
device
=
args
.
device
,
fp16
=
args
.
fp16
,
task
=
args
.
task
,
decoder_start_token_id
=
args
.
decoder_start_token_id
,
)
if
args
.
reference_path
is
None
:
return
...
...
src/transformers/tokenization_utils_base.py
View file @
9dab39fe
...
...
@@ -2255,8 +2255,23 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
return
encoded_inputs
def
batch_decode
(
self
,
sequences
:
List
[
List
[
int
]],
**
kwargs
)
->
List
[
str
]:
return
[
self
.
decode
(
seq
,
**
kwargs
)
for
seq
in
sequences
]
def
batch_decode
(
self
,
sequences
:
List
[
List
[
int
]],
skip_special_tokens
:
bool
=
False
,
clean_up_tokenization_spaces
:
bool
=
True
)
->
List
[
str
]:
"""
Convert a list of lists of token ids into a list of strings by calling decode.
Args:
token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods.
skip_special_tokens: if set to True, will replace special tokens.
clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces.
"""
return
[
self
.
decode
(
seq
,
skip_special_tokens
=
skip_special_tokens
,
clean_up_tokenization_spaces
=
clean_up_tokenization_spaces
)
for
seq
in
sequences
]
def
decode
(
self
,
token_ids
:
List
[
int
],
skip_special_tokens
:
bool
=
False
,
clean_up_tokenization_spaces
:
bool
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment