Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
92ce53aa
Unverified
Commit
92ce53aa
authored
Feb 01, 2023
by
Joao Gante
Committed by
GitHub
Feb 01, 2023
Browse files
Generate: decoder-only models can generate with `inputs_embeds` (#21405)
parent
e5db7051
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
68 additions
and
52 deletions
+68
-52
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+23
-30
src/transformers/models/gpt2/modeling_gpt2.py
src/transformers/models/gpt2/modeling_gpt2.py
+18
-9
tests/generation/test_utils.py
tests/generation/test_utils.py
+27
-13
No files found.
src/transformers/generation/utils.py
View file @
92ce53aa
...
...
@@ -519,47 +519,40 @@ class GenerationMixin:
inputs_kwarg
=
model_kwargs
.
pop
(
input_name
,
None
)
if
inputs_kwarg
is
not
None
and
inputs
is
not
None
:
raise
ValueError
(
f
"`inputs`:
{
inputs
}
` were passed alongside "
f
"
{
input_name
}
which is not allowed."
f
"`inputs`:
{
inputs
}
` were passed alongside
{
input_name
}
which is not allowed."
f
"Make sure to either pass
{
inputs
}
or
{
input_name
}
=..."
)
elif
inputs_kwarg
is
not
None
:
inputs
=
inputs_kwarg
# 3. models with `input_ids` can also make use of `inputs_embeds`
if
self
.
_can_retrieve_inputs_from_name
(
inputs
,
"inputs_embeds"
,
model_kwargs
):
inputs
,
input_name
=
model_kwargs
[
"inputs_embeds"
],
"inputs_embeds"
# 4. Only encoder-decoder models can have non `input_ids` input format
if
not
self
.
config
.
is_encoder_decoder
and
input_name
!=
"input_ids"
:
raise
ValueError
(
f
"If
{
input_name
}
is passed as model-specific keyword "
"input then model has to be an encoder-decoder and not a "
f
"
{
self
.
__class__
.
__name__
}
."
)
# 3. In the presence of `inputs_embeds` for text models:
# - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model
# doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with
# input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)
# - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
# pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
if
input_name
==
"input_ids"
and
"inputs_embeds"
in
model_kwargs
:
if
not
self
.
config
.
is_encoder_decoder
:
has_inputs_embeds_forwarding
=
"inputs_embeds"
in
set
(
inspect
.
signature
(
self
.
prepare_inputs_for_generation
).
parameters
.
keys
()
)
if
not
has_inputs_embeds_forwarding
:
raise
ValueError
(
f
"You passed `inputs_embeds` to `.generate()`, but the model class
{
self
.
__class__
.
__name__
}
"
"doesn't have its forwarding implemented. See the GPT2 implementation for an example "
"(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
)
else
:
if
inputs
is
not
None
:
raise
ValueError
(
"You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one."
)
inputs
,
input_name
=
model_kwargs
[
"inputs_embeds"
],
"inputs_embeds"
#
5
. if `inputs` is still None, try to create `input_ids` from BOS token
#
4
. if `inputs` is still None, try to create `input_ids` from BOS token
if
inputs
is
None
:
inputs
=
self
.
_prepare_input_ids_for_generation
(
bos_token_id
,
model_kwargs
.
get
(
"encoder_outputs"
))
return
inputs
,
input_name
,
model_kwargs
def
_can_retrieve_inputs_from_name
(
self
,
inputs
:
Optional
[
torch
.
Tensor
],
name
:
str
,
model_kwargs
:
Dict
[
str
,
torch
.
Tensor
]
)
->
torch
.
Tensor
:
"""
If `inputs` is None and `name` is in both forward function and keyword arguments, then inputs can be retrieved
from name
"""
can_retrieve_inputs
=
model_kwargs
.
get
(
name
,
None
)
is
not
None
and
name
in
set
(
inspect
.
signature
(
self
.
forward
).
parameters
.
keys
()
)
if
can_retrieve_inputs
and
inputs
is
not
None
:
raise
ValueError
(
f
"Cannot only pass one of
{
name
}
and
{
self
.
main_input_name
}
"
)
return
can_retrieve_inputs
def
adjust_logits_during_generation
(
self
,
logits
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
FloatTensor
:
"""
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
...
...
src/transformers/models/gpt2/modeling_gpt2.py
View file @
92ce53aa
...
...
@@ -981,7 +981,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
def
set_output_embeddings
(
self
,
new_embeddings
):
self
.
lm_head
=
new_embeddings
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
**
kwargs
):
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
inputs_embeds
=
None
,
**
kwargs
):
token_type_ids
=
kwargs
.
get
(
"token_type_ids"
,
None
)
# only last token for inputs_ids if past is defined in kwargs
if
past_key_values
:
...
...
@@ -1000,14 +1000,23 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
position_ids
=
position_ids
[:,
-
1
].
unsqueeze
(
-
1
)
else
:
position_ids
=
None
return
{
"input_ids"
:
input_ids
,
"past_key_values"
:
past_key_values
,
"use_cache"
:
kwargs
.
get
(
"use_cache"
),
"position_ids"
:
position_ids
,
"attention_mask"
:
attention_mask
,
"token_type_ids"
:
token_type_ids
,
}
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if
inputs_embeds
is
not
None
and
past_key_values
is
None
:
model_inputs
=
{
"inputs_embeds"
:
inputs_embeds
}
else
:
model_inputs
=
{
"input_ids"
:
input_ids
}
model_inputs
.
update
(
{
"past_key_values"
:
past_key_values
,
"use_cache"
:
kwargs
.
get
(
"use_cache"
),
"position_ids"
:
position_ids
,
"attention_mask"
:
attention_mask
,
"token_type_ids"
:
token_type_ids
,
}
)
return
model_inputs
@
add_start_docstrings_to_model_forward
(
GPT2_INPUTS_DOCSTRING
)
@
add_code_sample_docstrings
(
...
...
tests/generation/test_utils.py
View file @
92ce53aa
...
...
@@ -2359,17 +2359,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self
.
assertTrue
(
diff
<
1e-4
)
def
test_decoder_generate_with_inputs_embeds
(
self
):
article
=
"""I need input_ids to generate"""
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
model
=
GPT2LMHeadModel
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
,
max_length
=
5
).
to
(
torch_device
)
input_ids
=
tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
inputs_embeds
=
model
.
get_input_embeddings
()(
input_ids
)
# cannot generate from `inputs_embeds` for decoder only
with
self
.
assertRaises
(
ValueError
):
model
.
generate
(
inputs_embeds
=
inputs_embeds
)
def
test_generate_input_ids_as_kwarg
(
self
):
article
=
"""I need input_ids to generate"""
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
...
...
@@ -2417,8 +2406,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
def
test_generate_too_many_encoder_kwargs
(
self
):
article
=
"""I need input_ids to generate"""
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
model
=
GPT2LMHeadModel
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
,
max_length
=
10
).
to
(
torch_device
)
tokenizer
=
BartTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
model
=
BartForConditionalGeneration
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
,
max_length
=
10
).
to
(
torch_device
)
input_ids
=
tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
with
self
.
assertRaises
(
ValueError
):
model
.
generate
(
input_ids
=
input_ids
,
inputs_embeds
=
input_ids
)
...
...
@@ -3128,3 +3119,26 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
eos_token_id
=
[
873
]
generated_tokens
=
model
.
generate
(
**
tokens
,
eos_token_id
=
eos_token_id
,
**
generation_kwargs
)
self
.
assertTrue
(
expectation
==
len
(
generated_tokens
[
0
]))
def
test_generate_from_input_embeds_decoder_only
(
self
):
# Note: the model must support generation from input embeddings
model
=
AutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
text
=
"Hello world"
input_ids
=
tokenizer
.
encode
(
text
,
return_tensors
=
"pt"
)
# Traditional way of generating text
outputs_from_ids
=
model
.
generate
(
input_ids
)
# Same thing, but from input embeddings
inputs_embeds
=
model
.
transformer
.
wte
(
input_ids
)
outputs_from_embeds
=
model
.
generate
(
input_ids
,
inputs_embeds
=
inputs_embeds
)
self
.
assertListEqual
(
outputs_from_ids
.
tolist
(),
outputs_from_embeds
.
tolist
())
# But if we pass different inputs_embeds, we should get different outputs
torch
.
manual_seed
(
0
)
random_embeds
=
torch
.
rand_like
(
inputs_embeds
)
outputs_from_rand_embeds
=
model
.
generate
(
input_ids
,
inputs_embeds
=
random_embeds
)
with
self
.
assertRaises
(
AssertionError
):
self
.
assertListEqual
(
outputs_from_rand_embeds
.
tolist
(),
outputs_from_embeds
.
tolist
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment