Unverified Commit 0f78529f authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: general TF XLA constrastive search are now slow tests (#20277)

* move contrastive search test to slow
parent 2062c285
...@@ -1800,7 +1800,7 @@ class TFModelTesterMixin: ...@@ -1800,7 +1800,7 @@ class TFModelTesterMixin:
model.compile(optimizer="sgd", run_eagerly=True) model.compile(optimizer="sgd", run_eagerly=True)
model.train_on_batch(test_batch, test_batch_labels) model.train_on_batch(test_batch, test_batch_labels)
def _test_xla_generate(self, num_beams, num_return_sequences, max_length, **generate_kwargs): def _test_xla_generate(self, **generate_kwargs):
def _generate_and_check_results(model, config, inputs_dict): def _generate_and_check_results(model, config, inputs_dict):
if "input_ids" in inputs_dict: if "input_ids" in inputs_dict:
inputs = inputs_dict["input_ids"] inputs = inputs_dict["input_ids"]
...@@ -1826,20 +1826,7 @@ class TFModelTesterMixin: ...@@ -1826,20 +1826,7 @@ class TFModelTesterMixin:
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.eos_token_id = None # Generate until max length config.eos_token_id = None # Generate until max length
config.max_length = max_length
config.do_sample = False config.do_sample = False
config.num_beams = num_beams
config.num_return_sequences = num_return_sequences
# fix config for models with additional sequence-length limiting settings
for var_name in ["max_position_embeddings", "max_target_positions"]:
if hasattr(config, var_name):
try:
setattr(config, var_name, max_length)
except NotImplementedError:
# xlnet will raise an exception when trying to set
# max_position_embeddings.
pass
model = model_class(config) model = model_class(config)
...@@ -1856,23 +1843,18 @@ class TFModelTesterMixin: ...@@ -1856,23 +1843,18 @@ class TFModelTesterMixin:
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
""" """
num_beams = 1 self._test_xla_generate(num_beams=1, num_return_sequences=1, max_new_tokens=3)
num_return_sequences = 1
max_length = 10
self._test_xla_generate(num_beams, num_return_sequences, max_length)
@slow
def test_xla_generate_contrastive(self): def test_xla_generate_contrastive(self):
""" """
Similar to `test_xla_generate_fast`, but for contrastive search -- contrastive search directly manipulates the Slow and challenging version of `test_xla_generate_fast` for contrastive search -- contrastive search directly
model cache and other outputs, and this test ensures that they are in a valid format that is also supported manipulates the model cache and other outputs, and this test ensures that they are in a valid format that is
by XLA. also supported by XLA.
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
""" """
num_beams = 1 self._test_xla_generate(num_beams=1, num_return_sequences=1, max_new_tokens=64, penalty_alpha=0.5, top_k=4)
num_return_sequences = 1
max_length = 10
self._test_xla_generate(num_beams, num_return_sequences, max_length, penalty_alpha=0.5, top_k=5)
@slow @slow
def test_xla_generate_slow(self): def test_xla_generate_slow(self):
...@@ -1883,10 +1865,7 @@ class TFModelTesterMixin: ...@@ -1883,10 +1865,7 @@ class TFModelTesterMixin:
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
""" """
num_beams = 8 self._test_xla_generate(num_beams=8, num_return_sequences=2, max_new_tokens=128)
num_return_sequences = 2
max_length = 128
self._test_xla_generate(num_beams, num_return_sequences, max_length)
def _generate_random_bad_tokens(self, num_bad_tokens, model): def _generate_random_bad_tokens(self, num_bad_tokens, model):
# special tokens cannot be bad tokens # special tokens cannot be bad tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment