Unverified Commit 453e7488 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

LLaVa: add cache class attribute (#32278)

cache class flag
parent 14ee2326
......@@ -126,6 +126,7 @@ class LlavaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
def _init_weights(self, module):
# important: this ported version of Llava isn't meant for training from scratch - only
......
......@@ -232,6 +232,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlavaNextVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
def _init_weights(self, module):
# important: this ported version of LlavaNext isn't meant for training from scratch - only
......
......@@ -272,6 +272,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlavaNextVideoVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
def _init_weights(self, module):
# important: this ported version of LlavaNextVideo isn't meant for training from scratch - only
......
......@@ -127,6 +127,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = False
_supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module):
# important: this ported version of PaliGemmaisn't meant for training from scratch - only
......
......@@ -126,6 +126,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["VideoLlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
def _init_weights(self, module):
std = (
......
......@@ -135,6 +135,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["VipLlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
def _init_weights(self, module):
# important: this ported version of VipLlava isn't meant for training from scratch - only
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment