Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
13e03e61
Unverified
Commit
13e03e61
authored
Feb 14, 2023
by
Joao Gante
Committed by
GitHub
Feb 14, 2023
Browse files
Generate: filter encoder inputs when its signature does not accept wildcards (#21603)
parent
41fa672d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
94 additions
and
28 deletions
+94
-28
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+9
-3
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+7
-1
tests/generation/test_tf_utils.py
tests/generation/test_tf_utils.py
+38
-0
tests/generation/test_utils.py
tests/generation/test_utils.py
+40
-24
No files found.
src/transformers/generation/tf_utils.py
View file @
13e03e61
...
@@ -1078,18 +1078,24 @@ class TFGenerationMixin:
...
@@ -1078,18 +1078,24 @@ class TFGenerationMixin:
def
_prepare_encoder_decoder_kwargs_for_generation
(
def
_prepare_encoder_decoder_kwargs_for_generation
(
self
,
inputs_tensor
:
tf
.
Tensor
,
model_kwargs
,
model_input_name
:
Optional
[
str
]
=
None
self
,
inputs_tensor
:
tf
.
Tensor
,
model_kwargs
,
model_input_name
:
Optional
[
str
]
=
None
)
->
Dict
[
str
,
Any
]:
)
->
Dict
[
str
,
Any
]:
# get encoder and store encoder outputs
#
1.
get encoder and store encoder outputs
encoder
=
self
.
get_encoder
()
encoder
=
self
.
get_encoder
()
# prepare encoder args and encoder kwargs from model kwargs
#
2.
prepare encoder args and encoder kwargs from model kwargs
irrelevant_prefix
=
[
"decoder_"
,
"cross_attn"
,
"use_cache"
]
irrelevant_prefix
=
[
"decoder_"
,
"cross_attn"
,
"use_cache"
]
encoder_kwargs
=
{
encoder_kwargs
=
{
argument
:
value
argument
:
value
for
argument
,
value
in
model_kwargs
.
items
()
for
argument
,
value
in
model_kwargs
.
items
()
if
not
any
(
argument
.
startswith
(
p
)
for
p
in
irrelevant_prefix
)
if
not
any
(
argument
.
startswith
(
p
)
for
p
in
irrelevant_prefix
)
}
}
encoder_signature
=
set
(
inspect
.
signature
(
encoder
.
call
).
parameters
)
encoder_accepts_wildcard
=
"kwargs"
in
encoder_signature
or
"model_kwargs"
in
encoder_signature
if
not
encoder_accepts_wildcard
:
encoder_kwargs
=
{
argument
:
value
for
argument
,
value
in
encoder_kwargs
.
items
()
if
argument
in
encoder_signature
}
# vision models don't use `attention_mask`.
#
3.
vision models don't use `attention_mask`.
encoder_kwargs
[
"return_dict"
]
=
True
encoder_kwargs
[
"return_dict"
]
=
True
encoder_kwargs
[
model_input_name
]
=
inputs_tensor
encoder_kwargs
[
model_input_name
]
=
inputs_tensor
if
model_input_name
!=
self
.
main_input_name
:
# in Keras, the first input must always be passed
if
model_input_name
!=
self
.
main_input_name
:
# in Keras, the first input must always be passed
...
...
src/transformers/generation/utils.py
View file @
13e03e61
...
@@ -609,13 +609,19 @@ class GenerationMixin:
...
@@ -609,13 +609,19 @@ class GenerationMixin:
# 1. get encoder
# 1. get encoder
encoder
=
self
.
get_encoder
()
encoder
=
self
.
get_encoder
()
# 2.
p
repare encoder args and encoder kwargs from model kwargs
# 2.
P
repare encoder args and encoder kwargs from model kwargs
.
irrelevant_prefix
=
[
"decoder_"
,
"cross_attn"
,
"use_cache"
]
irrelevant_prefix
=
[
"decoder_"
,
"cross_attn"
,
"use_cache"
]
encoder_kwargs
=
{
encoder_kwargs
=
{
argument
:
value
argument
:
value
for
argument
,
value
in
model_kwargs
.
items
()
for
argument
,
value
in
model_kwargs
.
items
()
if
not
any
(
argument
.
startswith
(
p
)
for
p
in
irrelevant_prefix
)
if
not
any
(
argument
.
startswith
(
p
)
for
p
in
irrelevant_prefix
)
}
}
encoder_signature
=
set
(
inspect
.
signature
(
encoder
.
forward
).
parameters
)
encoder_accepts_wildcard
=
"kwargs"
in
encoder_signature
or
"model_kwargs"
in
encoder_signature
if
not
encoder_accepts_wildcard
:
encoder_kwargs
=
{
argument
:
value
for
argument
,
value
in
encoder_kwargs
.
items
()
if
argument
in
encoder_signature
}
# 3. make sure that encoder returns `ModelOutput`
# 3. make sure that encoder returns `ModelOutput`
model_input_name
=
model_input_name
if
model_input_name
is
not
None
else
self
.
main_input_name
model_input_name
=
model_input_name
if
model_input_name
is
not
None
else
self
.
main_input_name
...
...
tests/generation/test_tf_utils.py
View file @
13e03e61
...
@@ -16,6 +16,8 @@
...
@@ -16,6 +16,8 @@
import
tempfile
import
tempfile
import
unittest
import
unittest
import
numpy
as
np
from
transformers
import
is_tf_available
from
transformers
import
is_tf_available
from
transformers.testing_utils
import
require_tf
,
slow
from
transformers.testing_utils
import
require_tf
,
slow
...
@@ -32,6 +34,7 @@ if is_tf_available():
...
@@ -32,6 +34,7 @@ if is_tf_available():
TFAutoModelForSeq2SeqLM
,
TFAutoModelForSeq2SeqLM
,
TFAutoModelForSpeechSeq2Seq
,
TFAutoModelForSpeechSeq2Seq
,
TFAutoModelForVision2Seq
,
TFAutoModelForVision2Seq
,
TFBartForConditionalGeneration
,
TFLogitsProcessorList
,
TFLogitsProcessorList
,
TFMinLengthLogitsProcessor
,
TFMinLengthLogitsProcessor
,
tf_top_k_top_p_filtering
,
tf_top_k_top_p_filtering
,
...
@@ -264,3 +267,38 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
...
@@ -264,3 +267,38 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
tf
.
random
.
set_seed
(
0
)
tf
.
random
.
set_seed
(
0
)
generated_tokens
=
model
.
generate
(
**
tokens
,
eos_token_id
=
eos_token_id
,
**
generation_kwargs
)
generated_tokens
=
model
.
generate
(
**
tokens
,
eos_token_id
=
eos_token_id
,
**
generation_kwargs
)
self
.
assertTrue
(
expectation
==
len
(
generated_tokens
[
0
]))
self
.
assertTrue
(
expectation
==
len
(
generated_tokens
[
0
]))
def
test_model_kwarg_encoder_signature_filtering
(
self
):
# Has PT equivalent: ample use of framework-specific code
bart_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
article
=
"""Hugging Face is a technology company based in New York and Paris."""
input_ids
=
bart_tokenizer
(
article
,
return_tensors
=
"tf"
).
input_ids
bart_model
=
TFBartForConditionalGeneration
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
output
=
bart_model
.
generate
(
input_ids
).
numpy
()
# Let's create a fake model that has a different signature. In particular, this fake model accepts "foo" as an
# argument. Because "foo" is not in the encoder signature and doesn't start with "decoder_", it will be part of
# the encoder kwargs prior to signature filtering, which would lead to an exception. But filtering kicks in and
# saves the day.
class
FakeBart
(
TFBartForConditionalGeneration
):
def
call
(
self
,
input_ids
,
foo
=
None
,
**
kwargs
):
return
super
().
call
(
input_ids
,
**
kwargs
)
bart_model
=
FakeBart
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
fake_output
=
bart_model
.
generate
(
input_ids
,
foo
=
"bar"
).
numpy
()
self
.
assertTrue
(
np
.
array_equal
(
output
,
fake_output
))
# Encoder signature filtering only kicks in if it doesn't accept wildcard kwargs. The following test will fail
# because it doesn't do signature filtering.
class
FakeEncoder
(
bart_model
.
model
.
encoder
.
__class__
):
def
call
(
self
,
input_ids
,
**
kwargs
):
return
super
().
call
(
input_ids
,
**
kwargs
)
fake_encoder
=
FakeEncoder
(
bart_model
.
config
,
bart_model
.
model
.
shared
)
bart_model
.
model
.
encoder
=
fake_encoder
# Normal generation still works (the output will be different because the encoder weights are different)
fake_output
=
bart_model
.
generate
(
input_ids
).
numpy
()
with
self
.
assertRaises
(
ValueError
):
# FakeEncoder.call() accepts **kwargs -> no filtering -> value error due to unexpected input "foo"
bart_model
.
generate
(
input_ids
,
foo
=
"bar"
)
tests/generation/test_utils.py
View file @
13e03e61
...
@@ -17,6 +17,8 @@
...
@@ -17,6 +17,8 @@
import
inspect
import
inspect
import
unittest
import
unittest
import
numpy
as
np
from
transformers
import
is_torch_available
,
pipeline
from
transformers
import
is_torch_available
,
pipeline
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
...
@@ -2439,30 +2441,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
...
@@ -2439,30 +2441,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
max_score_diff
=
(
output_sequences_batched
.
scores
[
0
][
1
]
-
output_sequences
.
scores
[
0
][
0
]).
abs
().
max
()
max_score_diff
=
(
output_sequences_batched
.
scores
[
0
][
1
]
-
output_sequences
.
scores
[
0
][
0
]).
abs
().
max
()
self
.
assertTrue
(
max_score_diff
<
1e-5
)
self
.
assertTrue
(
max_score_diff
<
1e-5
)
def
test_generate_from_input_embeds_decoder_only
(
self
):
# PT-only test: TF doesn't have a model with support to generate from input embeds (yet ;))
# Note: the model must support generation from input embeddings
model
=
AutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
text
=
"Hello world"
input_ids
=
tokenizer
.
encode
(
text
,
return_tensors
=
"pt"
)
# Traditional way of generating text
outputs_from_ids
=
model
.
generate
(
input_ids
)
# Same thing, but from input embeddings
inputs_embeds
=
model
.
transformer
.
wte
(
input_ids
)
outputs_from_embeds
=
model
.
generate
(
input_ids
,
inputs_embeds
=
inputs_embeds
)
self
.
assertListEqual
(
outputs_from_ids
.
tolist
(),
outputs_from_embeds
.
tolist
())
# But if we pass different inputs_embeds, we should get different outputs
torch
.
manual_seed
(
0
)
random_embeds
=
torch
.
rand_like
(
inputs_embeds
)
outputs_from_rand_embeds
=
model
.
generate
(
input_ids
,
inputs_embeds
=
random_embeds
)
with
self
.
assertRaises
(
AssertionError
):
self
.
assertListEqual
(
outputs_from_rand_embeds
.
tolist
(),
outputs_from_embeds
.
tolist
())
def
test_eos_token_id_int_and_list_top_k_top_sampling
(
self
):
def
test_eos_token_id_int_and_list_top_k_top_sampling
(
self
):
# Has TF equivalent: this test relies on random sampling
# Has TF equivalent: this test relies on random sampling
generation_kwargs
=
{
generation_kwargs
=
{
...
@@ -2490,6 +2468,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
...
@@ -2490,6 +2468,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self
.
assertTrue
(
expectation
==
len
(
generated_tokens
[
0
]))
self
.
assertTrue
(
expectation
==
len
(
generated_tokens
[
0
]))
def
test_generate_from_inputs_embeds_decoder_only
(
self
):
def
test_generate_from_inputs_embeds_decoder_only
(
self
):
# PT-only test: TF doesn't have a model with support to generate from input embeds (yet ;))
# Note: the model must support generation from input embeddings
# Note: the model must support generation from input embeddings
model
=
AutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
).
to
(
torch_device
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
).
to
(
torch_device
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
...
@@ -2523,3 +2502,40 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
...
@@ -2523,3 +2502,40 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
outputs_from_embeds
[:,
inputs_embeds
.
shape
[
1
]
:].
tolist
(),
outputs_from_embeds
[:,
inputs_embeds
.
shape
[
1
]
:].
tolist
(),
outputs_from_embeds_wo_ids
[:,
1
:].
tolist
(),
outputs_from_embeds_wo_ids
[:,
1
:].
tolist
(),
)
)
def
test_model_kwarg_encoder_signature_filtering
(
self
):
# Has TF equivalent: ample use of framework-specific code
bart_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
)
article
=
"""Hugging Face is a technology company based in New York and Paris."""
input_ids
=
bart_tokenizer
(
article
,
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
bart_model
=
BartForConditionalGeneration
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
).
to
(
torch_device
)
output
=
bart_model
.
generate
(
input_ids
).
cpu
().
numpy
()
# Let's create a fake model that has a different signature. In particular, this fake model accepts "foo" as an
# argument. Because "foo" is not in the encoder signature and doesn't start with "decoder_", it will be part of
# the encoder kwargs prior to signature filtering, which would lead to an exception. But filtering kicks in and
# saves the day.
class
FakeBart
(
BartForConditionalGeneration
):
def
forward
(
self
,
input_ids
,
foo
=
None
,
**
kwargs
):
return
super
().
forward
(
input_ids
,
**
kwargs
)
bart_model
=
FakeBart
.
from_pretrained
(
"hf-internal-testing/tiny-random-bart"
).
to
(
torch_device
)
fake_output
=
bart_model
.
generate
(
input_ids
,
foo
=
"bar"
).
cpu
().
numpy
()
self
.
assertTrue
(
np
.
array_equal
(
output
,
fake_output
))
# Encoder signature filtering only kicks in if it doesn't accept wildcard kwargs. The following test will fail
# because it doesn't do signature filtering.
class
FakeEncoder
(
bart_model
.
model
.
encoder
.
__class__
):
def
forward
(
self
,
input_ids
,
**
kwargs
):
return
super
().
forward
(
input_ids
,
**
kwargs
)
fake_encoder
=
FakeEncoder
(
bart_model
.
config
,
bart_model
.
model
.
shared
).
to
(
torch_device
)
bart_model
.
model
.
encoder
=
fake_encoder
# Normal generation still works (the output will be different because the encoder weights are different)
fake_output
=
bart_model
.
generate
(
input_ids
).
cpu
().
numpy
()
with
self
.
assertRaises
(
TypeError
):
# FakeEncoder.forward() accepts **kwargs -> no filtering -> type error due to unexpected input "foo"
bart_model
.
generate
(
input_ids
,
foo
=
"bar"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment