Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2edf9a85
Unverified
Commit
2edf9a85
authored
Feb 09, 2023
by
Joao Gante
Committed by
GitHub
Feb 09, 2023
Browse files
Generate: TF `.generate()` can now be exported with dynamic length (#21474)
parent
e69f9715
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
68 additions
and
20 deletions
+68
-20
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+18
-13
src/transformers/models/gpt2/modeling_tf_gpt2.py
src/transformers/models/gpt2/modeling_tf_gpt2.py
+1
-1
tests/generation/test_tf_utils.py
tests/generation/test_tf_utils.py
+49
-6
No files found.
src/transformers/generation/tf_utils.py
View file @
2edf9a85
...
...
@@ -849,7 +849,7 @@ class TFGenerationMixin:
input_ids
=
inputs_tensor
# 7. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length
=
input_ids
.
shape
[
-
1
]
input_ids_seq_length
=
shape_list
(
input_ids
)
[
-
1
]
has_default_max_length
=
kwargs
.
get
(
"max_length"
)
is
None
and
generation_config
.
max_length
is
not
None
if
has_default_max_length
and
generation_config
.
max_new_tokens
is
None
:
warnings
.
warn
(
...
...
@@ -869,10 +869,15 @@ class TFGenerationMixin:
UserWarning
,
)
if
generation_config
.
min_length
is
not
None
and
generation_config
.
min_length
>
generation_config
.
max_length
:
# If the input length is a tensor (i.e. dynamic length), skip length checks
if
not
isinstance
(
input_ids_seq_length
,
tf
.
Tensor
):
if
(
generation_config
.
min_length
is
not
None
and
generation_config
.
min_length
>
generation_config
.
max_length
):
raise
ValueError
(
f
"Unfeasable length constraints: the minimum length (
{
generation_config
.
min_length
}
) is larger
than
"
f
"
the maximum length (
{
generation_config
.
max_length
}
)"
f
"Unfeasable length constraints: the minimum length (
{
generation_config
.
min_length
}
) is larger"
f
" than
the maximum length (
{
generation_config
.
max_length
}
)"
)
if
input_ids_seq_length
>=
generation_config
.
max_length
:
input_ids_string
=
"decoder_input_ids"
if
self
.
config
.
is_encoder_decoder
else
"input_ids"
...
...
src/transformers/models/gpt2/modeling_tf_gpt2.py
View file @
2edf9a85
...
...
@@ -182,7 +182,7 @@ class TFAttention(tf.keras.layers.Layer):
key
=
self
.
split_heads
(
key
)
value
=
self
.
split_heads
(
value
)
if
layer_past
is
not
None
:
past_key
,
past_value
=
tf
.
unstack
(
layer_past
,
axis
=
0
)
past_key
,
past_value
=
tf
.
unstack
(
layer_past
,
axis
=
0
,
num
=
2
)
key
=
tf
.
concat
([
past_key
,
key
],
axis
=-
2
)
value
=
tf
.
concat
([
past_value
,
value
],
axis
=-
2
)
...
...
tests/generation/test_tf_utils.py
View file @
2edf9a85
...
...
@@ -144,9 +144,10 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
}
@
slow
def
test_generate_tf_function_export
(
self
):
def
test_generate_tf_function_export
_fixed_input_length
(
self
):
test_model
=
TFAutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
max_length
=
2
input_length
=
2
max_new_tokens
=
2
class
DummyModel
(
tf
.
Module
):
def
__init__
(
self
,
model
):
...
...
@@ -155,8 +156,8 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
@
tf
.
function
(
input_signature
=
(
tf
.
TensorSpec
((
None
,
max
_length
),
tf
.
int32
,
name
=
"input_ids"
),
tf
.
TensorSpec
((
None
,
max
_length
),
tf
.
int32
,
name
=
"attention_mask"
),
tf
.
TensorSpec
((
None
,
input
_length
),
tf
.
int32
,
name
=
"input_ids"
),
tf
.
TensorSpec
((
None
,
input
_length
),
tf
.
int32
,
name
=
"attention_mask"
),
),
jit_compile
=
True
,
)
...
...
@@ -164,7 +165,7 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
outputs
=
self
.
model
.
generate
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
max_new_tokens
=
max_
length
,
max_new_tokens
=
max_
new_tokens
,
return_dict_in_generate
=
True
,
)
return
{
"sequences"
:
outputs
[
"sequences"
]}
...
...
@@ -181,5 +182,47 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
"attention_mask"
:
tf
.
constant
(
dummy_attention_masks
[:
batch_size
]),
}
tf_func_outputs
=
serving_func
(
**
inputs
)[
"sequences"
]
tf_model_outputs
=
test_model
.
generate
(
**
inputs
,
max_new_tokens
=
max_length
)
tf_model_outputs
=
test_model
.
generate
(
**
inputs
,
max_new_tokens
=
max_new_tokens
)
tf
.
debugging
.
assert_equal
(
tf_func_outputs
,
tf_model_outputs
)
@
slow
def
test_generate_tf_function_export_fixed_batch_size
(
self
):
test_model
=
TFAutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
batch_size
=
1
max_new_tokens
=
2
class
DummyModel
(
tf
.
Module
):
def
__init__
(
self
,
model
):
super
(
DummyModel
,
self
).
__init__
()
self
.
model
=
model
@
tf
.
function
(
input_signature
=
(
tf
.
TensorSpec
((
batch_size
,
None
),
tf
.
int32
,
name
=
"input_ids"
),
tf
.
TensorSpec
((
batch_size
,
None
),
tf
.
int32
,
name
=
"attention_mask"
),
),
jit_compile
=
True
,
)
def
serving
(
self
,
input_ids
,
attention_mask
):
outputs
=
self
.
model
.
generate
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
max_new_tokens
=
max_new_tokens
,
return_dict_in_generate
=
True
,
)
return
{
"sequences"
:
outputs
[
"sequences"
]}
dummy_input_ids
=
[[
2
],
[
102
,
103
]]
dummy_attention_masks
=
[[
1
],
[
1
,
1
]]
dummy_model
=
DummyModel
(
model
=
test_model
)
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
tf
.
saved_model
.
save
(
dummy_model
,
tmp_dir
,
signatures
=
{
"serving_default"
:
dummy_model
.
serving
})
serving_func
=
tf
.
saved_model
.
load
(
tmp_dir
).
signatures
[
"serving_default"
]
for
input_row
in
range
(
len
(
dummy_input_ids
)):
inputs
=
{
"input_ids"
:
tf
.
constant
([
dummy_input_ids
[
input_row
]]),
"attention_mask"
:
tf
.
constant
([
dummy_attention_masks
[
input_row
]]),
}
tf_func_outputs
=
serving_func
(
**
inputs
)[
"sequences"
]
tf_model_outputs
=
test_model
.
generate
(
**
inputs
,
max_new_tokens
=
max_new_tokens
)
tf
.
debugging
.
assert_equal
(
tf_func_outputs
,
tf_model_outputs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment