Unverified Commit b8378b65 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`Llava` / `Vip-Llava`] Add SDPA into llava (#28107)

add SDPA into llava
parent e6dcf8ab
......@@ -155,6 +155,14 @@ class LlavaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def _supports_sdpa(self):
"""
Retrieve language_model's attribute to check whether the model supports
SDPA or not.
"""
return self.language_model._supports_sdpa
LLAVA_INPUTS_DOCSTRING = r"""
Args:
......
......@@ -162,6 +162,14 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def _supports_sdpa(self):
"""
Retrieve language_model's attribute to check whether the model supports
SDPA or not.
"""
return self.language_model._supports_sdpa
VIPLLAVA_INPUTS_DOCSTRING = r"""
Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment