Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2edf9a85
Unverified
Commit
2edf9a85
authored
Feb 09, 2023
by
Joao Gante
Committed by
GitHub
Feb 09, 2023
Browse files
Generate: TF `.generate()` can now be exported with dynamic length (#21474)
parent
e69f9715
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
68 additions
and
20 deletions
+68
-20
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+18
-13
src/transformers/models/gpt2/modeling_tf_gpt2.py
src/transformers/models/gpt2/modeling_tf_gpt2.py
+1
-1
tests/generation/test_tf_utils.py
tests/generation/test_tf_utils.py
+49
-6
No files found.
src/transformers/generation/tf_utils.py
View file @
2edf9a85
...
@@ -849,7 +849,7 @@ class TFGenerationMixin:
...
@@ -849,7 +849,7 @@ class TFGenerationMixin:
input_ids
=
inputs_tensor
input_ids
=
inputs_tensor
# 7. Prepare `max_length` depending on other stopping criteria.
# 7. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length
=
input_ids
.
shape
[
-
1
]
input_ids_seq_length
=
shape_list
(
input_ids
)
[
-
1
]
has_default_max_length
=
kwargs
.
get
(
"max_length"
)
is
None
and
generation_config
.
max_length
is
not
None
has_default_max_length
=
kwargs
.
get
(
"max_length"
)
is
None
and
generation_config
.
max_length
is
not
None
if
has_default_max_length
and
generation_config
.
max_new_tokens
is
None
:
if
has_default_max_length
and
generation_config
.
max_new_tokens
is
None
:
warnings
.
warn
(
warnings
.
warn
(
...
@@ -869,18 +869,23 @@ class TFGenerationMixin:
...
@@ -869,18 +869,23 @@ class TFGenerationMixin:
UserWarning
,
UserWarning
,
)
)
if
generation_config
.
min_length
is
not
None
and
generation_config
.
min_length
>
generation_config
.
max_length
:
# If the input length is a tensor (i.e. dynamic length), skip length checks
raise
ValueError
(
if
not
isinstance
(
input_ids_seq_length
,
tf
.
Tensor
):
f
"Unfeasable length constraints: the minimum length (
{
generation_config
.
min_length
}
) is larger than"
if
(
f
" the maximum length (
{
generation_config
.
max_length
}
)"
generation_config
.
min_length
is
not
None
)
and
generation_config
.
min_length
>
generation_config
.
max_length
if
input_ids_seq_length
>=
generation_config
.
max_length
:
):
input_ids_string
=
"decoder_input_ids"
if
self
.
config
.
is_encoder_decoder
else
"input_ids"
raise
ValueError
(
logger
.
warning
(
f
"Unfeasable length constraints: the minimum length (
{
generation_config
.
min_length
}
) is larger"
f
"Input length of
{
input_ids_string
}
is
{
input_ids_seq_length
}
, but `max_length` is set to"
f
" than the maximum length (
{
generation_config
.
max_length
}
)"
f
"
{
generation_config
.
max_length
}
. This can lead to unexpected behavior. You should consider"
)
" increasing`max_new_tokens`."
if
input_ids_seq_length
>=
generation_config
.
max_length
:
)
input_ids_string
=
"decoder_input_ids"
if
self
.
config
.
is_encoder_decoder
else
"input_ids"
logger
.
warning
(
f
"Input length of
{
input_ids_string
}
is
{
input_ids_seq_length
}
, but `max_length` is set to"
f
"
{
generation_config
.
max_length
}
. This can lead to unexpected behavior. You should consider"
" increasing`max_new_tokens`."
)
# 8. determine generation mode
# 8. determine generation mode
is_contrastive_search_gen_mode
=
(
is_contrastive_search_gen_mode
=
(
...
...
src/transformers/models/gpt2/modeling_tf_gpt2.py
View file @
2edf9a85
...
@@ -182,7 +182,7 @@ class TFAttention(tf.keras.layers.Layer):
...
@@ -182,7 +182,7 @@ class TFAttention(tf.keras.layers.Layer):
key
=
self
.
split_heads
(
key
)
key
=
self
.
split_heads
(
key
)
value
=
self
.
split_heads
(
value
)
value
=
self
.
split_heads
(
value
)
if
layer_past
is
not
None
:
if
layer_past
is
not
None
:
past_key
,
past_value
=
tf
.
unstack
(
layer_past
,
axis
=
0
)
past_key
,
past_value
=
tf
.
unstack
(
layer_past
,
axis
=
0
,
num
=
2
)
key
=
tf
.
concat
([
past_key
,
key
],
axis
=-
2
)
key
=
tf
.
concat
([
past_key
,
key
],
axis
=-
2
)
value
=
tf
.
concat
([
past_value
,
value
],
axis
=-
2
)
value
=
tf
.
concat
([
past_value
,
value
],
axis
=-
2
)
...
...
tests/generation/test_tf_utils.py
View file @
2edf9a85
...
@@ -144,9 +144,10 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
...
@@ -144,9 +144,10 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
}
}
@
slow
@
slow
def
test_generate_tf_function_export
(
self
):
def
test_generate_tf_function_export
_fixed_input_length
(
self
):
test_model
=
TFAutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
test_model
=
TFAutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
max_length
=
2
input_length
=
2
max_new_tokens
=
2
class
DummyModel
(
tf
.
Module
):
class
DummyModel
(
tf
.
Module
):
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
...
@@ -155,8 +156,8 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
...
@@ -155,8 +156,8 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
@
tf
.
function
(
@
tf
.
function
(
input_signature
=
(
input_signature
=
(
tf
.
TensorSpec
((
None
,
max
_length
),
tf
.
int32
,
name
=
"input_ids"
),
tf
.
TensorSpec
((
None
,
input
_length
),
tf
.
int32
,
name
=
"input_ids"
),
tf
.
TensorSpec
((
None
,
max
_length
),
tf
.
int32
,
name
=
"attention_mask"
),
tf
.
TensorSpec
((
None
,
input
_length
),
tf
.
int32
,
name
=
"attention_mask"
),
),
),
jit_compile
=
True
,
jit_compile
=
True
,
)
)
...
@@ -164,7 +165,7 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
...
@@ -164,7 +165,7 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
outputs
=
self
.
model
.
generate
(
outputs
=
self
.
model
.
generate
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
max_new_tokens
=
max_
length
,
max_new_tokens
=
max_
new_tokens
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
)
)
return
{
"sequences"
:
outputs
[
"sequences"
]}
return
{
"sequences"
:
outputs
[
"sequences"
]}
...
@@ -181,5 +182,47 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
...
@@ -181,5 +182,47 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
"attention_mask"
:
tf
.
constant
(
dummy_attention_masks
[:
batch_size
]),
"attention_mask"
:
tf
.
constant
(
dummy_attention_masks
[:
batch_size
]),
}
}
tf_func_outputs
=
serving_func
(
**
inputs
)[
"sequences"
]
tf_func_outputs
=
serving_func
(
**
inputs
)[
"sequences"
]
tf_model_outputs
=
test_model
.
generate
(
**
inputs
,
max_new_tokens
=
max_length
)
tf_model_outputs
=
test_model
.
generate
(
**
inputs
,
max_new_tokens
=
max_new_tokens
)
tf
.
debugging
.
assert_equal
(
tf_func_outputs
,
tf_model_outputs
)
@
slow
def
test_generate_tf_function_export_fixed_batch_size
(
self
):
test_model
=
TFAutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
batch_size
=
1
max_new_tokens
=
2
class
DummyModel
(
tf
.
Module
):
def
__init__
(
self
,
model
):
super
(
DummyModel
,
self
).
__init__
()
self
.
model
=
model
@
tf
.
function
(
input_signature
=
(
tf
.
TensorSpec
((
batch_size
,
None
),
tf
.
int32
,
name
=
"input_ids"
),
tf
.
TensorSpec
((
batch_size
,
None
),
tf
.
int32
,
name
=
"attention_mask"
),
),
jit_compile
=
True
,
)
def
serving
(
self
,
input_ids
,
attention_mask
):
outputs
=
self
.
model
.
generate
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
max_new_tokens
=
max_new_tokens
,
return_dict_in_generate
=
True
,
)
return
{
"sequences"
:
outputs
[
"sequences"
]}
dummy_input_ids
=
[[
2
],
[
102
,
103
]]
dummy_attention_masks
=
[[
1
],
[
1
,
1
]]
dummy_model
=
DummyModel
(
model
=
test_model
)
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
tf
.
saved_model
.
save
(
dummy_model
,
tmp_dir
,
signatures
=
{
"serving_default"
:
dummy_model
.
serving
})
serving_func
=
tf
.
saved_model
.
load
(
tmp_dir
).
signatures
[
"serving_default"
]
for
input_row
in
range
(
len
(
dummy_input_ids
)):
inputs
=
{
"input_ids"
:
tf
.
constant
([
dummy_input_ids
[
input_row
]]),
"attention_mask"
:
tf
.
constant
([
dummy_attention_masks
[
input_row
]]),
}
tf_func_outputs
=
serving_func
(
**
inputs
)[
"sequences"
]
tf_model_outputs
=
test_model
.
generate
(
**
inputs
,
max_new_tokens
=
max_new_tokens
)
tf
.
debugging
.
assert_equal
(
tf_func_outputs
,
tf_model_outputs
)
tf
.
debugging
.
assert_equal
(
tf_func_outputs
,
tf_model_outputs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment