Unverified Commit d8f8a9cd authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

CI: more models wo cache support (#30780)

parent 5ad960f1
...@@ -810,7 +810,6 @@ class MistralPreTrainedModel(PreTrainedModel): ...@@ -810,7 +810,6 @@ class MistralPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
......
...@@ -989,7 +989,6 @@ class MixtralPreTrainedModel(PreTrainedModel): ...@@ -989,7 +989,6 @@ class MixtralPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
......
...@@ -457,7 +457,6 @@ class PersimmonPreTrainedModel(PreTrainedModel): ...@@ -457,7 +457,6 @@ class PersimmonPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["PersimmonDecoderLayer"] _no_split_modules = ["PersimmonDecoderLayer"]
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
......
...@@ -825,7 +825,6 @@ class PhiPreTrainedModel(PreTrainedModel): ...@@ -825,7 +825,6 @@ class PhiPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
......
...@@ -921,7 +921,6 @@ class Phi3PreTrainedModel(PreTrainedModel): ...@@ -921,7 +921,6 @@ class Phi3PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = False _supports_sdpa = False
_supports_cache_class = True
_version = "0.0.5" _version = "0.0.5"
......
...@@ -975,7 +975,6 @@ class Qwen2MoePreTrainedModel(PreTrainedModel): ...@@ -975,7 +975,6 @@ class Qwen2MoePreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
......
...@@ -799,7 +799,6 @@ class Starcoder2PreTrainedModel(PreTrainedModel): ...@@ -799,7 +799,6 @@ class Starcoder2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment