Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
13e03e61
Unverified
Commit
13e03e61
authored
Feb 14, 2023
by
Joao Gante
Committed by
GitHub
Feb 14, 2023
Browse files
Generate: filter encoder inputs when its signature does not accept wildcards (#21603)
parent
41fa672d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
94 additions
and
28 deletions
+94
-28
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+9
-3
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+7
-1
tests/generation/test_tf_utils.py
tests/generation/test_tf_utils.py
+38
-0
tests/generation/test_utils.py
tests/generation/test_utils.py
+40
-24
No files found.
src/transformers/generation/tf_utils.py
View file @
13e03e61
...
@@ -1078,18 +1078,24 @@ class TFGenerationMixin:
...
@@ -1078,18 +1078,24 @@ class TFGenerationMixin:
def
_prepare_encoder_decoder_kwargs_for_generation
(
def
_prepare_encoder_decoder_kwargs_for_generation
(
self
,
inputs_tensor
:
tf
.
Tensor
,
model_kwargs
,
model_input_name
:
Optional
[
str
]
=
None
self
,
inputs_tensor
:
tf
.
Tensor
,
model_kwargs
,
model_input_name
:
Optional
[
str
]
=
None
)
->
Dict
[
str
,
Any
]:
)
->
Dict
[
str
,
Any
]:
# get encoder and store encoder outputs
#
1.
get encoder and store encoder outputs
encoder
=
self
.
get_encoder
()
encoder
=
self
.
get_encoder
()
# prepare encoder args and encoder kwargs from model kwargs
#
2.
prepare encoder args and encoder kwargs from model kwargs
irrelevant_prefix
=
[
"decoder_"
,
"cross_attn"
,
"use_cache"
]
irrelevant_prefix
=
[
"decoder_"
,
"cross_attn"
,
"use_cache"
]
encoder_kwargs
=
{
encoder_kwargs
=
{
argument
:
value
argument
:
value
for
argument
,
value
in
model_kwargs
.
items
()
for
argument
,
value
in
model_kwargs
.
items
()
if
not
any
(
argument
.
startswith
(
p
)
for
p
in
irrelevant_prefix
)
if
not
any
(
argument
.
startswith
(
p
)
for
p
in
irrelevant_prefix
)
}
}
encoder_signature
=
set
(
inspect
.
signature
(
encoder
.
call
).
parameters
)
encoder_accepts_wildcard
=
"kwargs"
in
encoder_signature
or
"model_kwargs"
in
encoder_signature
if
not
encoder_accepts_wildcard
:
encoder_kwargs
=
{
argument
:
value
for
argument
,
value
in
encoder_kwargs
.
items
()
if
argument
in
encoder_signature
}
# vision models don't use `attention_mask`.
#
3.
vision models don't use `attention_mask`.
encoder_kwargs
[
"return_dict"
]
=
True
encoder_kwargs
[
"return_dict"
]
=
True
encoder_kwargs
[
model_input_name
]
=
inputs_tensor
encoder_kwargs
[
model_input_name
]
=
inputs_tensor
if
model_input_name
!=
self
.
main_input_name
:
# in Keras, the first input must always be passed
if
model_input_name
!=
self
.
main_input_name
:
# in Keras, the first input must always be passed
...
...
src/transformers/generation/utils.py
View file @
13e03e61
...
@@ -609,13 +609,19 @@ class GenerationMixin:
...
@@ -609,13 +609,19 @@ class GenerationMixin:
# 1. get encoder
# 1. get encoder
encoder
=
self
.
get_encoder
()
encoder
=
self
.
get_encoder
()
# 2.
p
repare encoder args and encoder kwargs from model kwargs
# 2.
P
repare encoder args and encoder kwargs from model kwargs
.
irrelevant_prefix
=
[
"decoder_"
,
"cross_attn"
,
"use_cache"
]
irrelevant_prefix
=
[
"decoder_"
,
"cross_attn"
,
"use_cache"
]
encoder_kwargs
=
{
encoder_kwargs
=
{
argument
:
value
argument
:
value
for
argument
,
value
in
model_kwargs
.
items
()
for
argument
,
value
in
model_kwargs
.
items
()
if
not
any
(
argument
.
startswith
(
p
)
for
p
in
irrelevant_prefix
)
if
not
any
(
argument
.
startswith
(
p
)
for
p
in
irrelevant_prefix
)
}
}
encoder_signature
=
set
(
inspect
.
signature
(
encoder
.
forward
).
parameters
)
encoder_accepts_wildcard
=
"kwargs"
in
encoder_signature
or
"model_kwargs"
in
encoder_signature
if
not
encoder_accepts_wildcard
:
encoder_kwargs
=
{
argument
:
value
for
argument
,
value
in
encoder_kwargs
.
items
()
if
argument
in
encoder_signature
}
# 3. make sure that encoder returns `ModelOutput`
# 3. make sure that encoder returns `ModelOutput`
model_input_name
=
model_input_name
if
model_input_name
is
not
None
else
self
.
main_input_name
model_input_name
=
model_input_name
if
model_input_name
is
not
None
else
self
.
main_input_name
...
...
tests/generation/test_tf_utils.py
View file @
13e03e61
...
@@ -16,6 +16,8 @@
...
@@ -16,6 +16,8 @@
import
tempfile
import
tempfile
import
unittest
import
unittest
import
numpy
as
np
from
transformers
import
is_tf_available
from
transformers
import
is_tf_available
from
transformers.testing_utils
import
require_tf
,
slow
from
transformers.testing_utils
import
require_tf
,
slow
...
@@ -32,6 +34,7 @@ if is_tf_available():
...
@@ -32,6 +34,7 @@ if is_tf_available():
TFAutoModelForSeq2SeqLM
,
TFAutoModelForSeq2SeqLM
,
TFAutoModelForSpeechSeq2Seq
,
TFAutoModelForSpeechSeq2Seq
,
TFAutoModelForVision2Seq
,
TFAutoModelForVision2Seq
,
TFBartForConditionalGeneration
,
TFLogitsProcessorList
,
TFLogitsProcessorList
,
TFMinLengthLogitsProcessor
,
TFMinLengthLogitsProcessor
,
tf_top_k_top_p_filtering
,
tf_top_k_top_p_filtering
,
...
@@ -264,3 +267,38 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
...
@@ -264,3 +267,38 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
tf
.
random
.
set_seed
(
0
)
tf
.
random
.
set_seed
(
0
)
generated_tokens
=
model
.
generate
(
**
tokens
,
eos_token_id
=
eos_token_id
,
**
generation_kwargs
)
generated_tokens
=
model
.
generate
(
**
tokens
,
eos_token_id
=
eos_token_id
,
**
generation_kwargs
)
self
.
assertTrue
(
expectation
==
len
(
generated_tokens
[
0
]))
self
.
assertTrue
(
expectation
==
len
(
generated_tokens
[
0
]))
def
test_model_kwarg_encoder_signature_filtering
(
self
):
# Has PT equivalent: ample use of framework-specific code
bart_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
article
=
"""Hugging Face is a technology company based in New York and Paris."""
input_ids
=
bart_tokenizer
(
article
,
return_tensors
=
"tf"
).
input_ids
bart_model
=
TFBartForConditionalGeneration
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
output
=
bart_model
.
generate
(
input_ids
).
numpy
()
# Let's create a fake model that has a different signature. In particular, this fake model accepts "foo" as an
# argument. Because "foo" is not in the encoder signature and doesn't start with "decoder_", it will be part of
# the encoder kwargs prior to signature filtering, which would lead to an exception. But filtering kicks in and
# saves the day.
class
FakeBart
(
TFBartForConditionalGeneration
):
def
call
(
self
,
input_ids
,
foo
=
None
,
**
kwargs
):
return
super
().
call
(
input_ids
,
**
kwargs
)
bart_model
=
FakeBart
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
fake_output
=
bart_model
.
generate
(
input_ids
,
foo
=
"bar"
).
numpy
()
self
.
assertTrue
(
np
.
array_equal
(
output
,
fake_output
))
# Encoder signature filtering only kicks in if it doesn't accept wildcard kwargs. The following test will fail
# because it doesn't do signature filtering.
class
FakeEncoder
(
bart_model
.
model
.
encoder
.
__class__
):
def
call
(
self
,
input_ids
,
**
kwargs
):
return
super
().
call
(
input_ids
,
**
kwargs
)
fake_encoder
=
FakeEncoder
(
bart_model
.
config
,
bart_model
.
model
.
shared
)
bart_model
.
model
.
encoder
=
fake_encoder
# Normal generation still works (the output will be different because the encoder weights are different)
fake_output
=
bart_model
.
generate
(
input_ids
).
numpy
()
with
self
.
assertRaises
(
ValueError
):
# FakeEncoder.call() accepts **kwargs -> no filtering -> value error due to unexpected input "foo"
bart_model
.
generate
(
input_ids
,
foo
=
"bar"
)
tests/generation/test_utils.py
View file @
13e03e61
...
@@ -17,6 +17,8 @@
...
@@ -17,6 +17,8 @@
import
inspect
import
inspect
import
unittest
import
unittest
import
numpy
as
np
from
transformers
import
is_torch_available
,
pipeline
from
transformers
import
is_torch_available
,
pipeline
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
...
@@ -2439,30 +2441,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
...
@@ -2439,30 +2441,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
max_score_diff
=
(
output_sequences_batched
.
scores
[
0
][
1
]
-
output_sequences
.
scores
[
0
][
0
]).
abs
().
max
()
max_score_diff
=
(
output_sequences_batched
.
scores
[
0
][
1
]
-
output_sequences
.
scores
[
0
][
0
]).
abs
().
max
()
self
.
assertTrue
(
max_score_diff
<
1e-5
)
self
.
assertTrue
(
max_score_diff
<
1e-5
)
def
test_generate_from_input_embeds_decoder_only
(
self
):
# PT-only test: TF doesn't have a model with support to generate from input embeds (yet ;))
# Note: the model must support generation from input embeddings
model
=
AutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
text
=
"Hello world"
input_ids
=
tokenizer
.
encode
(
text
,
return_tensors
=
"pt"
)
# Traditional way of generating text
outputs_from_ids
=
model
.
generate
(
input_ids
)
# Same thing, but from input embeddings
inputs_embeds
=
model
.
transformer
.
wte
(
input_ids
)
outputs_from_embeds
=
model
.
generate
(
input_ids
,
inputs_embeds
=
inputs_embeds
)
self
.
assertListEqual
(
outputs_from_ids
.
tolist
(),
outputs_from_embeds
.
tolist
())
# But if we pass different inputs_embeds, we should get different outputs
torch
.
manual_seed
(
0
)
random_embeds
=
torch
.
rand_like
(
inputs_embeds
)
outputs_from_rand_embeds
=
model
.
generate
(
input_ids
,
inputs_embeds
=
random_embeds
)
with
self
.
assertRaises
(
AssertionError
):
self
.
assertListEqual
(
outputs_from_rand_embeds
.
tolist
(),
outputs_from_embeds
.
tolist
())
def
test_eos_token_id_int_and_list_top_k_top_sampling
(
self
):
def
test_eos_token_id_int_and_list_top_k_top_sampling
(
self
):
# Has TF equivalent: this test relies on random sampling
# Has TF equivalent: this test relies on random sampling
generation_kwargs
=
{
generation_kwargs
=
{
...
@@ -2490,6 +2468,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
...
@@ -2490,6 +2468,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self
.
assertTrue
(
expectation
==
len
(
generated_tokens
[
0
]))
self
.
assertTrue
(
expectation
==
len
(
generated_tokens
[
0
]))
def
test_generate_from_inputs_embeds_decoder_only
(
self
):
def
test_generate_from_inputs_embeds_decoder_only
(
self
):
# PT-only test: TF doesn't have a model with support to generate from input embeds (yet ;))
# Note: the model must support generation from input embeddings
# Note: the model must support generation from input embeddings
model
=
AutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
).
to
(
torch_device
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
).
to
(
torch_device
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
...
@@ -2523,3 +2502,40 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
...
@@ -2523,3 +2502,40 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
outputs_from_embeds
[:,
inputs_embeds
.
shape
[
1
]
:].
tolist
(),
outputs_from_embeds
[:,
inputs_embeds
.
shape
[
1
]
:].
tolist
(),
outputs_from_embeds_wo_ids
[:,
1
:].
tolist
(),
outputs_from_embeds_wo_ids
[:,
1
:].
tolist
(),
)
)
def
test_model_kwarg_encoder_signature_filtering
(
self
):
# Has TF equivalent: ample use of framework-specific code
bart_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
article
=
"""Hugging Face is a technology company based in New York and Paris."""
input_ids
=
bart_tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
bart_model
=
BartForConditionalGeneration
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
).
to
(
torch_device
)
output
=
bart_model
.
generate
(
input_ids
).
cpu
().
numpy
()
# Let's create a fake model that has a different signature. In particular, this fake model accepts "foo" as an
# argument. Because "foo" is not in the encoder signature and doesn't start with "decoder_", it will be part of
# the encoder kwargs prior to signature filtering, which would lead to an exception. But filtering kicks in and
# saves the day.
class
FakeBart
(
BartForConditionalGeneration
):
def
forward
(
self
,
input_ids
,
foo
=
None
,
**
kwargs
):
return
super
().
forward
(
input_ids
,
**
kwargs
)
bart_model
=
FakeBart
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
).
to
(
torch_device
)
fake_output
=
bart_model
.
generate
(
input_ids
,
foo
=
"bar"
).
cpu
().
numpy
()
self
.
assertTrue
(
np
.
array_equal
(
output
,
fake_output
))
# Encoder signature filtering only kicks in if it doesn't accept wildcard kwargs. The following test will fail
# because it doesn't do signature filtering.
class
FakeEncoder
(
bart_model
.
model
.
encoder
.
__class__
):
def
forward
(
self
,
input_ids
,
**
kwargs
):
return
super
().
forward
(
input_ids
,
**
kwargs
)
fake_encoder
=
FakeEncoder
(
bart_model
.
config
,
bart_model
.
model
.
shared
).
to
(
torch_device
)
bart_model
.
model
.
encoder
=
fake_encoder
# Normal generation still works (the output will be different because the encoder weights are different)
fake_output
=
bart_model
.
generate
(
input_ids
).
cpu
().
numpy
()
with
self
.
assertRaises
(
TypeError
):
# FakeEncoder.forward() accepts **kwargs -> no filtering -> type error due to unexpected input "foo"
bart_model
.
generate
(
input_ids
,
foo
=
"bar"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment