Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bb01f291
Unverified
Commit
bb01f291
authored
Oct 23, 2024
by
Michael Goin
Committed by
GitHub
Oct 24, 2024
Browse files
[Bugfix][Model] Fix Mllama SDPA illegal memory access for batched multi-image (#9626)
Signed-off-by:
mgoin
<
michael@neuralmagic.com
>
parent
b548d7a5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
3 deletions
+5
-3
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+5
-3
No files found.
vllm/model_executor/models/mllama.py
View file @
bb01f291
...
@@ -795,17 +795,19 @@ class MllamaTextCrossAttention(nn.Module):
...
@@ -795,17 +795,19 @@ class MllamaTextCrossAttention(nn.Module):
kv_len
=
k
.
shape
[
0
]
kv_len
=
k
.
shape
[
0
]
q
=
q
.
transpose
(
0
,
1
).
view
(
self
.
num_local_key_value_heads
,
q
=
q
.
transpose
(
0
,
1
).
view
(
self
.
num_local_key_value_heads
,
self
.
num_key_value_groups
,
q_len
,
self
.
num_key_value_groups
,
q_len
,
self
.
head_dim
)
self
.
head_dim
)
.
contiguous
()
k
=
k
.
transpose
(
0
,
k
=
k
.
transpose
(
0
,
1
)[:,
1
)[:,
None
,
:,
:].
expand
(
self
.
num_local_key_value_heads
,
None
,
:,
:].
expand
(
self
.
num_local_key_value_heads
,
self
.
num_key_value_groups
,
self
.
num_key_value_groups
,
kv_len
,
self
.
head_dim
)
kv_len
,
self
.
head_dim
).
contiguous
()
v
=
v
.
transpose
(
0
,
v
=
v
.
transpose
(
0
,
1
)[:,
1
)[:,
None
,
:,
:].
expand
(
self
.
num_local_key_value_heads
,
None
,
:,
:].
expand
(
self
.
num_local_key_value_heads
,
self
.
num_key_value_groups
,
self
.
num_key_value_groups
,
kv_len
,
self
.
head_dim
)
kv_len
,
self
.
head_dim
).
contiguous
()
attention_mask
=
attention_mask
.
view
(
1
,
1
,
q_len
,
kv_len
)
attention_mask
=
attention_mask
.
view
(
1
,
1
,
q_len
,
kv_len
)
output
=
F
.
scaled_dot_product_attention
(
q
,
output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
k
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment