Unverified Commit 5e89b335 authored by hoshi-hiyouga's avatar hoshi-hiyouga Committed by GitHub
Browse files

Fix Gemma2 4d attention mask (#31674)



Update modeling_gemma2.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 0142aab7
...@@ -629,11 +629,13 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -629,11 +629,13 @@ class Gemma2DecoderLayer(nn.Module):
if ( if (
self.config._attn_implementation != "flash_attention_2" and self.is_sliding and attention_mask is not None self.config._attn_implementation != "flash_attention_2" and self.is_sliding and attention_mask is not None
): # efficient SDPA and no padding ): # efficient SDPA and no padding
attention_mask = attention_mask * torch.tril( min_dtype = torch.finfo(hidden_states.dtype).min
torch.ones_like(attention_mask), diagonal=-self.sliding_window sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
) )
if attention_mask.shape[1] <= 1: # when decoding attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
attention_mask = attention_mask[:, -self.sliding_window :] if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
residual = hidden_states residual = hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment