"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "101a6cd276d454c6ab07aff3c54e598ff83d537c"
Unverified Commit 9d889f87 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Cache: add new flag to distinguish models that `Cache` but not static cache (#30800)

* jamba cache

* new flag

* generate exception
parent 17cc71e1
...@@ -1616,6 +1616,11 @@ class GenerationMixin: ...@@ -1616,6 +1616,11 @@ class GenerationMixin:
"issue: https://github.com/huggingface/transformers/issues/28981." "issue: https://github.com/huggingface/transformers/issues/28981."
) )
if generation_config.cache_implementation == "static": if generation_config.cache_implementation == "static":
if not self._supports_static_cache:
raise ValueError(
"This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length) model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length)
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
......
...@@ -1280,8 +1280,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1280,8 +1280,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# SDPA support # SDPA support
_supports_sdpa = False _supports_sdpa = False
# Has support for a `Cache` instance as `past_key_values` # Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`?
_supports_cache_class = False _supports_cache_class = False
_supports_static_cache = False
@property @property
def dummy_inputs(self) -> Dict[str, torch.Tensor]: def dummy_inputs(self) -> Dict[str, torch.Tensor]:
......
...@@ -720,6 +720,7 @@ class CoherePreTrainedModel(PreTrainedModel): ...@@ -720,6 +720,7 @@ class CoherePreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
......
...@@ -938,6 +938,7 @@ class DbrxPreTrainedModel(PreTrainedModel): ...@@ -938,6 +938,7 @@ class DbrxPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module: nn.Module): def _init_weights(self, module: nn.Module):
std = self.config.initializer_range std = self.config.initializer_range
......
...@@ -703,6 +703,7 @@ class GemmaPreTrainedModel(PreTrainedModel): ...@@ -703,6 +703,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
......
...@@ -1341,6 +1341,7 @@ class Idefics2PreTrainedModel(PreTrainedModel): ...@@ -1341,6 +1341,7 @@ class Idefics2PreTrainedModel(PreTrainedModel):
_no_split_modules = ["Idefics2VisionAttention", "Idefics2MLP", "Idefics2PerceiverLayer", "Idefics2DecoderLayer"] _no_split_modules = ["Idefics2VisionAttention", "Idefics2MLP", "Idefics2PerceiverLayer", "Idefics2DecoderLayer"]
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
# important: this ported version of Idefics2 isn't meant for training from scratch - only # important: this ported version of Idefics2 isn't meant for training from scratch - only
......
...@@ -1261,6 +1261,7 @@ class JambaPreTrainedModel(PreTrainedModel): ...@@ -1261,6 +1261,7 @@ class JambaPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
......
...@@ -799,6 +799,7 @@ class LlamaPreTrainedModel(PreTrainedModel): ...@@ -799,6 +799,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
......
...@@ -810,6 +810,7 @@ class MistralPreTrainedModel(PreTrainedModel): ...@@ -810,6 +810,7 @@ class MistralPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
......
...@@ -989,6 +989,7 @@ class MixtralPreTrainedModel(PreTrainedModel): ...@@ -989,6 +989,7 @@ class MixtralPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
......
...@@ -776,6 +776,7 @@ class OlmoPreTrainedModel(PreTrainedModel): ...@@ -776,6 +776,7 @@ class OlmoPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
......
...@@ -457,6 +457,7 @@ class PersimmonPreTrainedModel(PreTrainedModel): ...@@ -457,6 +457,7 @@ class PersimmonPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["PersimmonDecoderLayer"] _no_split_modules = ["PersimmonDecoderLayer"]
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
......
...@@ -825,6 +825,7 @@ class PhiPreTrainedModel(PreTrainedModel): ...@@ -825,6 +825,7 @@ class PhiPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
......
...@@ -921,6 +921,7 @@ class Phi3PreTrainedModel(PreTrainedModel): ...@@ -921,6 +921,7 @@ class Phi3PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = False _supports_sdpa = False
_supports_cache_class = True
_version = "0.0.5" _version = "0.0.5"
......
...@@ -821,6 +821,7 @@ class Qwen2PreTrainedModel(PreTrainedModel): ...@@ -821,6 +821,7 @@ class Qwen2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
......
...@@ -975,6 +975,7 @@ class Qwen2MoePreTrainedModel(PreTrainedModel): ...@@ -975,6 +975,7 @@ class Qwen2MoePreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
......
...@@ -541,7 +541,6 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel): ...@@ -541,7 +541,6 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = ["cache"] _skip_keys_device_placement = ["cache"]
_supports_flash_attn_2 = False _supports_flash_attn_2 = False
_supports_sdpa = False # we can't compare with eager for now _supports_sdpa = False # we can't compare with eager for now
_supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
std = math.sqrt(self.config.w_init_variance_scale / self.config.conv1d_width) std = math.sqrt(self.config.w_init_variance_scale / self.config.conv1d_width)
......
...@@ -799,6 +799,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel): ...@@ -799,6 +799,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
......
...@@ -4365,7 +4365,7 @@ class ModelTesterMixin: ...@@ -4365,7 +4365,7 @@ class ModelTesterMixin:
self.skipTest("Model architecture has no generative classes, and thus not necessarily supporting 4D masks") self.skipTest("Model architecture has no generative classes, and thus not necessarily supporting 4D masks")
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if not model_class._supports_cache_class: if not model_class._supports_static_cache:
self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks") self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks")
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config).to(device=torch_device, dtype=torch.float32) model = model_class(config).to(device=torch_device, dtype=torch.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment