Unverified Commit 0f78529f authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: general TF XLA constrastive search are now slow tests (#20277)

* move contrastive search test to slow
parent 2062c285
......@@ -1800,7 +1800,7 @@ class TFModelTesterMixin:
model.compile(optimizer="sgd", run_eagerly=True)
model.train_on_batch(test_batch, test_batch_labels)
def _test_xla_generate(self, num_beams, num_return_sequences, max_length, **generate_kwargs):
def _test_xla_generate(self, **generate_kwargs):
def _generate_and_check_results(model, config, inputs_dict):
if "input_ids" in inputs_dict:
inputs = inputs_dict["input_ids"]
......@@ -1826,20 +1826,7 @@ class TFModelTesterMixin:
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.eos_token_id = None # Generate until max length
config.max_length = max_length
config.do_sample = False
config.num_beams = num_beams
config.num_return_sequences = num_return_sequences
# fix config for models with additional sequence-length limiting settings
for var_name in ["max_position_embeddings", "max_target_positions"]:
if hasattr(config, var_name):
try:
setattr(config, var_name, max_length)
except NotImplementedError:
# xlnet will raise an exception when trying to set
# max_position_embeddings.
pass
model = model_class(config)
......@@ -1856,23 +1843,18 @@ class TFModelTesterMixin:
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
"""
num_beams = 1
num_return_sequences = 1
max_length = 10
self._test_xla_generate(num_beams, num_return_sequences, max_length)
self._test_xla_generate(num_beams=1, num_return_sequences=1, max_new_tokens=3)
@slow
def test_xla_generate_contrastive(self):
"""
Similar to `test_xla_generate_fast`, but for contrastive search -- contrastive search directly manipulates the
model cache and other outputs, and this test ensures that they are in a valid format that is also supported
by XLA.
Slow and challenging version of `test_xla_generate_fast` for contrastive search -- contrastive search directly
manipulates the model cache and other outputs, and this test ensures that they are in a valid format that is
also supported by XLA.
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
"""
num_beams = 1
num_return_sequences = 1
max_length = 10
self._test_xla_generate(num_beams, num_return_sequences, max_length, penalty_alpha=0.5, top_k=5)
self._test_xla_generate(num_beams=1, num_return_sequences=1, max_new_tokens=64, penalty_alpha=0.5, top_k=4)
@slow
def test_xla_generate_slow(self):
......@@ -1883,10 +1865,7 @@ class TFModelTesterMixin:
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
"""
num_beams = 8
num_return_sequences = 2
max_length = 128
self._test_xla_generate(num_beams, num_return_sequences, max_length)
self._test_xla_generate(num_beams=8, num_return_sequences=2, max_new_tokens=128)
def _generate_random_bad_tokens(self, num_bad_tokens, model):
# special tokens cannot be bad tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment