Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
79bbcc52
Unverified
Commit
79bbcc52
authored
Jan 08, 2021
by
Patrick von Platen
Committed by
GitHub
Jan 08, 2021
Browse files
[Generation] Fix bug for manual decoder_input_ids + warning message (#9472)
* up * improve style
parent
9e1ea846
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
8 deletions
+14
-8
src/transformers/generation_utils.py
src/transformers/generation_utils.py
+14
-8
No files found.
src/transformers/generation_utils.py
View file @
79bbcc52
...
@@ -379,12 +379,8 @@ class GenerationMixin:
...
@@ -379,12 +379,8 @@ class GenerationMixin:
return
model_kwargs
return
model_kwargs
def
_prepare_decoder_input_ids_for_generation
(
def
_prepare_decoder_input_ids_for_generation
(
self
,
input_ids
:
torch
.
LongTensor
,
decoder_start_token_id
:
int
=
None
,
bos_token_id
:
int
=
None
,
**
model_kwargs
self
,
input_ids
:
torch
.
LongTensor
,
decoder_start_token_id
:
int
=
None
,
bos_token_id
:
int
=
None
)
->
torch
.
LongTensor
:
)
->
torch
.
LongTensor
:
if
"decoder_input_ids"
in
model_kwargs
:
return
model_kwargs
[
"decoder_input_ids"
]
decoder_start_token_id
=
self
.
_get_decoder_start_token_id
(
decoder_start_token_id
,
bos_token_id
)
decoder_start_token_id
=
self
.
_get_decoder_start_token_id
(
decoder_start_token_id
,
bos_token_id
)
decoder_input_ids
=
(
decoder_input_ids
=
(
torch
.
ones
((
input_ids
.
shape
[
0
],
1
),
dtype
=
input_ids
.
dtype
,
device
=
input_ids
.
device
)
torch
.
ones
((
input_ids
.
shape
[
0
],
1
),
dtype
=
input_ids
.
dtype
,
device
=
input_ids
.
device
)
...
@@ -837,13 +833,23 @@ class GenerationMixin:
...
@@ -837,13 +833,23 @@ class GenerationMixin:
model_kwargs
=
self
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
model_kwargs
)
model_kwargs
=
self
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
model_kwargs
)
# set input_ids as decoder_input_ids
# set input_ids as decoder_input_ids
input_ids
=
self
.
_prepare_decoder_input_ids_for_generation
(
if
"decoder_input_ids"
in
model_kwargs
:
input_ids
,
decoder_start_token_id
=
decoder_start_token_id
,
bos_token_id
=
bos_token_id
,
**
model_kwargs
input_ids
=
model_kwargs
.
pop
(
"decoder_input_ids"
)
)
else
:
input_ids
=
self
.
_prepare_decoder_input_ids_for_generation
(
input_ids
,
decoder_start_token_id
=
decoder_start_token_id
,
bos_token_id
=
bos_token_id
)
if
"encoder_outputs"
not
in
model_kwargs
or
not
isinstance
(
model_kwargs
[
"encoder_outputs"
],
ModelOutput
):
if
"encoder_outputs"
not
in
model_kwargs
or
not
isinstance
(
model_kwargs
[
"encoder_outputs"
],
ModelOutput
):
raise
ValueError
(
"Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`."
)
raise
ValueError
(
"Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`."
)
if
input_ids
.
shape
[
-
1
]
>=
max_length
:
input_ids_string
=
"decoder_input_ids"
if
self
.
config
.
is_encoder_decoder
else
"input_ids"
logger
.
warning
(
f
"Input length of
{
input_ids_string
}
is
{
input_ids
.
shape
[
-
1
]
}
, but ``max_length`` is set to
{
max_length
}
."
"This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``."
)
# determine generation mode
# determine generation mode
is_greedy_gen_mode
=
(
num_beams
==
1
)
and
(
num_beam_groups
==
1
)
and
do_sample
is
False
is_greedy_gen_mode
=
(
num_beams
==
1
)
and
(
num_beam_groups
==
1
)
and
do_sample
is
False
is_sample_gen_mode
=
(
num_beams
==
1
)
and
(
num_beam_groups
==
1
)
and
do_sample
is
True
is_sample_gen_mode
=
(
num_beams
==
1
)
and
(
num_beam_groups
==
1
)
and
do_sample
is
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment