Unverified Commit a684c012 authored by Yan Ma's avatar Yan Ma Committed by GitHub
Browse files

[bugfix] fix MHA for models like OpenGVLab/InternVL3_5-38B (#25146)


Signed-off-by: default avatarYan Ma <yan.ma@intel.com>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent f2718d29
...@@ -430,9 +430,11 @@ class MultiHeadAttention(nn.Module): ...@@ -430,9 +430,11 @@ class MultiHeadAttention(nn.Module):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
"""Input shape: batch_size x seq_len x hidden_size""" """Input shape:
# TODO(Isotr0py): Use existing backend implementations and support FA3 (batch_size x seq_len x hidden_size) or
bsz, q_len, _ = query.size() (batch_size x seq_len x num_heads x head_size)
"""
bsz, q_len = query.size()[:2]
kv_len = key.size(1) kv_len = key.size(1)
query = query.view(bsz, q_len, self.num_heads, self.head_size) query = query.view(bsz, q_len, self.num_heads, self.head_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment