Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fa4bdb0a
Unverified
Commit
fa4bdb0a
authored
Feb 13, 2023
by
Joao Gante
Committed by
GitHub
Feb 13, 2023
Browse files
Generate: correct default model input creation for decoder-only models (#21580)
parent
edc1e734
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
109 additions
and
17 deletions
+109
-17
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+23
-8
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+23
-9
tests/generation/test_utils.py
tests/generation/test_utils.py
+35
-0
tests/models/blip_2/test_modeling_blip_2.py
tests/models/blip_2/test_modeling_blip_2.py
+28
-0
No files found.
src/transformers/generation/tf_utils.py
View file @
fa4bdb0a
...
...
@@ -845,8 +845,7 @@ class TFGenerationMixin:
model_kwargs
=
model_kwargs
,
)
else
:
# if decoder-only then inputs_tensor has to be `input_ids`
input_ids
=
inputs_tensor
input_ids
=
inputs_tensor
if
model_input_name
==
"input_ids"
else
model_kwargs
.
pop
(
"input_ids"
)
# 7. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length
=
shape_list
(
input_ids
)[
-
1
]
...
...
@@ -1214,20 +1213,34 @@ class TFGenerationMixin:
"doesn't have its forwarding implemented. See the GPT2 implementation for an example "
"(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
)
# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
# the attention mask) can rely on the actual model input.
model_kwargs
[
"input_ids"
]
=
self
.
_maybe_initialize_input_ids_for_generation
(
inputs
,
bos_token_id
,
batch_size
=
model_kwargs
[
"inputs_embeds"
].
shape
[
0
]
)
else
:
if
inputs
is
not
None
:
raise
ValueError
(
"You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one."
)
inputs
,
input_name
=
model_kwargs
[
"inputs_embeds"
],
"inputs_embeds"
inputs
,
input_name
=
model_kwargs
[
"inputs_embeds"
],
"inputs_embeds"
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
if
inputs
is
None
:
inputs
=
self
.
_prepare_input_ids_for_generation
(
bos_token_id
,
model_kwargs
.
get
(
"encoder_outputs"
))
inputs
=
self
.
_maybe_initialize_input_ids_for_generation
(
inputs
,
bos_token_id
,
model_kwargs
.
get
(
"encoder_outputs"
)
)
return
inputs
,
input_name
,
model_kwargs
def
_prepare_input_ids_for_generation
(
self
,
bos_token_id
:
Optional
[
int
],
encoder_outputs
:
Optional
[
ModelOutput
]
def
_maybe_initialize_input_ids_for_generation
(
self
,
inputs
:
Optional
[
tf
.
Tensor
]
=
None
,
bos_token_id
:
Optional
[
int
]
=
None
,
encoder_outputs
:
Optional
[
ModelOutput
]
=
None
,
batch_size
:
Optional
[
int
]
=
None
,
)
->
tf
.
Tensor
:
"""Initializes input ids for generation, if necessary."""
if
inputs
is
not
None
:
return
inputs
if
self
.
config
.
is_encoder_decoder
and
encoder_outputs
is
not
None
:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape
=
encoder_outputs
.
last_hidden_state
.
shape
[:
-
1
]
...
...
@@ -1235,7 +1248,9 @@ class TFGenerationMixin:
if
bos_token_id
is
None
:
raise
ValueError
(
"`bos_token_id` has to be defined when no `input_ids` are provided."
)
return
tf
.
ones
((
1
,
1
),
dtype
=
tf
.
int32
)
*
bos_token_id
batch_size
=
batch_size
if
batch_size
is
not
None
else
1
return
tf
.
ones
((
batch_size
,
1
),
dtype
=
tf
.
int32
)
*
bos_token_id
@
staticmethod
def
_extract_past_from_model_output
(
outputs
:
ModelOutput
):
...
...
src/transformers/generation/utils.py
View file @
fa4bdb0a
...
...
@@ -541,15 +541,20 @@ class GenerationMixin:
"doesn't have its forwarding implemented. See the GPT2 implementation for an example "
"(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
)
# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
# the attention mask) can rely on the actual model input.
model_kwargs
[
"input_ids"
]
=
self
.
_maybe_initialize_input_ids_for_generation
(
inputs
,
bos_token_id
,
batch_size
=
model_kwargs
[
"inputs_embeds"
].
shape
[
0
]
)
else
:
if
inputs
is
not
None
:
raise
ValueError
(
"You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one."
)
inputs
,
input_name
=
model_kwargs
[
"inputs_embeds"
],
"inputs_embeds"
inputs
,
input_name
=
model_kwargs
[
"inputs_embeds"
],
"inputs_embeds"
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
i
f
inputs
is
None
:
inputs
=
self
.
_prepare_input_ids_for_generation
(
bos_token_id
,
model_kwargs
.
get
(
"encoder_outputs"
)
)
i
nputs
=
self
.
_maybe_initialize_input_ids_for_generation
(
inputs
,
bos_token_id
,
model_kwargs
.
get
(
"encoder_outputs"
)
)
return
inputs
,
input_name
,
model_kwargs
def
adjust_logits_during_generation
(
self
,
logits
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
FloatTensor
:
...
...
@@ -558,9 +563,17 @@ class GenerationMixin:
"""
return
logits
def
_prepare_input_ids_for_generation
(
self
,
bos_token_id
:
Optional
[
int
],
encoder_outputs
:
Optional
[
ModelOutput
]
def
_maybe_initialize_input_ids_for_generation
(
self
,
inputs
:
Optional
[
torch
.
Tensor
]
=
None
,
bos_token_id
:
Optional
[
int
]
=
None
,
encoder_outputs
:
Optional
[
ModelOutput
]
=
None
,
batch_size
:
Optional
[
int
]
=
None
,
)
->
torch
.
LongTensor
:
"""Initializes input ids for generation, if necessary."""
if
inputs
is
not
None
:
return
inputs
if
self
.
config
.
is_encoder_decoder
and
encoder_outputs
is
not
None
:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape
=
encoder_outputs
.
last_hidden_state
.
size
()[:
-
1
]
...
...
@@ -568,7 +581,9 @@ class GenerationMixin:
if
bos_token_id
is
None
:
raise
ValueError
(
"`bos_token_id` has to be defined when no `input_ids` are provided."
)
return
torch
.
ones
((
1
,
1
),
dtype
=
torch
.
long
,
device
=
self
.
device
)
*
bos_token_id
batch_size
=
batch_size
if
batch_size
is
not
None
else
1
return
torch
.
ones
((
batch_size
,
1
),
dtype
=
torch
.
long
,
device
=
self
.
device
)
*
bos_token_id
def
_prepare_attention_mask_for_generation
(
self
,
...
...
@@ -1258,8 +1273,7 @@ class GenerationMixin:
device
=
inputs_tensor
.
device
,
)
else
:
# if decoder-only then inputs_tensor has to be `input_ids`
input_ids
=
inputs_tensor
input_ids
=
inputs_tensor
if
model_input_name
==
"input_ids"
else
model_kwargs
.
pop
(
"input_ids"
)
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length
=
input_ids
.
shape
[
-
1
]
...
...
tests/generation/test_utils.py
View file @
fa4bdb0a
...
...
@@ -2488,3 +2488,38 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
eos_token_id
=
[
846
,
198
]
generated_tokens
=
model
.
generate
(
**
tokens
,
eos_token_id
=
eos_token_id
,
**
generation_kwargs
)
self
.
assertTrue
(
expectation
==
len
(
generated_tokens
[
0
]))
def
test_generate_from_inputs_embeds_decoder_only
(
self
):
# Note: the model must support generation from input embeddings
model
=
AutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
).
to
(
torch_device
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
model
.
config
.
pad_token_id
=
tokenizer
.
eos_token_id
text
=
"Hello world"
tokenized_inputs
=
tokenizer
([
text
,
text
],
return_tensors
=
"pt"
)
input_ids
=
tokenized_inputs
.
input_ids
.
to
(
torch_device
)
# Traditional way of generating text
outputs_from_ids
=
model
.
generate
(
input_ids
)
self
.
assertEqual
(
outputs_from_ids
.
shape
,
(
2
,
20
))
# Same thing, but from input embeddings
inputs_embeds
=
model
.
transformer
.
wte
(
input_ids
)
outputs_from_embeds
=
model
.
generate
(
input_ids
,
inputs_embeds
=
inputs_embeds
)
self
.
assertListEqual
(
outputs_from_ids
.
tolist
(),
outputs_from_embeds
.
tolist
())
# But if we pass different inputs_embeds, we should get different outputs
torch
.
manual_seed
(
0
)
random_embeds
=
torch
.
rand_like
(
inputs_embeds
)
outputs_from_rand_embeds
=
model
.
generate
(
input_ids
,
inputs_embeds
=
random_embeds
)
with
self
.
assertRaises
(
AssertionError
):
self
.
assertListEqual
(
outputs_from_rand_embeds
.
tolist
(),
outputs_from_embeds
.
tolist
())
# input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
outputs_from_embeds_wo_ids
=
model
.
generate
(
inputs_embeds
=
inputs_embeds
,
max_new_tokens
=
20
-
inputs_embeds
.
shape
[
1
]
)
self
.
assertListEqual
(
outputs_from_embeds
[:,
inputs_embeds
.
shape
[
1
]
:].
tolist
(),
outputs_from_embeds_wo_ids
[:,
1
:].
tolist
(),
)
tests/models/blip_2/test_modeling_blip_2.py
View file @
fa4bdb0a
...
...
@@ -797,6 +797,20 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
)
self
.
assertEqual
(
generated_text
,
"it's not a city, it's a beach"
)
def
test_inference_opt_batched
(
self
):
processor
=
Blip2Processor
.
from_pretrained
(
"Salesforce/blip2-opt-2.7b"
)
model
=
Blip2ForConditionalGeneration
.
from_pretrained
(
"Salesforce/blip2-opt-2.7b"
).
to
(
torch_device
)
# prepare image
image
=
prepare_img
()
inputs
=
processor
(
images
=
[
image
,
image
],
return_tensors
=
"pt"
).
to
(
torch_device
)
predictions
=
model
.
generate
(
**
inputs
)
# Test output
self
.
assertEqual
(
predictions
[
0
].
tolist
(),
[
2
,
102
,
693
,
2828
,
15
,
5
,
4105
,
19
,
10
,
2335
,
50118
])
self
.
assertEqual
(
predictions
[
1
].
tolist
(),
[
2
,
102
,
693
,
2828
,
15
,
5
,
4105
,
19
,
10
,
2335
,
50118
])
def
test_inference_t5
(
self
):
processor
=
Blip2Processor
.
from_pretrained
(
"Salesforce/blip2-flan-t5-xl"
)
model
=
Blip2ForConditionalGeneration
.
from_pretrained
(
...
...
@@ -827,3 +841,17 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
[
0
,
3
,
7
,
152
,
67
,
839
,
1
],
)
self
.
assertEqual
(
generated_text
,
"san diego"
)
def
test_inference_t5_batched
(
self
):
processor
=
Blip2Processor
.
from_pretrained
(
"Salesforce/blip2-flan-t5-xl"
)
model
=
Blip2ForConditionalGeneration
.
from_pretrained
(
"Salesforce/blip2-flan-t5-xl"
).
to
(
torch_device
)
# prepare image
image
=
prepare_img
()
inputs
=
processor
(
images
=
[
image
,
image
],
return_tensors
=
"pt"
).
to
(
torch_device
)
predictions
=
model
.
generate
(
**
inputs
)
# Test output
self
.
assertEqual
(
predictions
[
0
].
tolist
(),
[
0
,
2335
,
1556
,
28
,
1782
,
30
,
8
,
2608
,
1
])
self
.
assertEqual
(
predictions
[
1
].
tolist
(),
[
0
,
2335
,
1556
,
28
,
1782
,
30
,
8
,
2608
,
1
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment