Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9196f48b
Unverified
Commit
9196f48b
authored
Sep 02, 2022
by
Joao Gante
Committed by
GitHub
Sep 02, 2022
Browse files
Generate: validate `model_kwargs` on TF (and catch typos in generate arguments) (#18651)
parent
c5be7cae
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
214 additions
and
139 deletions
+214
-139
src/transformers/generation_tf_utils.py
src/transformers/generation_tf_utils.py
+29
-1
tests/generation/test_generation_tf_utils.py
tests/generation/test_generation_tf_utils.py
+183
-0
tests/generation/test_generation_utils.py
tests/generation/test_generation_utils.py
+2
-2
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+0
-136
No files found.
src/transformers/generation_tf_utils.py
View file @
9196f48b
...
@@ -579,6 +579,7 @@ class TFGenerationMixin:
...
@@ -579,6 +579,7 @@ class TFGenerationMixin:
do_sample
=
do_sample
if
do_sample
is
not
None
else
self
.
config
.
do_sample
do_sample
=
do_sample
if
do_sample
is
not
None
else
self
.
config
.
do_sample
if
do_sample
is
False
or
num_beams
==
1
:
if
do_sample
is
False
or
num_beams
==
1
:
seed
=
model_kwargs
.
pop
(
"seed"
,
None
)
return
self
.
_generate
(
return
self
.
_generate
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
max_length
=
max_length
,
max_length
=
max_length
,
...
@@ -601,13 +602,14 @@ class TFGenerationMixin:
...
@@ -601,13 +602,14 @@ class TFGenerationMixin:
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
decoder_start_token_id
=
decoder_start_token_id
,
decoder_start_token_id
=
decoder_start_token_id
,
use_cache
=
use_cache
,
use_cache
=
use_cache
,
seed
=
model_kwargs
.
pop
(
"seed"
,
None
)
,
seed
=
seed
,
output_scores
=
output_scores
,
output_scores
=
output_scores
,
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
return_dict_in_generate
=
return_dict_in_generate
,
return_dict_in_generate
=
return_dict_in_generate
,
forced_bos_token_id
=
forced_bos_token_id
,
forced_bos_token_id
=
forced_bos_token_id
,
forced_eos_token_id
=
forced_eos_token_id
,
forced_eos_token_id
=
forced_eos_token_id
,
**
model_kwargs
,
)
)
# We cannot generate if the model does not have a LM head
# We cannot generate if the model does not have a LM head
...
@@ -1288,6 +1290,29 @@ class TFGenerationMixin:
...
@@ -1288,6 +1290,29 @@ class TFGenerationMixin:
else
:
else
:
return
logits
return
logits
def
_validate_model_kwargs
(
self
,
model_kwargs
:
Dict
[
str
,
Any
]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# Excludes arguments that are handled before calling any model function
if
self
.
config
.
is_encoder_decoder
:
for
key
in
[
"decoder_input_ids"
]:
model_kwargs
.
pop
(
key
,
None
)
unused_model_args
=
[]
model_args
=
set
(
inspect
.
signature
(
self
.
prepare_inputs_for_generation
).
parameters
)
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
if
"kwargs"
in
model_args
:
model_args
|=
set
(
inspect
.
signature
(
self
.
call
).
parameters
)
for
key
,
value
in
model_kwargs
.
items
():
if
value
is
not
None
and
key
not
in
model_args
:
unused_model_args
.
append
(
key
)
if
unused_model_args
:
raise
ValueError
(
f
"The following `model_kwargs` are not used by the model:
{
unused_model_args
}
(note: typos in the"
" generate arguments will also show up in this list)"
)
def
_generate
(
def
_generate
(
self
,
self
,
input_ids
=
None
,
input_ids
=
None
,
...
@@ -1483,6 +1508,9 @@ class TFGenerationMixin:
...
@@ -1483,6 +1508,9 @@ class TFGenerationMixin:
# generate sequences without allowing bad_words to be generated
# generate sequences without allowing bad_words to be generated
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)
```"""
```"""
# 0. Validate model kwargs
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# 1. Set generation parameters if not already defined
# 1. Set generation parameters if not already defined
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
config
.
length_penalty
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
config
.
length_penalty
early_stopping
=
early_stopping
if
early_stopping
is
not
None
else
self
.
config
.
early_stopping
early_stopping
=
early_stopping
if
early_stopping
is
not
None
else
self
.
config
.
early_stopping
...
...
tests/generation/test_generation_tf_utils.py
0 → 100644
View file @
9196f48b
# coding=utf-8
# Copyright 2022 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
tempfile
import
unittest
from
transformers
import
is_tf_available
from
transformers.testing_utils
import
require_tf
,
slow
if
is_tf_available
():
import
tensorflow
as
tf
from
transformers
import
AutoTokenizer
,
TFAutoModelForCausalLM
,
TFAutoModelForSeq2SeqLM
,
tf_top_k_top_p_filtering
@
require_tf
class
UtilsFunctionsTest
(
unittest
.
TestCase
):
# tests whether the top_k_top_p_filtering function behaves as expected
def
test_top_k_top_p_filtering
(
self
):
logits
=
tf
.
convert_to_tensor
(
[
[
8.2220991
,
# 3rd highest value; idx. 0
-
0.5620044
,
5.23229752
,
4.0386393
,
-
6.8798378
,
-
0.54785802
,
-
3.2012153
,
2.92777176
,
1.88171953
,
7.35341276
,
# 5th highest value; idx. 9
8.43207833
,
# 2nd highest value; idx. 10
-
9.85711836
,
-
5.96209236
,
-
1.13039161
,
-
7.1115294
,
-
0.8369633
,
-
5.3186408
,
7.06427407
,
0.81369344
,
-
0.82023817
,
-
5.9179796
,
0.58813443
,
-
6.99778438
,
4.71551189
,
-
0.18771637
,
7.44020759
,
# 4th highest value; idx. 25
9.38450987
,
# 1st highest value; idx. 26
2.12662941
,
-
9.32562038
,
2.35652522
,
],
# cummulative prob of 5 highest values <= 0.6
[
0.58425518
,
4.53139238
,
-
5.57510464
,
-
6.28030699
,
-
7.19529503
,
-
4.02122551
,
1.39337037
,
-
6.06707057
,
1.59480517
,
-
9.643119
,
0.03907799
,
0.67231762
,
-
8.88206726
,
6.27115922
,
# 4th highest value; idx. 13
2.28520723
,
4.82767506
,
4.30421368
,
8.8275313
,
# 2nd highest value; idx. 17
5.44029958
,
# 5th highest value; idx. 18
-
4.4735794
,
7.38579536
,
# 3rd highest value; idx. 20
-
2.91051663
,
2.61946077
,
-
2.5674762
,
-
9.48959302
,
-
4.02922645
,
-
1.35416918
,
9.67702323
,
# 1st highest value; idx. 27
-
5.89478553
,
1.85370467
,
],
# cummulative prob of 5 highest values <= 0.6
],
dtype
=
tf
.
float32
,
)
non_inf_expected_idx
=
tf
.
convert_to_tensor
(
[[
0
,
0
],
[
0
,
9
],
[
0
,
10
],
[
0
,
25
],
[
0
,
26
],
[
1
,
13
],
[
1
,
17
],
[
1
,
18
],
[
1
,
20
],
[
1
,
27
]],
dtype
=
tf
.
int32
,
)
# expected non filtered idx as noted above
non_inf_expected_output
=
tf
.
convert_to_tensor
(
[
8.222099
,
7.3534126
,
8.432078
,
7.4402075
,
9.38451
,
6.271159
,
8.827531
,
5.4402995
,
7.3857956
,
9.677023
],
dtype
=
tf
.
float32
,
)
# expected non filtered values as noted above
output
=
tf_top_k_top_p_filtering
(
logits
,
top_k
=
10
,
top_p
=
0.6
,
min_tokens_to_keep
=
4
)
non_inf_output
=
output
[
output
!=
-
float
(
"inf"
)]
non_inf_idx
=
tf
.
cast
(
tf
.
where
(
tf
.
not_equal
(
output
,
tf
.
constant
(
-
float
(
"inf"
),
dtype
=
tf
.
float32
))),
dtype
=
tf
.
int32
,
)
tf
.
debugging
.
assert_near
(
non_inf_output
,
non_inf_expected_output
,
rtol
=
1e-12
)
tf
.
debugging
.
assert_equal
(
non_inf_idx
,
non_inf_expected_idx
)
@
require_tf
class
TFGenerationIntegrationTests
(
unittest
.
TestCase
):
@
slow
def
test_generate_tf_function_export
(
self
):
test_model
=
TFAutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
max_length
=
2
class
DummyModel
(
tf
.
Module
):
def
__init__
(
self
,
model
):
super
(
DummyModel
,
self
).
__init__
()
self
.
model
=
model
@
tf
.
function
(
input_signature
=
(
tf
.
TensorSpec
((
None
,
max_length
),
tf
.
int32
,
name
=
"input_ids"
),
tf
.
TensorSpec
((
None
,
max_length
),
tf
.
int32
,
name
=
"attention_mask"
),
),
jit_compile
=
True
,
)
def
serving
(
self
,
input_ids
,
attention_mask
):
outputs
=
self
.
model
.
generate
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
max_new_tokens
=
max_length
,
return_dict_in_generate
=
True
,
)
return
{
"sequences"
:
outputs
[
"sequences"
]}
dummy_input_ids
=
[[
2
,
0
],
[
102
,
103
]]
dummy_attention_masks
=
[[
1
,
0
],
[
1
,
1
]]
dummy_model
=
DummyModel
(
model
=
test_model
)
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
tf
.
saved_model
.
save
(
dummy_model
,
tmp_dir
,
signatures
=
{
"serving_default"
:
dummy_model
.
serving
})
serving_func
=
tf
.
saved_model
.
load
(
tmp_dir
).
signatures
[
"serving_default"
]
for
batch_size
in
range
(
1
,
len
(
dummy_input_ids
)
+
1
):
inputs
=
{
"input_ids"
:
tf
.
constant
(
dummy_input_ids
[:
batch_size
]),
"attention_mask"
:
tf
.
constant
(
dummy_attention_masks
[:
batch_size
]),
}
tf_func_outputs
=
serving_func
(
**
inputs
)[
"sequences"
]
tf_model_outputs
=
test_model
.
generate
(
**
inputs
,
max_new_tokens
=
max_length
)
tf
.
debugging
.
assert_equal
(
tf_func_outputs
,
tf_model_outputs
)
def
test_validate_generation_inputs
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-t5"
)
model
=
TFAutoModelForSeq2SeqLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-t5"
)
encoder_input_str
=
"Hello world"
input_ids
=
tokenizer
(
encoder_input_str
,
return_tensors
=
"tf"
).
input_ids
# typos are quickly detected (the correct argument is `do_sample`)
with
self
.
assertRaisesRegex
(
ValueError
,
"do_samples"
):
model
.
generate
(
input_ids
,
do_samples
=
True
)
# arbitrary arguments that will not be used anywhere are also not accepted
with
self
.
assertRaisesRegex
(
ValueError
,
"foo"
):
fake_model_kwargs
=
{
"foo"
:
"bar"
}
model
.
generate
(
input_ids
,
**
fake_model_kwargs
)
tests/generation/test_generation_utils.py
View file @
9196f48b
...
@@ -2704,8 +2704,8 @@ class GenerationIntegrationTests(unittest.TestCase):
...
@@ -2704,8 +2704,8 @@ class GenerationIntegrationTests(unittest.TestCase):
model
.
generate
(
input_ids
,
force_words_ids
=
[[[
-
1
]]])
model
.
generate
(
input_ids
,
force_words_ids
=
[[[
-
1
]]])
def
test_validate_generation_inputs
(
self
):
def
test_validate_generation_inputs
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"
patrickvonplaten/t5-
tiny-random"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"
hf-internal-testing/
tiny-random
-t5
"
)
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
"
patrickvonplaten/t5-
tiny-random"
)
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
"
hf-internal-testing/
tiny-random
-t5
"
)
encoder_input_str
=
"Hello world"
encoder_input_str
=
"Hello world"
input_ids
=
tokenizer
(
encoder_input_str
,
return_tensors
=
"pt"
).
input_ids
input_ids
=
tokenizer
(
encoder_input_str
,
return_tensors
=
"pt"
).
input_ids
...
...
tests/test_modeling_tf_common.py
View file @
9196f48b
...
@@ -75,11 +75,9 @@ if is_tf_available():
...
@@ -75,11 +75,9 @@ if is_tf_available():
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
,
BertConfig
,
BertConfig
,
TFAutoModel
,
TFAutoModel
,
TFAutoModelForCausalLM
,
TFAutoModelForSequenceClassification
,
TFAutoModelForSequenceClassification
,
TFBertModel
,
TFBertModel
,
TFSharedEmbeddings
,
TFSharedEmbeddings
,
tf_top_k_top_p_filtering
,
)
)
from
transformers.generation_tf_utils
import
(
from
transformers.generation_tf_utils
import
(
TFBeamSampleDecoderOnlyOutput
,
TFBeamSampleDecoderOnlyOutput
,
...
@@ -1824,100 +1822,6 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None, dtype=None):
...
@@ -1824,100 +1822,6 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None, dtype=None):
@
require_tf
@
require_tf
class
UtilsFunctionsTest
(
unittest
.
TestCase
):
class
UtilsFunctionsTest
(
unittest
.
TestCase
):
# tests whether the top_k_top_p_filtering function behaves as expected
def
test_top_k_top_p_filtering
(
self
):
logits
=
tf
.
convert_to_tensor
(
[
[
8.2220991
,
# 3rd highest value; idx. 0
-
0.5620044
,
5.23229752
,
4.0386393
,
-
6.8798378
,
-
0.54785802
,
-
3.2012153
,
2.92777176
,
1.88171953
,
7.35341276
,
# 5th highest value; idx. 9
8.43207833
,
# 2nd highest value; idx. 10
-
9.85711836
,
-
5.96209236
,
-
1.13039161
,
-
7.1115294
,
-
0.8369633
,
-
5.3186408
,
7.06427407
,
0.81369344
,
-
0.82023817
,
-
5.9179796
,
0.58813443
,
-
6.99778438
,
4.71551189
,
-
0.18771637
,
7.44020759
,
# 4th highest value; idx. 25
9.38450987
,
# 1st highest value; idx. 26
2.12662941
,
-
9.32562038
,
2.35652522
,
],
# cummulative prob of 5 highest values <= 0.6
[
0.58425518
,
4.53139238
,
-
5.57510464
,
-
6.28030699
,
-
7.19529503
,
-
4.02122551
,
1.39337037
,
-
6.06707057
,
1.59480517
,
-
9.643119
,
0.03907799
,
0.67231762
,
-
8.88206726
,
6.27115922
,
# 4th highest value; idx. 13
2.28520723
,
4.82767506
,
4.30421368
,
8.8275313
,
# 2nd highest value; idx. 17
5.44029958
,
# 5th highest value; idx. 18
-
4.4735794
,
7.38579536
,
# 3rd highest value; idx. 20
-
2.91051663
,
2.61946077
,
-
2.5674762
,
-
9.48959302
,
-
4.02922645
,
-
1.35416918
,
9.67702323
,
# 1st highest value; idx. 27
-
5.89478553
,
1.85370467
,
],
# cummulative prob of 5 highest values <= 0.6
],
dtype
=
tf
.
float32
,
)
non_inf_expected_idx
=
tf
.
convert_to_tensor
(
[[
0
,
0
],
[
0
,
9
],
[
0
,
10
],
[
0
,
25
],
[
0
,
26
],
[
1
,
13
],
[
1
,
17
],
[
1
,
18
],
[
1
,
20
],
[
1
,
27
]],
dtype
=
tf
.
int32
,
)
# expected non filtered idx as noted above
non_inf_expected_output
=
tf
.
convert_to_tensor
(
[
8.222099
,
7.3534126
,
8.432078
,
7.4402075
,
9.38451
,
6.271159
,
8.827531
,
5.4402995
,
7.3857956
,
9.677023
],
dtype
=
tf
.
float32
,
)
# expected non filtered values as noted above
output
=
tf_top_k_top_p_filtering
(
logits
,
top_k
=
10
,
top_p
=
0.6
,
min_tokens_to_keep
=
4
)
non_inf_output
=
output
[
output
!=
-
float
(
"inf"
)]
non_inf_idx
=
tf
.
cast
(
tf
.
where
(
tf
.
not_equal
(
output
,
tf
.
constant
(
-
float
(
"inf"
),
dtype
=
tf
.
float32
))),
dtype
=
tf
.
int32
,
)
tf
.
debugging
.
assert_near
(
non_inf_output
,
non_inf_expected_output
,
rtol
=
1e-12
)
tf
.
debugging
.
assert_equal
(
non_inf_idx
,
non_inf_expected_idx
)
def
test_cached_files_are_used_when_internet_is_down
(
self
):
def
test_cached_files_are_used_when_internet_is_down
(
self
):
# A mock response for an HTTP head request to emulate server down
# A mock response for an HTTP head request to emulate server down
response_mock
=
mock
.
Mock
()
response_mock
=
mock
.
Mock
()
...
@@ -2179,46 +2083,6 @@ class UtilsFunctionsTest(unittest.TestCase):
...
@@ -2179,46 +2083,6 @@ class UtilsFunctionsTest(unittest.TestCase):
for
p1
,
p2
in
zip
(
model
.
weights
,
new_model
.
weights
):
for
p1
,
p2
in
zip
(
model
.
weights
,
new_model
.
weights
):
self
.
assertTrue
(
np
.
allclose
(
p1
.
numpy
(),
p2
.
numpy
()))
self
.
assertTrue
(
np
.
allclose
(
p1
.
numpy
(),
p2
.
numpy
()))
def
test_generate_tf_function_export
(
self
):
test_model
=
TFAutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
max_length
=
2
class
DummyModel
(
tf
.
Module
):
def
__init__
(
self
,
model
):
super
(
DummyModel
,
self
).
__init__
()
self
.
model
=
model
@
tf
.
function
(
input_signature
=
(
tf
.
TensorSpec
((
None
,
max_length
),
tf
.
int32
,
name
=
"input_ids"
),
tf
.
TensorSpec
((
None
,
max_length
),
tf
.
int32
,
name
=
"attention_mask"
),
),
jit_compile
=
True
,
)
def
serving
(
self
,
input_ids
,
attention_mask
):
outputs
=
self
.
model
.
generate
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
max_new_tokens
=
max_length
,
return_dict_in_generate
=
True
,
)
return
{
"sequences"
:
outputs
[
"sequences"
]}
dummy_input_ids
=
[[
2
,
0
],
[
102
,
103
]]
dummy_attention_masks
=
[[
1
,
0
],
[
1
,
1
]]
dummy_model
=
DummyModel
(
model
=
test_model
)
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
tf
.
saved_model
.
save
(
dummy_model
,
tmp_dir
,
signatures
=
{
"serving_default"
:
dummy_model
.
serving
})
serving_func
=
tf
.
saved_model
.
load
(
tmp_dir
).
signatures
[
"serving_default"
]
for
batch_size
in
range
(
1
,
len
(
dummy_input_ids
)
+
1
):
inputs
=
{
"input_ids"
:
tf
.
constant
(
dummy_input_ids
[:
batch_size
]),
"attention_mask"
:
tf
.
constant
(
dummy_attention_masks
[:
batch_size
]),
}
tf_func_outputs
=
serving_func
(
**
inputs
)[
"sequences"
]
tf_model_outputs
=
test_model
.
generate
(
**
inputs
,
max_new_tokens
=
max_length
)
tf
.
debugging
.
assert_equal
(
tf_func_outputs
,
tf_model_outputs
)
@
require_tf
@
require_tf
@
is_staging_test
@
is_staging_test
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment