Unverified Commit bb01f291 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Bugfix][Model] Fix Mllama SDPA illegal memory access for batched multi-image (#9626)


Signed-off-by: default avatarmgoin <michael@neuralmagic.com>
parent b548d7a5
......@@ -795,17 +795,19 @@ class MllamaTextCrossAttention(nn.Module):
kv_len = k.shape[0]
q = q.transpose(0, 1).view(self.num_local_key_value_heads,
self.num_key_value_groups, q_len,
self.head_dim)
self.head_dim).contiguous()
k = k.transpose(0,
1)[:,
None, :, :].expand(self.num_local_key_value_heads,
self.num_key_value_groups,
kv_len, self.head_dim)
kv_len,
self.head_dim).contiguous()
v = v.transpose(0,
1)[:,
None, :, :].expand(self.num_local_key_value_heads,
self.num_key_value_groups,
kv_len, self.head_dim)
kv_len,
self.head_dim).contiguous()
attention_mask = attention_mask.view(1, 1, q_len, kv_len)
output = F.scaled_dot_product_attention(q,
k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment