"...data/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "a55305655e28efa768f35c460c971f40e300c8d1"
Unverified Commit 83259e40 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Mamba: add generative tests (#31478)

parent 7d683f7b
...@@ -1830,6 +1830,12 @@ class GenerationMixin: ...@@ -1830,6 +1830,12 @@ class GenerationMixin:
raise ValueError("assisted generate requires `use_cache=True`") raise ValueError("assisted generate requires `use_cache=True`")
if generation_config.cache_implementation == "static": if generation_config.cache_implementation == "static":
raise ValueError("assisted generate is not supported with `static_cache`") raise ValueError("assisted generate is not supported with `static_cache`")
if self._is_stateful:
# In assisted generation we need the ability to confirm whether the model would pick certain tokens,
# which is not possible with stateful models (they can't reset to a previous subset of generated text)
raise ValueError(
f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}"
)
# 11. Get the candidate generator, given the parameterization # 11. Get the candidate generator, given the parameterization
candidate_generator = self._get_candidate_generator( candidate_generator = self._get_candidate_generator(
...@@ -1867,6 +1873,11 @@ class GenerationMixin: ...@@ -1867,6 +1873,11 @@ class GenerationMixin:
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH: elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
if not model_kwargs["use_cache"]: if not model_kwargs["use_cache"]:
raise ValueError("Contrastive search requires `use_cache=True`") raise ValueError("Contrastive search requires `use_cache=True`")
if self._is_stateful:
# Just like assisted generation, we need to be able to rollback to a previous state (see comment above)
raise ValueError(
f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}"
)
result = self._contrastive_search( result = self._contrastive_search(
input_ids, input_ids,
......
...@@ -1281,6 +1281,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1281,6 +1281,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
is_parallelizable = False is_parallelizable = False
supports_gradient_checkpointing = False supports_gradient_checkpointing = False
_is_stateful = False
# Flash Attention 2 support # Flash Attention 2 support
_supports_flash_attn_2 = False _supports_flash_attn_2 = False
......
...@@ -1266,6 +1266,7 @@ class JambaPreTrainedModel(PreTrainedModel): ...@@ -1266,6 +1266,7 @@ class JambaPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache
_is_stateful = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
......
...@@ -354,6 +354,7 @@ class MambaPreTrainedModel(PreTrainedModel): ...@@ -354,6 +354,7 @@ class MambaPreTrainedModel(PreTrainedModel):
base_model_prefix = "backbone" base_model_prefix = "backbone"
_no_split_modules = ["MambaBlock"] _no_split_modules = ["MambaBlock"]
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_is_stateful = True
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights.""" """Initialize the weights."""
......
...@@ -394,6 +394,7 @@ class RwkvPreTrainedModel(PreTrainedModel): ...@@ -394,6 +394,7 @@ class RwkvPreTrainedModel(PreTrainedModel):
_no_split_modules = ["RwkvBlock"] _no_split_modules = ["RwkvBlock"]
_keep_in_fp32_modules = ["time_decay", "time_first"] _keep_in_fp32_modules = ["time_decay", "time_first"]
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_is_stateful = True
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights.""" """Initialize the weights."""
......
...@@ -102,7 +102,11 @@ class GenerationTesterMixin: ...@@ -102,7 +102,11 @@ class GenerationTesterMixin:
if isinstance(config.eos_token_id, int): if isinstance(config.eos_token_id, int):
config.eos_token_id = [config.eos_token_id] config.eos_token_id = [config.eos_token_id]
config.pad_token_id = config.eos_token_id[0] config.pad_token_id = config.eos_token_id[0]
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
if self.has_attentions:
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
else:
attention_mask = None
# It is important set set the eos_token_id to None to ensure that no sequences # It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated # shorter than `max_length` can be generated
...@@ -437,7 +441,7 @@ class GenerationTesterMixin: ...@@ -437,7 +441,7 @@ class GenerationTesterMixin:
output_scores=True, output_scores=True,
output_logits=True, output_logits=True,
output_hidden_states=True, output_hidden_states=True,
output_attentions=True, output_attentions=self.has_attentions,
return_dict_in_generate=True, return_dict_in_generate=True,
) )
...@@ -471,7 +475,7 @@ class GenerationTesterMixin: ...@@ -471,7 +475,7 @@ class GenerationTesterMixin:
output_scores=True, output_scores=True,
output_logits=True, output_logits=True,
output_hidden_states=True, output_hidden_states=True,
output_attentions=True, output_attentions=self.has_attentions,
return_dict_in_generate=True, return_dict_in_generate=True,
) )
...@@ -529,7 +533,7 @@ class GenerationTesterMixin: ...@@ -529,7 +533,7 @@ class GenerationTesterMixin:
output_scores=True, output_scores=True,
output_logits=True, output_logits=True,
output_hidden_states=True, output_hidden_states=True,
output_attentions=True, output_attentions=self.has_attentions,
return_dict_in_generate=True, return_dict_in_generate=True,
) )
...@@ -595,7 +599,7 @@ class GenerationTesterMixin: ...@@ -595,7 +599,7 @@ class GenerationTesterMixin:
output_scores=True, output_scores=True,
output_logits=True, output_logits=True,
output_hidden_states=True, output_hidden_states=True,
output_attentions=True, output_attentions=self.has_attentions,
return_dict_in_generate=True, return_dict_in_generate=True,
) )
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
...@@ -642,7 +646,7 @@ class GenerationTesterMixin: ...@@ -642,7 +646,7 @@ class GenerationTesterMixin:
output_scores=True, output_scores=True,
output_logits=True, output_logits=True,
output_hidden_states=True, output_hidden_states=True,
output_attentions=True, output_attentions=self.has_attentions,
return_dict_in_generate=True, return_dict_in_generate=True,
) )
...@@ -733,7 +737,7 @@ class GenerationTesterMixin: ...@@ -733,7 +737,7 @@ class GenerationTesterMixin:
output_scores=True, output_scores=True,
output_logits=True, output_logits=True,
output_hidden_states=True, output_hidden_states=True,
output_attentions=True, output_attentions=self.has_attentions,
return_dict_in_generate=True, return_dict_in_generate=True,
) )
...@@ -834,7 +838,7 @@ class GenerationTesterMixin: ...@@ -834,7 +838,7 @@ class GenerationTesterMixin:
output_scores=True, output_scores=True,
output_logits=True, output_logits=True,
output_hidden_states=True, output_hidden_states=True,
output_attentions=True, output_attentions=self.has_attentions,
return_dict_in_generate=True, return_dict_in_generate=True,
) )
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
...@@ -952,7 +956,7 @@ class GenerationTesterMixin: ...@@ -952,7 +956,7 @@ class GenerationTesterMixin:
output_scores=True, output_scores=True,
output_logits=True, output_logits=True,
output_hidden_states=True, output_hidden_states=True,
output_attentions=True, output_attentions=self.has_attentions,
return_dict_in_generate=True, return_dict_in_generate=True,
) )
...@@ -973,6 +977,9 @@ class GenerationTesterMixin: ...@@ -973,6 +977,9 @@ class GenerationTesterMixin:
def test_contrastive_generate(self): def test_contrastive_generate(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest("Stateful models don't support contrastive search generation")
# won't fix: FSMT and Reformer have a different cache variable type (and format). # won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format") self.skipTest("Won't fix: old model with different cache format")
...@@ -997,6 +1004,9 @@ class GenerationTesterMixin: ...@@ -997,6 +1004,9 @@ class GenerationTesterMixin:
def test_contrastive_generate_dict_outputs_use_cache(self): def test_contrastive_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest("Stateful models don't support contrastive search generation")
# won't fix: FSMT and Reformer have a different cache variable type (and format). # won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format") self.skipTest("Won't fix: old model with different cache format")
...@@ -1017,7 +1027,7 @@ class GenerationTesterMixin: ...@@ -1017,7 +1027,7 @@ class GenerationTesterMixin:
output_scores=True, output_scores=True,
output_logits=True, output_logits=True,
output_hidden_states=True, output_hidden_states=True,
output_attentions=True, output_attentions=self.has_attentions,
return_dict_in_generate=True, return_dict_in_generate=True,
) )
...@@ -1030,9 +1040,12 @@ class GenerationTesterMixin: ...@@ -1030,9 +1040,12 @@ class GenerationTesterMixin:
def test_contrastive_generate_low_memory(self): def test_contrastive_generate_low_memory(self):
# Check that choosing 'low_memory' does not change the model output # Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest("Stateful models don't support contrastive search generation")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]):
self.skipTest("Won't fix: old model with different cache format") self.skipTest("Won't fix: old model with different cache format")
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode", "jamba"]): if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]):
self.skipTest("TODO: fix me") self.skipTest("TODO: fix me")
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1) config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
...@@ -1069,6 +1082,8 @@ class GenerationTesterMixin: ...@@ -1069,6 +1082,8 @@ class GenerationTesterMixin:
def test_beam_search_low_memory(self): def test_beam_search_low_memory(self):
# Check that choosing 'low_memory' does not change the model output # Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest("May fix in the future: need custom cache handling")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format") self.skipTest("Won't fix: old model with different cache format")
if any( if any(
...@@ -1115,6 +1130,8 @@ class GenerationTesterMixin: ...@@ -1115,6 +1130,8 @@ class GenerationTesterMixin:
# - assisted_decoding does not support `batch_size > 1` # - assisted_decoding does not support `batch_size > 1`
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest("Stateful models don't support assisted generation")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format") self.skipTest("Won't fix: old model with different cache format")
if any( if any(
...@@ -1156,7 +1173,7 @@ class GenerationTesterMixin: ...@@ -1156,7 +1173,7 @@ class GenerationTesterMixin:
"output_scores": True, "output_scores": True,
"output_logits": True, "output_logits": True,
"output_hidden_states": True, "output_hidden_states": True,
"output_attentions": True, "output_attentions": self.has_attentions,
"return_dict_in_generate": True, "return_dict_in_generate": True,
} }
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
...@@ -1184,6 +1201,8 @@ class GenerationTesterMixin: ...@@ -1184,6 +1201,8 @@ class GenerationTesterMixin:
# This test is mostly a copy of test_assisted_decoding_matches_greedy_search # This test is mostly a copy of test_assisted_decoding_matches_greedy_search
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest("Stateful models don't support assisted generation")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format") self.skipTest("Won't fix: old model with different cache format")
if any( if any(
...@@ -1225,7 +1244,7 @@ class GenerationTesterMixin: ...@@ -1225,7 +1244,7 @@ class GenerationTesterMixin:
"output_scores": True, "output_scores": True,
"output_logits": True, "output_logits": True,
"output_hidden_states": True, "output_hidden_states": True,
"output_attentions": True, "output_attentions": self.has_attentions,
"return_dict_in_generate": True, "return_dict_in_generate": True,
} }
...@@ -1244,6 +1263,8 @@ class GenerationTesterMixin: ...@@ -1244,6 +1263,8 @@ class GenerationTesterMixin:
# match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with # match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
# different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535). # different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535).
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest("Stateful models don't support assisted generation")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format") self.skipTest("Won't fix: old model with different cache format")
if any( if any(
...@@ -1289,7 +1310,7 @@ class GenerationTesterMixin: ...@@ -1289,7 +1310,7 @@ class GenerationTesterMixin:
"output_scores": True, "output_scores": True,
"output_logits": True, "output_logits": True,
"output_hidden_states": True, "output_hidden_states": True,
"output_attentions": True, "output_attentions": self.has_attentions,
"return_dict_in_generate": True, "return_dict_in_generate": True,
} }
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
...@@ -1326,7 +1347,7 @@ class GenerationTesterMixin: ...@@ -1326,7 +1347,7 @@ class GenerationTesterMixin:
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
num_beams=1, num_beams=1,
output_attentions=True, output_attentions=self.has_attentions,
return_dict_in_generate=True, return_dict_in_generate=True,
remove_invalid_values=True, remove_invalid_values=True,
**{name: mask}, **{name: mask},
...@@ -1344,6 +1365,10 @@ class GenerationTesterMixin: ...@@ -1344,6 +1365,10 @@ class GenerationTesterMixin:
if len(self.all_generative_model_classes) == 0: if len(self.all_generative_model_classes) == 0:
self.skipTest(reason="No generative architecture available for this model.") self.skipTest(reason="No generative architecture available for this model.")
# - The model must support padding
if not self.has_attentions:
self.skipTest(reason="This model doesn't support padding.")
# - The model must be a decoder-only architecture (encoder-based architectures use right-padding) # - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
decoder_only_classes = [] decoder_only_classes = []
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
...@@ -1704,30 +1729,31 @@ class GenerationTesterMixin: ...@@ -1704,30 +1729,31 @@ class GenerationTesterMixin:
self._check_logits(num_sequences_in_output, output.logits, config=config) self._check_logits(num_sequences_in_output, output.logits, config=config)
# Attentions # Attentions
if config.is_encoder_decoder: if self.has_attentions:
# encoder if config.is_encoder_decoder:
self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length) # encoder
# decoder self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length)
self._check_attentions_for_generate( # decoder
num_sequences_in_output, self._check_attentions_for_generate(
output.decoder_attentions, num_sequences_in_output,
min_length=1, output.decoder_attentions,
max_length=output.sequences.shape[-1], min_length=1,
config=config, max_length=output.sequences.shape[-1],
use_cache=use_cache, config=config,
) use_cache=use_cache,
else: )
# if use_cache first input is equal to no use_cache, so skip here else:
attentions = output.attentions if not use_cache else output.attentions[1:] # if use_cache first input is equal to no use_cache, so skip here
min_length = seq_length if not use_cache else seq_length + 1 attentions = output.attentions if not use_cache else output.attentions[1:]
self._check_attentions_for_generate( min_length = seq_length if not use_cache else seq_length + 1
num_sequences_in_output, self._check_attentions_for_generate(
attentions=attentions, num_sequences_in_output,
min_length=min_length, attentions=attentions,
max_length=output.sequences.shape[-1], min_length=min_length,
config=config, max_length=output.sequences.shape[-1],
use_cache=use_cache, config=config,
) use_cache=use_cache,
)
# Hidden States # Hidden States
if config.is_encoder_decoder: if config.is_encoder_decoder:
...@@ -1763,7 +1789,7 @@ class GenerationTesterMixin: ...@@ -1763,7 +1789,7 @@ class GenerationTesterMixin:
# 2. Some old models still return `output.past_key_values` even without `use_cache=True` # 2. Some old models still return `output.past_key_values` even without `use_cache=True`
# 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is # 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is
# complete # complete
models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba") models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba")
has_standard_cache = not any( has_standard_cache = not any(
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
) )
......
...@@ -503,10 +503,6 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -503,10 +503,6 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# They should result in very similar logits # They should result in very similar logits
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3)) self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3))
@unittest.skip("Jamba has its own special cache type") # FIXME: @gante
def test_assisted_decoding_matches_greedy_search_0_random(self):
pass
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@require_bitsandbytes @require_bitsandbytes
......
...@@ -250,6 +250,8 @@ class MambaModelTester: ...@@ -250,6 +250,8 @@ class MambaModelTester:
@require_torch @require_torch
class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (MambaModel, MambaForCausalLM) if is_torch_available() else () all_model_classes = (MambaModel, MambaForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (MambaForCausalLM,) if is_torch_available() else ()
has_attentions = False # Mamba does not support attentions
fx_compatible = False # FIXME let's try to support this @ArthurZucker fx_compatible = False # FIXME let's try to support this @ArthurZucker
test_torchscript = False # FIXME let's try to support this @ArthurZucker test_torchscript = False # FIXME let's try to support this @ArthurZucker
test_missing_keys = False test_missing_keys = False
...@@ -292,10 +294,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -292,10 +294,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
@unittest.skip("No attention in mamba")
def test_retain_grad_hidden_states_attentions(self):
pass
@require_torch_multi_gpu @require_torch_multi_gpu
def test_multi_gpu_data_parallel_forward(self): def test_multi_gpu_data_parallel_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -364,14 +362,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -364,14 +362,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# check if it's a ones like # check if it's a ones like
self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5)) self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))
@unittest.skip("Mamba does not use attention")
def test_attention_outputs(self):
r"""
Overriding the test_attention_outputs test as the attention outputs of Mamba are different from other models
it has a shape `batch_size, seq_len, hidden_size`.
"""
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
model = MambaModel.from_pretrained("hf-internal-testing/mamba-130m") model = MambaModel.from_pretrained("hf-internal-testing/mamba-130m")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment