Unverified Commit 12b50c61 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: improve assisted generation tests (#27540)

parent 651408a0
...@@ -23,6 +23,7 @@ import numpy as np ...@@ -23,6 +23,7 @@ import numpy as np
from transformers import is_torch_available, pipeline from transformers import is_torch_available, pipeline
from transformers.testing_utils import ( from transformers.testing_utils import (
is_flaky,
require_accelerate, require_accelerate,
require_torch, require_torch,
require_torch_multi_accelerator, require_torch_multi_accelerator,
...@@ -1506,10 +1507,14 @@ class GenerationTesterMixin: ...@@ -1506,10 +1507,14 @@ class GenerationTesterMixin:
) )
self.assertListEqual(low_output.tolist(), high_output.tolist()) self.assertListEqual(low_output.tolist(), high_output.tolist())
@slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%. @is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
def test_assisted_decoding_matches_greedy_search(self): def test_assisted_decoding_matches_greedy_search(self):
# This test ensures that the assisted generation does not introduce output changes over greedy search. # This test ensures that the assisted generation does not introduce output changes over greedy search.
# It breaks the pattern in the tests above, for multiple reasons: # NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul
# shape differences -- and it may result in a different output. The input shape difference happens in the
# main model, that runs the forward pass with several candidates at once (as opposed to generating one token at
# a time). See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
# NOTE (2): It breaks the pattern in the tests above, for multiple reasons:
# - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to # - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to
# prepare the assistant encoder outputs in the main generate body); # prepare the assistant encoder outputs in the main generate body);
# - assisted_decoding does not support `use_cache = False` # - assisted_decoding does not support `use_cache = False`
...@@ -1520,77 +1525,82 @@ class GenerationTesterMixin: ...@@ -1520,77 +1525,82 @@ class GenerationTesterMixin:
self.skipTest("Won't fix: old model with different cache format") self.skipTest("Won't fix: old model with different cache format")
if any( if any(
model_name in model_class.__name__.lower() model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"] for model_name in [
"bigbirdpegasus",
"led",
"mega",
"speech2text",
"git",
"prophetnet",
"seamlessm4t",
"clvp",
]
): ):
self.skipTest("May fix in the future: need model-specific fixes") self.skipTest("May fix in the future: need model-specific fixes")
# This for loop is a naive and temporary effort to make the test less flaky. # enable cache
failed = 0 config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)
for i in range(10):
# enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")
config.use_cache = True # NOTE: assisted generation only works with cache on at the moment.
config.is_decoder = True if not hasattr(config, "use_cache"):
model = model_class(config).to(torch_device).eval() self.skipTest("This model doesn't support caching")
output_greedy = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_beams=1,
do_sample=False,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
# Note: with assisted generate, if the same model is used as assistant, then all assistant tokens will
# be correct
output_assisted = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_beams=1,
do_sample=False,
assistant_model=model,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
try: config.use_cache = True
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist()) config.is_decoder = True
model = model_class(config).to(torch_device).eval()
# Sets assisted generation arguments such that:
# a) no EOS is generated, to ensure generation doesn't break early
# b) the assistant model always generates two tokens when it is called, to ensure the input preparation of
# the assistant model is correct
# c) there are at least two forward passes in the main model, to ensure the input preparation of
# the main model is correct
generation_kwargs = {
"eos_token_id": -1, # see a)
"max_new_tokens": 4, # see c)
"num_beams": 1,
"do_sample": False,
"output_scores": True,
"output_hidden_states": True,
"output_attentions": True,
"return_dict_in_generate": True,
}
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
for output in (output_greedy, output_assisted): assistant_model = model
self._check_outputs(output, input_ids, model.config, use_cache=True) assistant_model.generation_config.num_assistant_tokens = 2 # see b)
except AssertionError: assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
failed += 1 generation_kwargs.update({"assistant_model": assistant_model})
if failed > 1: output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
for output in (output_greedy, output_assisted): # The two outputs must match and their shape must be as expected
self._check_outputs(output, input_ids, model.config, use_cache=True) self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
for output in (output_greedy, output_assisted):
self._check_outputs(output, input_ids, model.config, use_cache=True)
@unittest.skip("Failing for a lot of models du to attention mask size missmatch. Works well when standalone.")
def test_assisted_decoding_sample(self): def test_assisted_decoding_sample(self):
# Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the # In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
# exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking). # match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
# different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535).
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format") self.skipTest("Won't fix: old model with different cache format")
if any( if any(
model_name in model_class.__name__.lower() model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet", "seamlessm4t"] for model_name in [
"bigbirdpegasus",
"led",
"mega",
"speech2text",
"git",
"prophetnet",
"seamlessm4t",
"clvp",
]
): ):
self.skipTest("May fix in the future: need model-specific fixes") self.skipTest("May fix in the future: need model-specific fixes")
# enable cache # enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment. # NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
...@@ -1599,18 +1609,27 @@ class GenerationTesterMixin: ...@@ -1599,18 +1609,27 @@ class GenerationTesterMixin:
config.use_cache = True config.use_cache = True
config.is_decoder = True config.is_decoder = True
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_assisted = model.generate( # Sets assisted generation arguments such that:
input_ids, # a) no EOS is generated, to ensure generation doesn't break early
attention_mask=attention_mask, # b) the assistant model always generates two tokens when it is called, to ensure the input preparation of
max_length=max_length, # the assistant model is correct
num_beams=1, # c) there are at least two forward passes in the main model, to ensure the input preparation of
do_sample=True, # the main model is correct
assistant_model=model, # triggers assisted decoding assistant_model = model
output_scores=True, assistant_model.generation_config.num_assistant_tokens = 2 # see b)
output_hidden_states=True, assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
output_attentions=True, generation_kwargs = {
return_dict_in_generate=True, "eos_token_id": -1, # see a)
) "max_new_tokens": 4, # see c)
"num_beams": 1,
"do_sample": True,
"assistant_model": assistant_model,
"output_scores": True,
"output_hidden_states": True,
"output_attentions": True,
"return_dict_in_generate": True,
}
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
self._check_outputs(output_assisted, input_ids, model.config, use_cache=True) self._check_outputs(output_assisted, input_ids, model.config, use_cache=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment