"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "986526a0e4f5ab803581074e9e4069c3edcff1dc"
Unverified Commit 43ee5858 authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Fix MusicGen SDPA (#31208)

* fix sdpa musicgen

* make style

* remove copied from statement from Musicgen SDPA
parent 833fc17a
...@@ -545,7 +545,6 @@ class MusicgenFlashAttention2(MusicgenAttention): ...@@ -545,7 +545,6 @@ class MusicgenFlashAttention2(MusicgenAttention):
) )
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->Musicgen
class MusicgenSdpaAttention(MusicgenAttention): class MusicgenSdpaAttention(MusicgenAttention):
def forward( def forward(
self, self,
...@@ -572,6 +571,23 @@ class MusicgenSdpaAttention(MusicgenAttention): ...@@ -572,6 +571,23 @@ class MusicgenSdpaAttention(MusicgenAttention):
output_attentions=output_attentions, output_attentions=output_attentions,
) )
if (
attention_mask is not None
and (attention_mask.mean(dim=[1, 2, 3]) <= torch.finfo(attention_mask.dtype).min).any()
):
logger.warning_once(
'`torch.nn.functional.scaled_dot_product_attention` does not support having an empty attention mask. Falling back to the manual attention implementation. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
"Note that this probably happens because `guidance_scale>1` or because you used `get_unconditional_inputs`. See https://github.com/huggingface/transformers/issues/31189 for more information."
)
return super().forward(
hidden_states,
key_value_states=key_value_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
# if key_value_states are provided this layer is used as a cross-attention layer # if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder # for the decoder
is_cross_attention = key_value_states is not None is_cross_attention = key_value_states is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment