Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9196f48b
"docs/source/vscode:/vscode.git/clone" did not exist on "89073a95ba2ba5a6e65fa5a54db6184fb44a8f99"
Unverified
Commit
9196f48b
authored
Sep 02, 2022
by
Joao Gante
Committed by
GitHub
Sep 02, 2022
Browse files
Generate: validate `model_kwargs` on TF (and catch typos in generate arguments) (#18651)
parent
c5be7cae
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
214 additions
and
139 deletions
+214
-139
src/transformers/generation_tf_utils.py
src/transformers/generation_tf_utils.py
+29
-1
tests/generation/test_generation_tf_utils.py
tests/generation/test_generation_tf_utils.py
+183
-0
tests/generation/test_generation_utils.py
tests/generation/test_generation_utils.py
+2
-2
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+0
-136
No files found.
src/transformers/generation_tf_utils.py
View file @
9196f48b
...
...
@@ -579,6 +579,7 @@ class TFGenerationMixin:
do_sample
=
do_sample
if
do_sample
is
not
None
else
self
.
config
.
do_sample
if
do_sample
is
False
or
num_beams
==
1
:
seed
=
model_kwargs
.
pop
(
"seed"
,
None
)
return
self
.
_generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
...
...
@@ -601,13 +602,14 @@ class TFGenerationMixin:
attention_mask
=
attention_mask
,
decoder_start_token_id
=
decoder_start_token_id
,
use_cache
=
use_cache
,
seed
=
model_kwargs
.
pop
(
"seed"
,
None
)
,
seed
=
seed
,
output_scores
=
output_scores
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict_in_generate
=
return_dict_in_generate
,
forced_bos_token_id
=
forced_bos_token_id
,
forced_eos_token_id
=
forced_eos_token_id
,
**
model_kwargs
,
)
# We cannot generate if the model does not have a LM head
...
...
@@ -1288,6 +1290,29 @@ class TFGenerationMixin:
else
:
return
logits
def
_validate_model_kwargs
(
self
,
model_kwargs
:
Dict
[
str
,
Any
]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# Excludes arguments that are handled before calling any model function
if
self
.
config
.
is_encoder_decoder
:
for
key
in
[
"decoder_input_ids"
]:
model_kwargs
.
pop
(
key
,
None
)
unused_model_args
=
[]
model_args
=
set
(
inspect
.
signature
(
self
.
prepare_inputs_for_generation
).
parameters
)
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
if
"kwargs"
in
model_args
:
model_args
|=
set
(
inspect
.
signature
(
self
.
call
).
parameters
)
for
key
,
value
in
model_kwargs
.
items
():
if
value
is
not
None
and
key
not
in
model_args
:
unused_model_args
.
append
(
key
)
if
unused_model_args
:
raise
ValueError
(
f
"The following `model_kwargs` are not used by the model:
{
unused_model_args
}
(note: typos in the"
" generate arguments will also show up in this list)"
)
def
_generate
(
self
,
input_ids
=
None
,
...
...
@@ -1483,6 +1508,9 @@ class TFGenerationMixin:
# generate sequences without allowing bad_words to be generated
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)
```"""
# 0. Validate model kwargs
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# 1. Set generation parameters if not already defined
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
config
.
length_penalty
early_stopping
=
early_stopping
if
early_stopping
is
not
None
else
self
.
config
.
early_stopping
...
...
tests/generation/test_generation_tf_utils.py
0 → 100644
View file @
9196f48b
# coding=utf-8
# Copyright 2022 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
tempfile
import
unittest
from
transformers
import
is_tf_available
from
transformers.testing_utils
import
require_tf
,
slow
if
is_tf_available
():
import
tensorflow
as
tf
from
transformers
import
AutoTokenizer
,
TFAutoModelForCausalLM
,
TFAutoModelForSeq2SeqLM
,
tf_top_k_top_p_filtering
@
require_tf
class
UtilsFunctionsTest
(
unittest
.
TestCase
):
# tests whether the top_k_top_p_filtering function behaves as expected
def
test_top_k_top_p_filtering
(
self
):
logits
=
tf
.
convert_to_tensor
(
[
[
8.2220991
,
# 3rd highest value; idx. 0
-
0.5620044
,
5.23229752
,
4.0386393
,
-
6.8798378
,
-
0.54785802
,
-
3.2012153
,
2.92777176
,
1.88171953
,
7.35341276
,
# 5th highest value; idx. 9
8.43207833
,
# 2nd highest value; idx. 10
-
9.85711836
,
-
5.96209236
,
-
1.13039161
,
-
7.1115294
,
-
0.8369633
,
-
5.3186408
,
7.06427407
,
0.81369344
,
-
0.82023817
,
-
5.9179796
,
0.58813443
,
-
6.99778438
,
4.71551189
,
-
0.18771637
,
7.44020759
,
# 4th highest value; idx. 25
9.38450987
,
# 1st highest value; idx. 26
2.12662941
,
-
9.32562038
,
2.35652522
,
],
# cummulative prob of 5 highest values <= 0.6
[
0.58425518
,
4.53139238
,
-
5.57510464
,
-
6.28030699
,
-
7.19529503
,
-
4.02122551
,
1.39337037
,
-
6.06707057
,
1.59480517
,
-
9.643119
,
0.03907799
,
0.67231762
,
-
8.88206726
,
6.27115922
,
# 4th highest value; idx. 13
2.28520723
,
4.82767506
,
4.30421368
,
8.8275313
,
# 2nd highest value; idx. 17
5.44029958
,
# 5th highest value; idx. 18
-
4.4735794
,
7.38579536
,
# 3rd highest value; idx. 20
-
2.91051663
,
2.61946077
,
-
2.5674762
,
-
9.48959302
,
-
4.02922645
,
-
1.35416918
,
9.67702323
,
# 1st highest value; idx. 27
-
5.89478553
,
1.85370467
,
],
# cummulative prob of 5 highest values <= 0.6
],
dtype
=
tf
.
float32
,
)
non_inf_expected_idx
=
tf
.
convert_to_tensor
(
[[
0
,
0
],
[
0
,
9
],
[
0
,
10
],
[
0
,
25
],
[
0
,
26
],
[
1
,
13
],
[
1
,
17
],
[
1
,
18
],
[
1
,
20
],
[
1
,
27
]],
dtype
=
tf
.
int32
,
)
# expected non filtered idx as noted above
non_inf_expected_output
=
tf
.
convert_to_tensor
(
[
8.222099
,
7.3534126
,
8.432078
,
7.4402075
,
9.38451
,
6.271159
,
8.827531
,
5.4402995
,
7.3857956
,
9.677023
],
dtype
=
tf
.
float32
,
)
# expected non filtered values as noted above
output
=
tf_top_k_top_p_filtering
(
logits
,
top_k
=
10
,
top_p
=
0.6
,
min_tokens_to_keep
=
4
)
non_inf_output
=
output
[
output
!=
-
float
(
"inf"
)]
non_inf_idx
=
tf
.
cast
(
tf
.
where
(
tf
.
not_equal
(
output
,
tf
.
constant
(
-
float
(
"inf"
),
dtype
=
tf
.
float32
))),
dtype
=
tf
.
int32
,
)
tf
.
debugging
.
assert_near
(
non_inf_output
,
non_inf_expected_output
,
rtol
=
1e-12
)
tf
.
debugging
.
assert_equal
(
non_inf_idx
,
non_inf_expected_idx
)
@
require_tf
class
TFGenerationIntegrationTests
(
unittest
.
TestCase
):
@
slow
def
test_generate_tf_function_export
(
self
):
test_model
=
TFAutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
max_length
=
2
class
DummyModel
(
tf
.
Module
):
def
__init__
(
self
,
model
):
super
(
DummyModel
,
self
).
__init__
()
self
.
model
=
model
@
tf
.
function
(
input_signature
=
(
tf
.
TensorSpec
((
None
,
max_length
),
tf
.
int32
,
name
=
"input_ids"
),
tf
.
TensorSpec
((
None
,
max_length
),
tf
.
int32
,
name
=
"attention_mask"
),
),
jit_compile
=
True
,
)
def
serving
(
self
,
input_ids
,
attention_mask
):
outputs
=
self
.
model
.
generate
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
max_new_tokens
=
max_length
,
return_dict_in_generate
=
True
,
)
return
{
"sequences"
:
outputs
[
"sequences"
]}
dummy_input_ids
=
[[
2
,
0
],
[
102
,
103
]]
dummy_attention_masks
=
[[
1
,
0
],
[
1
,
1
]]
dummy_model
=
DummyModel
(
model
=
test_model
)
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
tf
.
saved_model
.
save
(
dummy_model
,
tmp_dir
,
signatures
=
{
"serving_default"
:
dummy_model
.
serving
})
serving_func
=
tf
.
saved_model
.
load
(
tmp_dir
).
signatures
[
"serving_default"
]
for
batch_size
in
range
(
1
,
len
(
dummy_input_ids
)
+
1
):
inputs
=
{
"input_ids"
:
tf
.
constant
(
dummy_input_ids
[:
batch_size
]),
"attention_mask"
:
tf
.
constant
(
dummy_attention_masks
[:
batch_size
]),
}
tf_func_outputs
=
serving_func
(
**
inputs
)[
"sequences"
]
tf_model_outputs
=
test_model
.
generate
(
**
inputs
,
max_new_tokens
=
max_length
)
tf
.
debugging
.
assert_equal
(
tf_func_outputs
,
tf_model_outputs
)
def
test_validate_generation_inputs
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-t5"
)
model
=
TFAutoModelForSeq2SeqLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-t5"
)
encoder_input_str
=
"Hello world"
input_ids
=
tokenizer
(
encoder_input_str
,
return_tensors
=
"tf"
).
input_ids
# typos are quickly detected (the correct argument is `do_sample`)
with
self
.
assertRaisesRegex
(
ValueError
,
"do_samples"
):
model
.
generate
(
input_ids
,
do_samples
=
True
)
# arbitrary arguments that will not be used anywhere are also not accepted
with
self
.
assertRaisesRegex
(
ValueError
,
"foo"
):
fake_model_kwargs
=
{
"foo"
:
"bar"
}
model
.
generate
(
input_ids
,
**
fake_model_kwargs
)
tests/generation/test_generation_utils.py
View file @
9196f48b
...
...
@@ -2704,8 +2704,8 @@ class GenerationIntegrationTests(unittest.TestCase):
model
.
generate
(
input_ids
,
force_words_ids
=
[[[
-
1
]]])
def
test_validate_generation_inputs
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"
patrickvonplaten/t5-
tiny-random"
)
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
"
patrickvonplaten/t5-
tiny-random"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"
hf-internal-testing/
tiny-random
-t5
"
)
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
"
hf-internal-testing/
tiny-random
-t5
"
)
encoder_input_str
=
"Hello world"
input_ids
=
tokenizer
(
encoder_input_str
,
return_tensors
=
"pt"
).
input_ids
...
...
tests/test_modeling_tf_common.py
View file @
9196f48b
...
...
@@ -75,11 +75,9 @@ if is_tf_available():
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
,
BertConfig
,
TFAutoModel
,
TFAutoModelForCausalLM
,
TFAutoModelForSequenceClassification
,
TFBertModel
,
TFSharedEmbeddings
,
tf_top_k_top_p_filtering
,
)
from
transformers.generation_tf_utils
import
(
TFBeamSampleDecoderOnlyOutput
,
...
...
@@ -1824,100 +1822,6 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None, dtype=None):
@
require_tf
class
UtilsFunctionsTest
(
unittest
.
TestCase
):
# tests whether the top_k_top_p_filtering function behaves as expected
def
test_top_k_top_p_filtering
(
self
):
logits
=
tf
.
convert_to_tensor
(
[
[
8.2220991
,
# 3rd highest value; idx. 0
-
0.5620044
,
5.23229752
,
4.0386393
,
-
6.8798378
,
-
0.54785802
,
-
3.2012153
,
2.92777176
,
1.88171953
,
7.35341276
,
# 5th highest value; idx. 9
8.43207833
,
# 2nd highest value; idx. 10
-
9.85711836
,
-
5.96209236
,
-
1.13039161
,
-
7.1115294
,
-
0.8369633
,
-
5.3186408
,
7.06427407
,
0.81369344
,
-
0.82023817
,
-
5.9179796
,
0.58813443
,
-
6.99778438
,
4.71551189
,
-
0.18771637
,
7.44020759
,
# 4th highest value; idx. 25
9.38450987
,
# 1st highest value; idx. 26
2.12662941
,
-
9.32562038
,
2.35652522
,
],
# cummulative prob of 5 highest values <= 0.6
[
0.58425518
,
4.53139238
,
-
5.57510464
,
-
6.28030699
,
-
7.19529503
,
-
4.02122551
,
1.39337037
,
-
6.06707057
,
1.59480517
,
-
9.643119
,
0.03907799
,
0.67231762
,
-
8.88206726
,
6.27115922
,
# 4th highest value; idx. 13
2.28520723
,
4.82767506
,
4.30421368
,
8.8275313
,
# 2nd highest value; idx. 17
5.44029958
,
# 5th highest value; idx. 18
-
4.4735794
,
7.38579536
,
# 3rd highest value; idx. 20
-
2.91051663
,
2.61946077
,
-
2.5674762
,
-
9.48959302
,
-
4.02922645
,
-
1.35416918
,
9.67702323
,
# 1st highest value; idx. 27
-
5.89478553
,
1.85370467
,
],
# cummulative prob of 5 highest values <= 0.6
],
dtype
=
tf
.
float32
,
)
non_inf_expected_idx
=
tf
.
convert_to_tensor
(
[[
0
,
0
],
[
0
,
9
],
[
0
,
10
],
[
0
,
25
],
[
0
,
26
],
[
1
,
13
],
[
1
,
17
],
[
1
,
18
],
[
1
,
20
],
[
1
,
27
]],
dtype
=
tf
.
int32
,
)
# expected non filtered idx as noted above
non_inf_expected_output
=
tf
.
convert_to_tensor
(
[
8.222099
,
7.3534126
,
8.432078
,
7.4402075
,
9.38451
,
6.271159
,
8.827531
,
5.4402995
,
7.3857956
,
9.677023
],
dtype
=
tf
.
float32
,
)
# expected non filtered values as noted above
output
=
tf_top_k_top_p_filtering
(
logits
,
top_k
=
10
,
top_p
=
0.6
,
min_tokens_to_keep
=
4
)
non_inf_output
=
output
[
output
!=
-
float
(
"inf"
)]
non_inf_idx
=
tf
.
cast
(
tf
.
where
(
tf
.
not_equal
(
output
,
tf
.
constant
(
-
float
(
"inf"
),
dtype
=
tf
.
float32
))),
dtype
=
tf
.
int32
,
)
tf
.
debugging
.
assert_near
(
non_inf_output
,
non_inf_expected_output
,
rtol
=
1e-12
)
tf
.
debugging
.
assert_equal
(
non_inf_idx
,
non_inf_expected_idx
)
def
test_cached_files_are_used_when_internet_is_down
(
self
):
# A mock response for an HTTP head request to emulate server down
response_mock
=
mock
.
Mock
()
...
...
@@ -2179,46 +2083,6 @@ class UtilsFunctionsTest(unittest.TestCase):
for
p1
,
p2
in
zip
(
model
.
weights
,
new_model
.
weights
):
self
.
assertTrue
(
np
.
allclose
(
p1
.
numpy
(),
p2
.
numpy
()))
def
test_generate_tf_function_export
(
self
):
test_model
=
TFAutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
max_length
=
2
class
DummyModel
(
tf
.
Module
):
def
__init__
(
self
,
model
):
super
(
DummyModel
,
self
).
__init__
()
self
.
model
=
model
@
tf
.
function
(
input_signature
=
(
tf
.
TensorSpec
((
None
,
max_length
),
tf
.
int32
,
name
=
"input_ids"
),
tf
.
TensorSpec
((
None
,
max_length
),
tf
.
int32
,
name
=
"attention_mask"
),
),
jit_compile
=
True
,
)
def
serving
(
self
,
input_ids
,
attention_mask
):
outputs
=
self
.
model
.
generate
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
max_new_tokens
=
max_length
,
return_dict_in_generate
=
True
,
)
return
{
"sequences"
:
outputs
[
"sequences"
]}
dummy_input_ids
=
[[
2
,
0
],
[
102
,
103
]]
dummy_attention_masks
=
[[
1
,
0
],
[
1
,
1
]]
dummy_model
=
DummyModel
(
model
=
test_model
)
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
tf
.
saved_model
.
save
(
dummy_model
,
tmp_dir
,
signatures
=
{
"serving_default"
:
dummy_model
.
serving
})
serving_func
=
tf
.
saved_model
.
load
(
tmp_dir
).
signatures
[
"serving_default"
]
for
batch_size
in
range
(
1
,
len
(
dummy_input_ids
)
+
1
):
inputs
=
{
"input_ids"
:
tf
.
constant
(
dummy_input_ids
[:
batch_size
]),
"attention_mask"
:
tf
.
constant
(
dummy_attention_masks
[:
batch_size
]),
}
tf_func_outputs
=
serving_func
(
**
inputs
)[
"sequences"
]
tf_model_outputs
=
test_model
.
generate
(
**
inputs
,
max_new_tokens
=
max_length
)
tf
.
debugging
.
assert_equal
(
tf_func_outputs
,
tf_model_outputs
)
@
require_tf
@
is_staging_test
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment