Unverified Commit 90b4adc1 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: skip tests on unsupported models instead of passing (#27265)

parent 26d8d5f2
...@@ -749,8 +749,7 @@ class GenerationTesterMixin: ...@@ -749,8 +749,7 @@ class GenerationTesterMixin:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
# only relevant if model has "use_cache" self.skipTest("This model doesn't support caching")
return
config.use_cache = True config.use_cache = True
config.is_decoder = True config.is_decoder = True
...@@ -983,8 +982,7 @@ class GenerationTesterMixin: ...@@ -983,8 +982,7 @@ class GenerationTesterMixin:
config.forced_eos_token_id = None config.forced_eos_token_id = None
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
# only relevant if model has "use_cache" self.skipTest("This model doesn't support caching")
return
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
...@@ -1420,13 +1418,13 @@ class GenerationTesterMixin: ...@@ -1420,13 +1418,13 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
# won't fix: FSMT and Reformer have a different cache variable type (and format). # won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
return self.skipTest("Won't fix: old model with different cache format")
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# NOTE: contrastive search only works with cache on at the moment. # NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
return self.skipTest("This model doesn't support caching")
config.use_cache = True config.use_cache = True
config.is_decoder = True config.is_decoder = True
...@@ -1441,14 +1439,14 @@ class GenerationTesterMixin: ...@@ -1441,14 +1439,14 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
# won't fix: FSMT and Reformer have a different cache variable type (and format). # won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
return self.skipTest("Won't fix: old model with different cache format")
# enable cache # enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# NOTE: contrastive search only works with cache on at the moment. # NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
return self.skipTest("This model doesn't support caching")
config.use_cache = True config.use_cache = True
config.is_decoder = True config.is_decoder = True
...@@ -1472,18 +1470,16 @@ class GenerationTesterMixin: ...@@ -1472,18 +1470,16 @@ class GenerationTesterMixin:
def test_contrastive_generate_low_memory(self): def test_contrastive_generate_low_memory(self):
# Check that choosing 'low_memory' does not change the model output # Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
# won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format). if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]):
if any( self.skipTest("Won't fix: old model with different cache format")
model_name in model_class.__name__.lower() if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]):
for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"] self.skipTest("TODO: fix me")
):
return
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
# NOTE: contrastive search only works with cache on at the moment. # NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
return self.skipTest("This model doesn't support caching")
config.use_cache = True config.use_cache = True
config.is_decoder = True config.is_decoder = True
...@@ -1510,8 +1506,6 @@ class GenerationTesterMixin: ...@@ -1510,8 +1506,6 @@ class GenerationTesterMixin:
) )
self.assertListEqual(low_output.tolist(), high_output.tolist()) self.assertListEqual(low_output.tolist(), high_output.tolist())
return
@slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%. @slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
def test_assisted_decoding_matches_greedy_search(self): def test_assisted_decoding_matches_greedy_search(self):
# This test ensures that the assisted generation does not introduce output changes over greedy search. # This test ensures that the assisted generation does not introduce output changes over greedy search.
...@@ -1522,15 +1516,13 @@ class GenerationTesterMixin: ...@@ -1522,15 +1516,13 @@ class GenerationTesterMixin:
# - assisted_decoding does not support `batch_size > 1` # - assisted_decoding does not support `batch_size > 1`
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
return self.skipTest("Won't fix: old model with different cache format")
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
if any( if any(
model_name in model_class.__name__.lower() model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"] for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
): ):
return self.skipTest("May fix in the future: need model-specific fixes")
# This for loop is a naive and temporary effort to make the test less flaky. # This for loop is a naive and temporary effort to make the test less flaky.
failed = 0 failed = 0
...@@ -1540,7 +1532,7 @@ class GenerationTesterMixin: ...@@ -1540,7 +1532,7 @@ class GenerationTesterMixin:
# NOTE: assisted generation only works with cache on at the moment. # NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
return self.skipTest("This model doesn't support caching")
config.use_cache = True config.use_cache = True
config.is_decoder = True config.is_decoder = True
...@@ -1587,24 +1579,21 @@ class GenerationTesterMixin: ...@@ -1587,24 +1579,21 @@ class GenerationTesterMixin:
def test_assisted_decoding_sample(self): def test_assisted_decoding_sample(self):
# Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the # Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the
# exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking). # exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking).
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
return self.skipTest("Won't fix: old model with different cache format")
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
if any( if any(
model_name in model_class.__name__.lower() model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet", "seamlessm4t"] for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet", "seamlessm4t"]
): ):
return self.skipTest("May fix in the future: need model-specific fixes")
# enable cache # enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment. # NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
return self.skipTest("This model doesn't support caching")
config.use_cache = True config.use_cache = True
config.is_decoder = True config.is_decoder = True
...@@ -1716,7 +1705,7 @@ class GenerationTesterMixin: ...@@ -1716,7 +1705,7 @@ class GenerationTesterMixin:
# If it doesn't support cache, pass the test # If it doesn't support cache, pass the test
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
return self.skipTest("This model doesn't support caching")
model = model_class(config).to(torch_device) model = model_class(config).to(torch_device)
if "use_cache" not in inputs: if "use_cache" not in inputs:
...@@ -1725,7 +1714,7 @@ class GenerationTesterMixin: ...@@ -1725,7 +1714,7 @@ class GenerationTesterMixin:
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format) # If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
if "past_key_values" not in outputs: if "past_key_values" not in outputs:
return self.skipTest("This model doesn't return `past_key_values`")
num_hidden_layers = ( num_hidden_layers = (
getattr(config, "decoder_layers", None) getattr(config, "decoder_layers", None)
...@@ -1832,18 +1821,15 @@ class GenerationTesterMixin: ...@@ -1832,18 +1821,15 @@ class GenerationTesterMixin:
def test_generate_continue_from_past_key_values(self): def test_generate_continue_from_past_key_values(self):
# Tests that we can continue generating from past key values, returned from a previous `generate` call # Tests that we can continue generating from past key values, returned from a previous `generate` call
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
# won't fix: old models with unique inputs/caches/others
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]): if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
return self.skipTest("Won't fix: old model with unique inputs/caches/other")
# may fix in the future: needs modeling or test input preparation fixes for compatibility
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]): if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
return self.skipTest("TODO: needs modeling or test input preparation fixes for compatibility")
config, inputs = self.model_tester.prepare_config_and_inputs_for_common() config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
# If it doesn't support cache, pass the test
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
return self.skipTest("This model doesn't support caching")
# Let's make it always: # Let's make it always:
# 1. use cache (for obvious reasons) # 1. use cache (for obvious reasons)
...@@ -1862,10 +1848,10 @@ class GenerationTesterMixin: ...@@ -1862,10 +1848,10 @@ class GenerationTesterMixin:
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1 model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
model.generation_config.forced_eos_token_id = None model.generation_config.forced_eos_token_id = None
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format) # If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
outputs = model(**inputs) outputs = model(**inputs)
if "past_key_values" not in outputs: if "past_key_values" not in outputs:
return self.skipTest("This model doesn't return `past_key_values`")
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values # Traditional way of generating text, with `return_dict_in_generate` to return the past key values
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True) outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment