Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0f78529f
Unverified
Commit
0f78529f
authored
Nov 17, 2022
by
Joao Gante
Committed by
GitHub
Nov 17, 2022
Browse files
Generate: general TF XLA constrastive search are now slow tests (#20277)
* move contrastive search test to slow
parent
2062c285
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
29 deletions
+8
-29
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+8
-29
No files found.
tests/test_modeling_tf_common.py
View file @
0f78529f
...
...
@@ -1800,7 +1800,7 @@ class TFModelTesterMixin:
model
.
compile
(
optimizer
=
"sgd"
,
run_eagerly
=
True
)
model
.
train_on_batch
(
test_batch
,
test_batch_labels
)
def
_test_xla_generate
(
self
,
num_beams
,
num_return_sequences
,
max_length
,
**
generate_kwargs
):
def
_test_xla_generate
(
self
,
**
generate_kwargs
):
def
_generate_and_check_results
(
model
,
config
,
inputs_dict
):
if
"input_ids"
in
inputs_dict
:
inputs
=
inputs_dict
[
"input_ids"
]
...
...
@@ -1826,20 +1826,7 @@ class TFModelTesterMixin:
for
model_class
in
self
.
all_generative_model_classes
:
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
.
eos_token_id
=
None
# Generate until max length
config
.
max_length
=
max_length
config
.
do_sample
=
False
config
.
num_beams
=
num_beams
config
.
num_return_sequences
=
num_return_sequences
# fix config for models with additional sequence-length limiting settings
for
var_name
in
[
"max_position_embeddings"
,
"max_target_positions"
]:
if
hasattr
(
config
,
var_name
):
try
:
setattr
(
config
,
var_name
,
max_length
)
except
NotImplementedError
:
# xlnet will raise an exception when trying to set
# max_position_embeddings.
pass
model
=
model_class
(
config
)
...
...
@@ -1856,23 +1843,18 @@ class TFModelTesterMixin:
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
"""
num_beams
=
1
num_return_sequences
=
1
max_length
=
10
self
.
_test_xla_generate
(
num_beams
,
num_return_sequences
,
max_length
)
self
.
_test_xla_generate
(
num_beams
=
1
,
num_return_sequences
=
1
,
max_new_tokens
=
3
)
@
slow
def
test_xla_generate_contrastive
(
self
):
"""
S
imilar to
`test_xla_generate_fast`
, but
for contrastive search -- contrastive search directly
manipulates the
model cache and other outputs, and this test ensures that they are in a valid format that is
also supported
by XLA.
S
low and challenging version of
`test_xla_generate_fast` for contrastive search -- contrastive search directly
manipulates the
model cache and other outputs, and this test ensures that they are in a valid format that is
also supported
by XLA.
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
"""
num_beams
=
1
num_return_sequences
=
1
max_length
=
10
self
.
_test_xla_generate
(
num_beams
,
num_return_sequences
,
max_length
,
penalty_alpha
=
0.5
,
top_k
=
5
)
self
.
_test_xla_generate
(
num_beams
=
1
,
num_return_sequences
=
1
,
max_new_tokens
=
64
,
penalty_alpha
=
0.5
,
top_k
=
4
)
@
slow
def
test_xla_generate_slow
(
self
):
...
...
@@ -1883,10 +1865,7 @@ class TFModelTesterMixin:
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
"""
num_beams
=
8
num_return_sequences
=
2
max_length
=
128
self
.
_test_xla_generate
(
num_beams
,
num_return_sequences
,
max_length
)
self
.
_test_xla_generate
(
num_beams
=
8
,
num_return_sequences
=
2
,
max_new_tokens
=
128
)
def
_generate_random_bad_tokens
(
self
,
num_bad_tokens
,
model
):
# special tokens cannot be bad tokens
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment