Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
95119ad7
Unverified
Commit
95119ad7
authored
Dec 17, 2021
by
Patrick von Platen
Committed by
GitHub
Dec 17, 2021
Browse files
[Generate] Correct input_ids detection (#14815)
* [Generate] Correct input_ids detection * correct
parent
bdbe3df8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
1 deletion
+26
-1
src/transformers/generation_utils.py
src/transformers/generation_utils.py
+1
-1
tests/test_generation_utils.py
tests/test_generation_utils.py
+25
-0
No files found.
src/transformers/generation_utils.py
View file @
95119ad7
...
...
@@ -457,7 +457,7 @@ class GenerationMixin:
pad_token_id
:
int
,
eos_token_id
:
int
,
)
->
torch
.
LongTensor
:
is_input_ids
=
isinstance
(
inputs
,
torch
.
LongTensor
)
and
len
(
inputs
.
shape
)
==
2
is_input_ids
=
len
(
inputs
.
shape
)
==
2
and
inputs
.
dtype
in
[
torch
.
int
,
torch
.
long
]
is_pad_token_in_inputs
=
(
pad_token_id
is
not
None
)
and
(
pad_token_id
in
inputs
)
is_pad_token_not_equal_to_eos_token_id
=
(
eos_token_id
is
None
)
or
(
(
eos_token_id
is
not
None
)
and
(
pad_token_id
!=
eos_token_id
)
...
...
tests/test_generation_utils.py
View file @
95119ad7
...
...
@@ -1719,6 +1719,31 @@ class GenerationIntegrationTests(unittest.TestCase):
# make sure model generated correctly until `max_length`
self
.
assertEqual
(
output_sequences
.
shape
,
(
1
,
5
))
def
test_encoder_decoder_generate_attention_mask
(
self
):
articles
=
[
"Timberlake"
,
"Jessica Biel, welcome to parenthood among other things"
]
tokenizer
=
BartTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
# need extrem generation values here to force this test
# to fail when `attention_mask` is not correctly treated in generate
model
=
BartForConditionalGeneration
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
,
max_length
=
50
,
num_beams
=
5
,
num_return_sequences
=
5
).
to
(
torch_device
)
model
.
config
.
eos_token_id
=
None
input_ids
=
tokenizer
(
articles
[
0
],
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
input_ids_batched
=
tokenizer
(
articles
,
padding
=
True
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
output_sequences_batched
=
model
.
generate
(
input_ids
=
input_ids_batched
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
output_sequences
=
model
.
generate
(
input_ids
=
input_ids
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
batched_out
=
output_sequences_batched
.
sequences_scores
out
=
output_sequences
.
sequences_scores
diff
=
(
batched_out
[:
5
].
sum
()
-
out
.
sum
()).
abs
()
self
.
assertTrue
(
diff
<
1e-4
)
def
test_decoder_generate_with_inputs_embeds
(
self
):
article
=
"""I need input_ids to generate"""
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment