Unverified Commit e2b6df79 authored by Adilzhan Ismailov's avatar Adilzhan Ismailov Committed by GitHub
Browse files

[LLaVa] Add past_key_values to _skip_keys_device_placement to fix multi-GPU dispatch (#28051)

Add past_key_values to _skip_keys_device_placement  for LLaVa
parent deb72cb6
...@@ -130,6 +130,7 @@ class LlavaPreTrainedModel(PreTrainedModel): ...@@ -130,6 +130,7 @@ class LlavaPreTrainedModel(PreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["LlavaVisionAttention"] _no_split_modules = ["LlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
def _init_weights(self, module): def _init_weights(self, module):
......
...@@ -137,6 +137,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel): ...@@ -137,6 +137,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["VipLlavaVisionAttention"] _no_split_modules = ["VipLlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
def _init_weights(self, module): def _init_weights(self, module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment