Unverified Commit 453e7488 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

LLaVa: add cache class attribute (#32278)

cache class flag
parent 14ee2326
...@@ -126,6 +126,7 @@ class LlavaPreTrainedModel(PreTrainedModel): ...@@ -126,6 +126,7 @@ class LlavaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlavaVisionAttention"] _no_split_modules = ["LlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
# important: this ported version of Llava isn't meant for training from scratch - only # important: this ported version of Llava isn't meant for training from scratch - only
......
...@@ -232,6 +232,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel): ...@@ -232,6 +232,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlavaNextVisionAttention"] _no_split_modules = ["LlavaNextVisionAttention"]
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
# important: this ported version of LlavaNext isn't meant for training from scratch - only # important: this ported version of LlavaNext isn't meant for training from scratch - only
......
...@@ -272,6 +272,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel): ...@@ -272,6 +272,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlavaNextVideoVisionAttention"] _no_split_modules = ["LlavaNextVideoVisionAttention"]
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
# important: this ported version of LlavaNextVideo isn't meant for training from scratch - only # important: this ported version of LlavaNextVideo isn't meant for training from scratch - only
......
...@@ -127,6 +127,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): ...@@ -127,6 +127,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = False _supports_flash_attn_2 = False
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
# important: this ported version of PaliGemmaisn't meant for training from scratch - only # important: this ported version of PaliGemmaisn't meant for training from scratch - only
......
...@@ -126,6 +126,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel): ...@@ -126,6 +126,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["VideoLlavaVisionAttention"] _no_split_modules = ["VideoLlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
std = ( std = (
......
...@@ -135,6 +135,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel): ...@@ -135,6 +135,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["VipLlavaVisionAttention"] _no_split_modules = ["VipLlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_cache_class = True
def _init_weights(self, module): def _init_weights(self, module):
# important: this ported version of VipLlava isn't meant for training from scratch - only # important: this ported version of VipLlava isn't meant for training from scratch - only
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment