Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
79bbcc52
Unverified
Commit
79bbcc52
authored
Jan 08, 2021
by
Patrick von Platen
Committed by
GitHub
Jan 08, 2021
Browse files
[Generation] Fix bug for manual decoder_input_ids + warning message (#9472)
* up * improve style
parent
9e1ea846
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
8 deletions
+14
-8
src/transformers/generation_utils.py
src/transformers/generation_utils.py
+14
-8
No files found.
src/transformers/generation_utils.py
View file @
79bbcc52
...
@@ -379,12 +379,8 @@ class GenerationMixin:
...
@@ -379,12 +379,8 @@ class GenerationMixin:
return
model_kwargs
return
model_kwargs
def
_prepare_decoder_input_ids_for_generation
(
def
_prepare_decoder_input_ids_for_generation
(
self
,
input_ids
:
torch
.
LongTensor
,
decoder_start_token_id
:
int
=
None
,
bos_token_id
:
int
=
None
,
**
model_kwargs
self
,
input_ids
:
torch
.
LongTensor
,
decoder_start_token_id
:
int
=
None
,
bos_token_id
:
int
=
None
)
->
torch
.
LongTensor
:
)
->
torch
.
LongTensor
:
if
"decoder_input_ids"
in
model_kwargs
:
return
model_kwargs
[
"decoder_input_ids"
]
decoder_start_token_id
=
self
.
_get_decoder_start_token_id
(
decoder_start_token_id
,
bos_token_id
)
decoder_start_token_id
=
self
.
_get_decoder_start_token_id
(
decoder_start_token_id
,
bos_token_id
)
decoder_input_ids
=
(
decoder_input_ids
=
(
torch
.
ones
((
input_ids
.
shape
[
0
],
1
),
dtype
=
input_ids
.
dtype
,
device
=
input_ids
.
device
)
torch
.
ones
((
input_ids
.
shape
[
0
],
1
),
dtype
=
input_ids
.
dtype
,
device
=
input_ids
.
device
)
...
@@ -837,13 +833,23 @@ class GenerationMixin:
...
@@ -837,13 +833,23 @@ class GenerationMixin:
model_kwargs
=
self
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
model_kwargs
)
model_kwargs
=
self
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
model_kwargs
)
# set input_ids as decoder_input_ids
# set input_ids as decoder_input_ids
if
"decoder_input_ids"
in
model_kwargs
:
input_ids
=
model_kwargs
.
pop
(
"decoder_input_ids"
)
else
:
input_ids
=
self
.
_prepare_decoder_input_ids_for_generation
(
input_ids
=
self
.
_prepare_decoder_input_ids_for_generation
(
input_ids
,
decoder_start_token_id
=
decoder_start_token_id
,
bos_token_id
=
bos_token_id
,
**
model_kwargs
input_ids
,
decoder_start_token_id
=
decoder_start_token_id
,
bos_token_id
=
bos_token_id
)
)
if
"encoder_outputs"
not
in
model_kwargs
or
not
isinstance
(
model_kwargs
[
"encoder_outputs"
],
ModelOutput
):
if
"encoder_outputs"
not
in
model_kwargs
or
not
isinstance
(
model_kwargs
[
"encoder_outputs"
],
ModelOutput
):
raise
ValueError
(
"Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`."
)
raise
ValueError
(
"Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`."
)
if
input_ids
.
shape
[
-
1
]
>=
max_length
:
input_ids_string
=
"decoder_input_ids"
if
self
.
config
.
is_encoder_decoder
else
"input_ids"
logger
.
warning
(
f
"Input length of
{
input_ids_string
}
is
{
input_ids
.
shape
[
-
1
]
}
, but ``max_length`` is set to
{
max_length
}
."
"This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``."
)
# determine generation mode
# determine generation mode
is_greedy_gen_mode
=
(
num_beams
==
1
)
and
(
num_beam_groups
==
1
)
and
do_sample
is
False
is_greedy_gen_mode
=
(
num_beams
==
1
)
and
(
num_beam_groups
==
1
)
and
do_sample
is
False
is_sample_gen_mode
=
(
num_beams
==
1
)
and
(
num_beam_groups
==
1
)
and
do_sample
is
True
is_sample_gen_mode
=
(
num_beams
==
1
)
and
(
num_beam_groups
==
1
)
and
do_sample
is
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment