Unverified Commit 22f888b3 authored by Joshua Lochner's avatar Joshua Lochner Committed by GitHub
Browse files

[mistral] Fix FA2 attention reshape for Mistral Nemo (#32065)

* [mistral] Fix FA2 attention reshape

* [run-slow] mistral
parent cd48553f
...@@ -387,7 +387,7 @@ class MistralFlashAttention2(MistralAttention): ...@@ -387,7 +387,7 @@ class MistralFlashAttention2(MistralAttention):
is_causal=self.is_causal, is_causal=self.is_causal,
) )
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
if not output_attentions: if not output_attentions:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment