Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6e98f6d8
Unverified
Commit
6e98f6d8
authored
Feb 05, 2026
by
Taeksang Kim
Committed by
GitHub
Feb 04, 2026
Browse files
Implement zero-copy GQA for multimodal and CPU (#33732)
Signed-off-by:
Taeksang Kim
<
ts.kim@hyperaccel.ai
>
parent
2f6d17cb
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
32 deletions
+18
-32
vllm/model_executor/layers/attention/mm_encoder_attention.py
vllm/model_executor/layers/attention/mm_encoder_attention.py
+4
-12
vllm/model_executor/models/molmo2.py
vllm/model_executor/models/molmo2.py
+1
-12
vllm/v1/attention/backends/cpu_attn.py
vllm/v1/attention/backends/cpu_attn.py
+1
-4
vllm/v1/attention/ops/vit_attn_wrappers.py
vllm/v1/attention/ops/vit_attn_wrappers.py
+12
-4
No files found.
vllm/model_executor/layers/attention/mm_encoder_attention.py
View file @
6e98f6d8
...
@@ -80,7 +80,7 @@ class MMEncoderAttention(CustomOp):
...
@@ -80,7 +80,7 @@ class MMEncoderAttention(CustomOp):
def
enabled
(
cls
)
->
bool
:
def
enabled
(
cls
)
->
bool
:
return
True
return
True
def
maybe_reshape
_qkv_to_4d
(
def
view
_qkv_to_4d
(
self
,
self
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
...
@@ -97,11 +97,6 @@ class MMEncoderAttention(CustomOp):
...
@@ -97,11 +97,6 @@ class MMEncoderAttention(CustomOp):
key
=
key
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_size
)
key
=
key
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_size
)
if
(
num_repeat
:
=
self
.
num_queries_per_kv
)
>
1
:
# Handle MQA and GQA
key
=
torch
.
repeat_interleave
(
key
,
num_repeat
,
dim
=
2
)
value
=
torch
.
repeat_interleave
(
value
,
num_repeat
,
dim
=
2
)
return
query
,
key
,
value
return
query
,
key
,
value
def
_forward_sdpa
(
def
_forward_sdpa
(
...
@@ -119,9 +114,7 @@ class MMEncoderAttention(CustomOp):
...
@@ -119,9 +114,7 @@ class MMEncoderAttention(CustomOp):
kv_len
=
key
.
size
(
1
)
kv_len
=
key
.
size
(
1
)
is_reshaped
=
query
.
dim
()
!=
4
is_reshaped
=
query
.
dim
()
!=
4
query
,
key
,
value
=
self
.
maybe_reshape_qkv_to_4d
(
query
,
key
,
value
=
self
.
view_qkv_to_4d
(
query
,
key
,
value
,
bsz
,
q_len
,
kv_len
)
query
,
key
,
value
,
bsz
,
q_len
,
kv_len
)
output
=
vit_torch_sdpa_wrapper
(
output
=
vit_torch_sdpa_wrapper
(
q
=
query
,
q
=
query
,
...
@@ -129,6 +122,7 @@ class MMEncoderAttention(CustomOp):
...
@@ -129,6 +122,7 @@ class MMEncoderAttention(CustomOp):
v
=
value
,
v
=
value
,
scale
=
self
.
scale
,
scale
=
self
.
scale
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
enable_gqa
=
self
.
num_heads
>
self
.
num_kv_heads
,
)
)
if
is_reshaped
:
if
is_reshaped
:
output
=
output
.
reshape
(
bsz
,
q_len
,
-
1
)
output
=
output
.
reshape
(
bsz
,
q_len
,
-
1
)
...
@@ -154,9 +148,7 @@ class MMEncoderAttention(CustomOp):
...
@@ -154,9 +148,7 @@ class MMEncoderAttention(CustomOp):
kv_len
=
key
.
size
(
1
)
kv_len
=
key
.
size
(
1
)
is_reshaped
=
query
.
dim
()
!=
4
is_reshaped
=
query
.
dim
()
!=
4
query
,
key
,
value
=
self
.
maybe_reshape_qkv_to_4d
(
query
,
key
,
value
=
self
.
view_qkv_to_4d
(
query
,
key
,
value
,
bsz
,
q_len
,
kv_len
)
query
,
key
,
value
,
bsz
,
q_len
,
kv_len
)
output
=
vit_flash_attn_wrapper
(
output
=
vit_flash_attn_wrapper
(
q
=
query
,
q
=
query
,
...
...
vllm/model_executor/models/molmo2.py
View file @
6e98f6d8
...
@@ -628,18 +628,6 @@ class ImagePoolingAttention(nn.Module):
...
@@ -628,18 +628,6 @@ class ImagePoolingAttention(nn.Module):
key
=
key
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_dim
)
key
=
key
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_dim
)
value
=
value
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_dim
)
value
=
value
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_dim
)
if
self
.
num_heads
!=
self
.
num_kv_heads
:
key
=
torch
.
repeat_interleave
(
key
,
self
.
num_heads
//
self
.
num_kv_heads
,
dim
=
2
,
)
value
=
torch
.
repeat_interleave
(
value
,
self
.
num_heads
//
self
.
num_kv_heads
,
dim
=
2
,
)
query
,
key
,
value
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query
,
key
,
value
))
query
,
key
,
value
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query
,
key
,
value
))
out
=
F
.
scaled_dot_product_attention
(
out
=
F
.
scaled_dot_product_attention
(
...
@@ -648,6 +636,7 @@ class ImagePoolingAttention(nn.Module):
...
@@ -648,6 +636,7 @@ class ImagePoolingAttention(nn.Module):
value
,
value
,
attn_mask
=
attn_mask
,
attn_mask
=
attn_mask
,
is_causal
=
False
,
is_causal
=
False
,
enable_gqa
=
self
.
num_heads
>
self
.
num_kv_heads
,
).
transpose
(
1
,
2
)
).
transpose
(
1
,
2
)
return
out
.
reshape
(
bsz
,
q_len
,
-
1
)
return
out
.
reshape
(
bsz
,
q_len
,
-
1
)
...
...
vllm/v1/attention/backends/cpu_attn.py
View file @
6e98f6d8
...
@@ -398,10 +398,6 @@ class CPUAttentionBackendImpl(AttentionImpl):
...
@@ -398,10 +398,6 @@ class CPUAttentionBackendImpl(AttentionImpl):
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
value
=
value
.
movedim
(
0
,
value
.
dim
()
-
2
)
value
=
value
.
movedim
(
0
,
value
.
dim
()
-
2
)
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=-
3
)
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=-
3
)
causal_attn
=
attn_type
==
AttentionType
.
DECODER
causal_attn
=
attn_type
==
AttentionType
.
DECODER
sdpa_start_loc
=
attn_metadata
.
sdpa_start_loc
.
numpy
()
# type: ignore
sdpa_start_loc
=
attn_metadata
.
sdpa_start_loc
.
numpy
()
# type: ignore
...
@@ -418,6 +414,7 @@ class CPUAttentionBackendImpl(AttentionImpl):
...
@@ -418,6 +414,7 @@ class CPUAttentionBackendImpl(AttentionImpl):
dropout_p
=
0.0
,
dropout_p
=
0.0
,
is_causal
=
causal_attn
and
mask
is
None
,
is_causal
=
causal_attn
and
mask
is
None
,
scale
=
self
.
scale
,
scale
=
self
.
scale
,
enable_gqa
=
self
.
num_heads
>
self
.
num_kv_heads
,
)
)
.
squeeze
(
0
)
.
squeeze
(
0
)
.
movedim
(
query
.
dim
()
-
2
,
0
)
.
movedim
(
query
.
dim
()
-
2
,
0
)
...
...
vllm/v1/attention/ops/vit_attn_wrappers.py
View file @
6e98f6d8
...
@@ -115,13 +115,16 @@ def apply_sdpa(
...
@@ -115,13 +115,16 @@ def apply_sdpa(
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
scale
:
float
|
None
=
None
,
scale
:
float
|
None
=
None
,
enable_gqa
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Input shape:
Input shape:
(batch_size x seq_len x num_heads x head_size)
(batch_size x seq_len x num_heads x head_size)
"""
"""
q
,
k
,
v
=
(
einops
.
rearrange
(
x
,
"b s h d -> b h s d"
)
for
x
in
[
q
,
k
,
v
])
q
,
k
,
v
=
(
einops
.
rearrange
(
x
,
"b s h d -> b h s d"
)
for
x
in
[
q
,
k
,
v
])
output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
dropout_p
=
0.0
,
scale
=
scale
)
output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
dropout_p
=
0.0
,
scale
=
scale
,
enable_gqa
=
enable_gqa
)
output
=
einops
.
rearrange
(
output
,
"b h s d -> b s h d "
)
output
=
einops
.
rearrange
(
output
,
"b h s d -> b s h d "
)
return
output
return
output
...
@@ -134,6 +137,7 @@ def torch_sdpa_wrapper(
...
@@ -134,6 +137,7 @@ def torch_sdpa_wrapper(
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
scale
:
float
|
None
=
None
,
scale
:
float
|
None
=
None
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
enable_gqa
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Never remove the contiguous logic for ROCm
# Never remove the contiguous logic for ROCm
# Without it, hallucinations occur with the backend
# Without it, hallucinations occur with the backend
...
@@ -143,7 +147,7 @@ def torch_sdpa_wrapper(
...
@@ -143,7 +147,7 @@ def torch_sdpa_wrapper(
v
=
v
.
contiguous
()
v
=
v
.
contiguous
()
if
cu_seqlens
is
None
:
if
cu_seqlens
is
None
:
return
apply_sdpa
(
q
,
k
,
v
,
scale
=
scale
)
return
apply_sdpa
(
q
,
k
,
v
,
scale
=
scale
,
enable_gqa
=
enable_gqa
)
outputs
=
[]
outputs
=
[]
...
@@ -152,7 +156,7 @@ def torch_sdpa_wrapper(
...
@@ -152,7 +156,7 @@ def torch_sdpa_wrapper(
k_chunks
=
torch
.
split
(
k
,
lens
,
dim
=
1
)
k_chunks
=
torch
.
split
(
k
,
lens
,
dim
=
1
)
v_chunks
=
torch
.
split
(
v
,
lens
,
dim
=
1
)
v_chunks
=
torch
.
split
(
v
,
lens
,
dim
=
1
)
for
q_i
,
k_i
,
v_i
in
zip
(
q_chunks
,
k_chunks
,
v_chunks
):
for
q_i
,
k_i
,
v_i
in
zip
(
q_chunks
,
k_chunks
,
v_chunks
):
output_i
=
apply_sdpa
(
q_i
,
k_i
,
v_i
,
scale
=
scale
)
output_i
=
apply_sdpa
(
q_i
,
k_i
,
v_i
,
scale
=
scale
,
enable_gqa
=
enable_gqa
)
outputs
.
append
(
output_i
)
outputs
.
append
(
output_i
)
context_layer
=
torch
.
cat
(
outputs
,
dim
=
1
)
context_layer
=
torch
.
cat
(
outputs
,
dim
=
1
)
return
context_layer
return
context_layer
...
@@ -164,6 +168,7 @@ def torch_sdpa_wrapper_fake(
...
@@ -164,6 +168,7 @@ def torch_sdpa_wrapper_fake(
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
scale
:
float
|
None
,
scale
:
float
|
None
,
cu_seqlens
:
torch
.
Tensor
|
None
,
cu_seqlens
:
torch
.
Tensor
|
None
,
enable_gqa
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
q
)
return
torch
.
empty_like
(
q
)
...
@@ -181,5 +186,8 @@ def vit_torch_sdpa_wrapper(
...
@@ -181,5 +186,8 @@ def vit_torch_sdpa_wrapper(
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
scale
:
float
|
None
=
None
,
scale
:
float
|
None
=
None
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
enable_gqa
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
torch
.
ops
.
vllm
.
torch_sdpa_wrapper
(
q
,
k
,
v
,
scale
,
cu_seqlens
)
return
torch
.
ops
.
vllm
.
torch_sdpa_wrapper
(
q
,
k
,
v
,
scale
,
cu_seqlens
,
enable_gqa
=
enable_gqa
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment