Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9dab39fe
Unverified
Commit
9dab39fe
authored
Jul 21, 2020
by
Sam Shleifer
Committed by
GitHub
Jul 21, 2020
Browse files
seq2seq/run_eval.py can take decoder_start_token_id (#5949)
parent
5b193b39
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
3 deletions
+35
-3
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+1
-0
examples/seq2seq/run_eval.py
examples/seq2seq/run_eval.py
+17
-1
src/transformers/tokenization_utils_base.py
src/transformers/tokenization_utils_base.py
+17
-2
No files found.
examples/seq2seq/finetune.py
View file @
9dab39fe
...
@@ -327,6 +327,7 @@ class TranslationModule(SummarizationModule):
...
@@ -327,6 +327,7 @@ class TranslationModule(SummarizationModule):
self
.
dataset_kwargs
[
"tgt_lang"
]
=
hparams
.
tgt_lang
self
.
dataset_kwargs
[
"tgt_lang"
]
=
hparams
.
tgt_lang
if
self
.
model
.
config
.
decoder_start_token_id
is
None
and
isinstance
(
self
.
tokenizer
,
MBartTokenizer
):
if
self
.
model
.
config
.
decoder_start_token_id
is
None
and
isinstance
(
self
.
tokenizer
,
MBartTokenizer
):
self
.
decoder_start_token_id
=
self
.
tokenizer
.
lang_code_to_id
[
hparams
.
tgt_lang
]
self
.
decoder_start_token_id
=
self
.
tokenizer
.
lang_code_to_id
[
hparams
.
tgt_lang
]
self
.
model
.
config
.
decoder_start_token_id
=
self
.
decoder_start_token_id
if
isinstance
(
self
.
tokenizer
,
MBartTokenizer
):
if
isinstance
(
self
.
tokenizer
,
MBartTokenizer
):
self
.
dataset_class
=
MBartDataset
self
.
dataset_class
=
MBartDataset
...
...
examples/seq2seq/run_eval.py
View file @
9dab39fe
...
@@ -30,6 +30,7 @@ def generate_summaries_or_translations(
...
@@ -30,6 +30,7 @@ def generate_summaries_or_translations(
device
:
str
=
DEFAULT_DEVICE
,
device
:
str
=
DEFAULT_DEVICE
,
fp16
=
False
,
fp16
=
False
,
task
=
"summarization"
,
task
=
"summarization"
,
decoder_start_token_id
=
None
,
**
gen_kwargs
,
**
gen_kwargs
,
)
->
None
:
)
->
None
:
fout
=
Path
(
out_file
).
open
(
"w"
,
encoding
=
"utf-8"
)
fout
=
Path
(
out_file
).
open
(
"w"
,
encoding
=
"utf-8"
)
...
@@ -37,6 +38,8 @@ def generate_summaries_or_translations(
...
@@ -37,6 +38,8 @@ def generate_summaries_or_translations(
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
model_name
).
to
(
device
)
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
model_name
).
to
(
device
)
if
fp16
:
if
fp16
:
model
=
model
.
half
()
model
=
model
.
half
()
if
decoder_start_token_id
is
None
:
decoder_start_token_id
=
gen_kwargs
.
pop
(
"decoder_start_token_id"
,
None
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
...
@@ -48,7 +51,12 @@ def generate_summaries_or_translations(
...
@@ -48,7 +51,12 @@ def generate_summaries_or_translations(
batch
=
[
model
.
config
.
prefix
+
text
for
text
in
batch
]
batch
=
[
model
.
config
.
prefix
+
text
for
text
in
batch
]
batch
=
tokenizer
(
batch
,
return_tensors
=
"pt"
,
truncation
=
True
,
padding
=
"max_length"
).
to
(
device
)
batch
=
tokenizer
(
batch
,
return_tensors
=
"pt"
,
truncation
=
True
,
padding
=
"max_length"
).
to
(
device
)
input_ids
,
attention_mask
=
trim_batch
(
**
batch
,
pad_token_id
=
tokenizer
.
pad_token_id
)
input_ids
,
attention_mask
=
trim_batch
(
**
batch
,
pad_token_id
=
tokenizer
.
pad_token_id
)
summaries
=
model
.
generate
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
**
gen_kwargs
)
summaries
=
model
.
generate
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
decoder_start_token_id
=
decoder_start_token_id
,
**
gen_kwargs
,
)
dec
=
tokenizer
.
batch_decode
(
summaries
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
dec
=
tokenizer
.
batch_decode
(
summaries
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
for
hypothesis
in
dec
:
for
hypothesis
in
dec
:
fout
.
write
(
hypothesis
+
"
\n
"
)
fout
.
write
(
hypothesis
+
"
\n
"
)
...
@@ -66,6 +74,13 @@ def run_generate():
...
@@ -66,6 +74,13 @@ def run_generate():
parser
.
add_argument
(
"--device"
,
type
=
str
,
required
=
False
,
default
=
DEFAULT_DEVICE
,
help
=
"cuda, cuda:1, cpu etc."
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
required
=
False
,
default
=
DEFAULT_DEVICE
,
help
=
"cuda, cuda:1, cpu etc."
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"summarization"
,
help
=
"typically translation or summarization"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"summarization"
,
help
=
"typically translation or summarization"
)
parser
.
add_argument
(
"--bs"
,
type
=
int
,
default
=
8
,
required
=
False
,
help
=
"batch size"
)
parser
.
add_argument
(
"--bs"
,
type
=
int
,
default
=
8
,
required
=
False
,
help
=
"batch size"
)
parser
.
add_argument
(
"--decoder_start_token_id"
,
type
=
int
,
default
=
None
,
required
=
False
,
help
=
"decoder_start_token_id (otherwise will look at config)"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--n_obs"
,
type
=
int
,
default
=-
1
,
required
=
False
,
help
=
"How many observations. Defaults to all."
"--n_obs"
,
type
=
int
,
default
=-
1
,
required
=
False
,
help
=
"How many observations. Defaults to all."
)
)
...
@@ -83,6 +98,7 @@ def run_generate():
...
@@ -83,6 +98,7 @@ def run_generate():
device
=
args
.
device
,
device
=
args
.
device
,
fp16
=
args
.
fp16
,
fp16
=
args
.
fp16
,
task
=
args
.
task
,
task
=
args
.
task
,
decoder_start_token_id
=
args
.
decoder_start_token_id
,
)
)
if
args
.
reference_path
is
None
:
if
args
.
reference_path
is
None
:
return
return
...
...
src/transformers/tokenization_utils_base.py
View file @
9dab39fe
...
@@ -2255,8 +2255,23 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
...
@@ -2255,8 +2255,23 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
return
encoded_inputs
return
encoded_inputs
def
batch_decode
(
self
,
sequences
:
List
[
List
[
int
]],
**
kwargs
)
->
List
[
str
]:
def
batch_decode
(
return
[
self
.
decode
(
seq
,
**
kwargs
)
for
seq
in
sequences
]
self
,
sequences
:
List
[
List
[
int
]],
skip_special_tokens
:
bool
=
False
,
clean_up_tokenization_spaces
:
bool
=
True
)
->
List
[
str
]:
"""
Convert a list of lists of token ids into a list of strings by calling decode.
Args:
token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods.
skip_special_tokens: if set to True, will replace special tokens.
clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces.
"""
return
[
self
.
decode
(
seq
,
skip_special_tokens
=
skip_special_tokens
,
clean_up_tokenization_spaces
=
clean_up_tokenization_spaces
)
for
seq
in
sequences
]
def
decode
(
def
decode
(
self
,
token_ids
:
List
[
int
],
skip_special_tokens
:
bool
=
False
,
clean_up_tokenization_spaces
:
bool
=
True
self
,
token_ids
:
List
[
int
],
skip_special_tokens
:
bool
=
False
,
clean_up_tokenization_spaces
:
bool
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment