Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
92ce53aa
Unverified
Commit
92ce53aa
authored
Feb 01, 2023
by
Joao Gante
Committed by
GitHub
Feb 01, 2023
Browse files
Generate: decoder-only models can generate with `inputs_embeds` (#21405)
parent
e5db7051
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
68 additions
and
52 deletions
+68
-52
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+23
-30
src/transformers/models/gpt2/modeling_gpt2.py
src/transformers/models/gpt2/modeling_gpt2.py
+18
-9
tests/generation/test_utils.py
tests/generation/test_utils.py
+27
-13
No files found.
src/transformers/generation/utils.py
View file @
92ce53aa
...
@@ -519,47 +519,40 @@ class GenerationMixin:
...
@@ -519,47 +519,40 @@ class GenerationMixin:
inputs_kwarg
=
model_kwargs
.
pop
(
input_name
,
None
)
inputs_kwarg
=
model_kwargs
.
pop
(
input_name
,
None
)
if
inputs_kwarg
is
not
None
and
inputs
is
not
None
:
if
inputs_kwarg
is
not
None
and
inputs
is
not
None
:
raise
ValueError
(
raise
ValueError
(
f
"`inputs`:
{
inputs
}
` were passed alongside "
f
"`inputs`:
{
inputs
}
` were passed alongside
{
input_name
}
which is not allowed."
f
"
{
input_name
}
which is not allowed."
f
"Make sure to either pass
{
inputs
}
or
{
input_name
}
=..."
f
"Make sure to either pass
{
inputs
}
or
{
input_name
}
=..."
)
)
elif
inputs_kwarg
is
not
None
:
elif
inputs_kwarg
is
not
None
:
inputs
=
inputs_kwarg
inputs
=
inputs_kwarg
# 3. models with `input_ids` can also make use of `inputs_embeds`
# 3. In the presence of `inputs_embeds` for text models:
if
self
.
_can_retrieve_inputs_from_name
(
inputs
,
"inputs_embeds"
,
model_kwargs
):
# - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model
inputs
,
input_name
=
model_kwargs
[
"inputs_embeds"
],
"inputs_embeds"
# doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with
# input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)
# 4. Only encoder-decoder models can have non `input_ids` input format
# - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
if
not
self
.
config
.
is_encoder_decoder
and
input_name
!=
"input_ids"
:
# pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
if
input_name
==
"input_ids"
and
"inputs_embeds"
in
model_kwargs
:
if
not
self
.
config
.
is_encoder_decoder
:
has_inputs_embeds_forwarding
=
"inputs_embeds"
in
set
(
inspect
.
signature
(
self
.
prepare_inputs_for_generation
).
parameters
.
keys
()
)
if
not
has_inputs_embeds_forwarding
:
raise
ValueError
(
raise
ValueError
(
f
"If
{
input_name
}
is passed as model-specific keyword
"
f
"You passed `inputs_embeds` to `.generate()`, but the model class
{
self
.
__class__
.
__name__
}
"
"input then model has to be an encoder-decoder and not a
"
"doesn't have its forwarding implemented. See the GPT2 implementation for an example
"
f
"
{
self
.
__class__
.
__name__
}
.
"
"(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!
"
)
)
else
:
if
inputs
is
not
None
:
raise
ValueError
(
"You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one."
)
inputs
,
input_name
=
model_kwargs
[
"inputs_embeds"
],
"inputs_embeds"
#
5
. if `inputs` is still None, try to create `input_ids` from BOS token
#
4
. if `inputs` is still None, try to create `input_ids` from BOS token
if
inputs
is
None
:
if
inputs
is
None
:
inputs
=
self
.
_prepare_input_ids_for_generation
(
bos_token_id
,
model_kwargs
.
get
(
"encoder_outputs"
))
inputs
=
self
.
_prepare_input_ids_for_generation
(
bos_token_id
,
model_kwargs
.
get
(
"encoder_outputs"
))
return
inputs
,
input_name
,
model_kwargs
return
inputs
,
input_name
,
model_kwargs
def
_can_retrieve_inputs_from_name
(
self
,
inputs
:
Optional
[
torch
.
Tensor
],
name
:
str
,
model_kwargs
:
Dict
[
str
,
torch
.
Tensor
]
)
->
torch
.
Tensor
:
"""
If `inputs` is None and `name` is in both forward function and keyword arguments, then inputs can be retrieved
from name
"""
can_retrieve_inputs
=
model_kwargs
.
get
(
name
,
None
)
is
not
None
and
name
in
set
(
inspect
.
signature
(
self
.
forward
).
parameters
.
keys
()
)
if
can_retrieve_inputs
and
inputs
is
not
None
:
raise
ValueError
(
f
"Cannot only pass one of
{
name
}
and
{
self
.
main_input_name
}
"
)
return
can_retrieve_inputs
def
adjust_logits_during_generation
(
self
,
logits
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
FloatTensor
:
def
adjust_logits_during_generation
(
self
,
logits
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
FloatTensor
:
"""
"""
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
...
...
src/transformers/models/gpt2/modeling_gpt2.py
View file @
92ce53aa
...
@@ -981,7 +981,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
...
@@ -981,7 +981,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
def
set_output_embeddings
(
self
,
new_embeddings
):
def
set_output_embeddings
(
self
,
new_embeddings
):
self
.
lm_head
=
new_embeddings
self
.
lm_head
=
new_embeddings
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
**
kwargs
):
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
inputs_embeds
=
None
,
**
kwargs
):
token_type_ids
=
kwargs
.
get
(
"token_type_ids"
,
None
)
token_type_ids
=
kwargs
.
get
(
"token_type_ids"
,
None
)
# only last token for inputs_ids if past is defined in kwargs
# only last token for inputs_ids if past is defined in kwargs
if
past_key_values
:
if
past_key_values
:
...
@@ -1000,14 +1000,23 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
...
@@ -1000,14 +1000,23 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
position_ids
=
position_ids
[:,
-
1
].
unsqueeze
(
-
1
)
position_ids
=
position_ids
[:,
-
1
].
unsqueeze
(
-
1
)
else
:
else
:
position_ids
=
None
position_ids
=
None
return
{
"input_ids"
:
input_ids
,
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if
inputs_embeds
is
not
None
and
past_key_values
is
None
:
model_inputs
=
{
"inputs_embeds"
:
inputs_embeds
}
else
:
model_inputs
=
{
"input_ids"
:
input_ids
}
model_inputs
.
update
(
{
"past_key_values"
:
past_key_values
,
"past_key_values"
:
past_key_values
,
"use_cache"
:
kwargs
.
get
(
"use_cache"
),
"use_cache"
:
kwargs
.
get
(
"use_cache"
),
"position_ids"
:
position_ids
,
"position_ids"
:
position_ids
,
"attention_mask"
:
attention_mask
,
"attention_mask"
:
attention_mask
,
"token_type_ids"
:
token_type_ids
,
"token_type_ids"
:
token_type_ids
,
}
}
)
return
model_inputs
@
add_start_docstrings_to_model_forward
(
GPT2_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
GPT2_INPUTS_DOCSTRING
)
@
add_code_sample_docstrings
(
@
add_code_sample_docstrings
(
...
...
tests/generation/test_utils.py
View file @
92ce53aa
...
@@ -2359,17 +2359,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
...
@@ -2359,17 +2359,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self
.
assertTrue
(
diff
<
1e-4
)
self
.
assertTrue
(
diff
<
1e-4
)
def
test_decoder_generate_with_inputs_embeds
(
self
):
article
=
"""I need input_ids to generate"""
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
model
=
GPT2LMHeadModel
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
,
max_length
=
5
).
to
(
torch_device
)
input_ids
=
tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
inputs_embeds
=
model
.
get_input_embeddings
()(
input_ids
)
# cannot generate from `inputs_embeds` for decoder only
with
self
.
assertRaises
(
ValueError
):
model
.
generate
(
inputs_embeds
=
inputs_embeds
)
def
test_generate_input_ids_as_kwarg
(
self
):
def
test_generate_input_ids_as_kwarg
(
self
):
article
=
"""I need input_ids to generate"""
article
=
"""I need input_ids to generate"""
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
...
@@ -2417,8 +2406,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
...
@@ -2417,8 +2406,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
def
test_generate_too_many_encoder_kwargs
(
self
):
def
test_generate_too_many_encoder_kwargs
(
self
):
article
=
"""I need input_ids to generate"""
article
=
"""I need input_ids to generate"""
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
tokenizer
=
BartTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
model
=
GPT2LMHeadModel
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
,
max_length
=
10
).
to
(
torch_device
)
model
=
BartForConditionalGeneration
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
,
max_length
=
10
).
to
(
torch_device
)
input_ids
=
tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
input_ids
=
tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
model
.
generate
(
input_ids
=
input_ids
,
inputs_embeds
=
input_ids
)
model
.
generate
(
input_ids
=
input_ids
,
inputs_embeds
=
input_ids
)
...
@@ -3128,3 +3119,26 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
...
@@ -3128,3 +3119,26 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
eos_token_id
=
[
873
]
eos_token_id
=
[
873
]
generated_tokens
=
model
.
generate
(
**
tokens
,
eos_token_id
=
eos_token_id
,
**
generation_kwargs
)
generated_tokens
=
model
.
generate
(
**
tokens
,
eos_token_id
=
eos_token_id
,
**
generation_kwargs
)
self
.
assertTrue
(
expectation
==
len
(
generated_tokens
[
0
]))
self
.
assertTrue
(
expectation
==
len
(
generated_tokens
[
0
]))
def
test_generate_from_input_embeds_decoder_only
(
self
):
# Note: the model must support generation from input embeddings
model
=
AutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
text
=
"Hello world"
input_ids
=
tokenizer
.
encode
(
text
,
return_tensors
=
"pt"
)
# Traditional way of generating text
outputs_from_ids
=
model
.
generate
(
input_ids
)
# Same thing, but from input embeddings
inputs_embeds
=
model
.
transformer
.
wte
(
input_ids
)
outputs_from_embeds
=
model
.
generate
(
input_ids
,
inputs_embeds
=
inputs_embeds
)
self
.
assertListEqual
(
outputs_from_ids
.
tolist
(),
outputs_from_embeds
.
tolist
())
# But if we pass different inputs_embeds, we should get different outputs
torch
.
manual_seed
(
0
)
random_embeds
=
torch
.
rand_like
(
inputs_embeds
)
outputs_from_rand_embeds
=
model
.
generate
(
input_ids
,
inputs_embeds
=
random_embeds
)
with
self
.
assertRaises
(
AssertionError
):
self
.
assertListEqual
(
outputs_from_rand_embeds
.
tolist
(),
outputs_from_embeds
.
tolist
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment