Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2972a054
Unverified
Commit
2972a054
authored
Jan 08, 2026
by
Isotr0py
Committed by
GitHub
Jan 08, 2026
Browse files
[MM Encoder]: Make MMEncoderAttention's `scale` takes effect properly (#31950)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
5576227b
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
32 additions
and
8 deletions
+32
-8
vllm/attention/layers/mm_encoder_attention.py
vllm/attention/layers/mm_encoder_attention.py
+2
-0
vllm/attention/ops/vit_attn_wrappers.py
vllm/attention/ops/vit_attn_wrappers.py
+21
-8
vllm/model_executor/models/dots_ocr.py
vllm/model_executor/models/dots_ocr.py
+1
-0
vllm/model_executor/models/ernie45_vl.py
vllm/model_executor/models/ernie45_vl.py
+1
-0
vllm/model_executor/models/glm4_1v.py
vllm/model_executor/models/glm4_1v.py
+1
-0
vllm/model_executor/models/glmasr.py
vllm/model_executor/models/glmasr.py
+1
-0
vllm/model_executor/models/isaac.py
vllm/model_executor/models/isaac.py
+1
-0
vllm/model_executor/models/moonvit.py
vllm/model_executor/models/moonvit.py
+1
-0
vllm/model_executor/models/paddleocr_vl.py
vllm/model_executor/models/paddleocr_vl.py
+1
-0
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+1
-0
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+1
-0
No files found.
vllm/attention/layers/mm_encoder_attention.py
View file @
2972a054
...
@@ -133,6 +133,7 @@ class MMEncoderAttention(CustomOp):
...
@@ -133,6 +133,7 @@ class MMEncoderAttention(CustomOp):
q
=
query
,
q
=
query
,
k
=
key
,
k
=
key
,
v
=
value
,
v
=
value
,
scale
=
self
.
scale
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
)
)
if
is_reshaped
:
if
is_reshaped
:
...
@@ -167,6 +168,7 @@ class MMEncoderAttention(CustomOp):
...
@@ -167,6 +168,7 @@ class MMEncoderAttention(CustomOp):
q
=
query
,
q
=
query
,
k
=
key
,
k
=
key
,
v
=
value
,
v
=
value
,
scale
=
self
.
scale
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
max_seqlen
=
max_seqlen
,
batch_size
=
bsz
,
batch_size
=
bsz
,
...
...
vllm/attention/ops/vit_attn_wrappers.py
View file @
2972a054
...
@@ -27,6 +27,7 @@ def flash_attn_maxseqlen_wrapper(
...
@@ -27,6 +27,7 @@ def flash_attn_maxseqlen_wrapper(
batch_size
:
int
,
batch_size
:
int
,
is_rocm_aiter
:
bool
,
is_rocm_aiter
:
bool
,
fa_version
:
int
|
None
,
fa_version
:
int
|
None
,
scale
:
float
|
None
=
None
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
max_seqlen
:
torch
.
Tensor
|
None
=
None
,
max_seqlen
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -57,6 +58,7 @@ def flash_attn_maxseqlen_wrapper(
...
@@ -57,6 +58,7 @@ def flash_attn_maxseqlen_wrapper(
max_seqlen_k
=
max_seqlen
,
max_seqlen_k
=
max_seqlen
,
dropout_p
=
0.0
,
dropout_p
=
0.0
,
causal
=
False
,
causal
=
False
,
softmax_scale
=
scale
,
**
kwargs
,
**
kwargs
,
)
)
context_layer
=
einops
.
rearrange
(
output
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
context_layer
=
einops
.
rearrange
(
output
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
...
@@ -67,11 +69,12 @@ def flash_attn_maxseqlen_wrapper_fake(
...
@@ -67,11 +69,12 @@ def flash_attn_maxseqlen_wrapper_fake(
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
max_seqlen
:
torch
.
Tensor
,
batch_size
:
int
,
batch_size
:
int
,
is_rocm_aiter
:
bool
,
is_rocm_aiter
:
bool
,
fa_version
:
int
|
None
,
fa_version
:
int
|
None
,
scale
:
float
|
None
,
cu_seqlens
:
torch
.
Tensor
|
None
,
max_seqlen
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
q
)
return
torch
.
empty_like
(
q
)
...
@@ -90,6 +93,7 @@ def vit_flash_attn_wrapper(
...
@@ -90,6 +93,7 @@ def vit_flash_attn_wrapper(
batch_size
:
int
,
batch_size
:
int
,
is_rocm_aiter
:
bool
,
is_rocm_aiter
:
bool
,
fa_version
:
int
|
None
,
fa_version
:
int
|
None
,
scale
:
float
|
None
=
None
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
max_seqlen
:
torch
.
Tensor
|
None
=
None
,
max_seqlen
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -100,18 +104,24 @@ def vit_flash_attn_wrapper(
...
@@ -100,18 +104,24 @@ def vit_flash_attn_wrapper(
batch_size
,
batch_size
,
is_rocm_aiter
,
is_rocm_aiter
,
fa_version
,
fa_version
,
scale
,
cu_seqlens
,
cu_seqlens
,
max_seqlen
,
max_seqlen
,
)
)
def
apply_sdpa
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
apply_sdpa
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
scale
:
float
|
None
=
None
,
)
->
torch
.
Tensor
:
"""
"""
Input shape:
Input shape:
(batch_size x seq_len x num_heads x head_size)
(batch_size x seq_len x num_heads x head_size)
"""
"""
q
,
k
,
v
=
(
einops
.
rearrange
(
x
,
"b s h d -> b h s d"
)
for
x
in
[
q
,
k
,
v
])
q
,
k
,
v
=
(
einops
.
rearrange
(
x
,
"b s h d -> b h s d"
)
for
x
in
[
q
,
k
,
v
])
output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
dropout_p
=
0.0
)
output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
dropout_p
=
0.0
,
scale
=
scale
)
output
=
einops
.
rearrange
(
output
,
"b h s d -> b s h d "
)
output
=
einops
.
rearrange
(
output
,
"b h s d -> b s h d "
)
return
output
return
output
...
@@ -122,6 +132,7 @@ def torch_sdpa_wrapper(
...
@@ -122,6 +132,7 @@ def torch_sdpa_wrapper(
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
scale
:
float
|
None
=
None
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Never remove the contiguous logic for ROCm
# Never remove the contiguous logic for ROCm
...
@@ -132,7 +143,7 @@ def torch_sdpa_wrapper(
...
@@ -132,7 +143,7 @@ def torch_sdpa_wrapper(
v
=
v
.
contiguous
()
v
=
v
.
contiguous
()
if
cu_seqlens
is
None
:
if
cu_seqlens
is
None
:
return
apply_sdpa
(
q
,
k
,
v
)
return
apply_sdpa
(
q
,
k
,
v
,
scale
=
scale
)
outputs
=
[]
outputs
=
[]
...
@@ -141,7 +152,7 @@ def torch_sdpa_wrapper(
...
@@ -141,7 +152,7 @@ def torch_sdpa_wrapper(
k_chunks
=
torch
.
split
(
k
,
lens
,
dim
=
1
)
k_chunks
=
torch
.
split
(
k
,
lens
,
dim
=
1
)
v_chunks
=
torch
.
split
(
v
,
lens
,
dim
=
1
)
v_chunks
=
torch
.
split
(
v
,
lens
,
dim
=
1
)
for
q_i
,
k_i
,
v_i
in
zip
(
q_chunks
,
k_chunks
,
v_chunks
):
for
q_i
,
k_i
,
v_i
in
zip
(
q_chunks
,
k_chunks
,
v_chunks
):
output_i
=
apply_sdpa
(
q_i
,
k_i
,
v_i
)
output_i
=
apply_sdpa
(
q_i
,
k_i
,
v_i
,
scale
=
scale
)
outputs
.
append
(
output_i
)
outputs
.
append
(
output_i
)
context_layer
=
torch
.
cat
(
outputs
,
dim
=
1
)
context_layer
=
torch
.
cat
(
outputs
,
dim
=
1
)
return
context_layer
return
context_layer
...
@@ -151,7 +162,8 @@ def torch_sdpa_wrapper_fake(
...
@@ -151,7 +162,8 @@ def torch_sdpa_wrapper_fake(
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
scale
:
float
|
None
,
cu_seqlens
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
q
)
return
torch
.
empty_like
(
q
)
...
@@ -167,6 +179,7 @@ def vit_torch_sdpa_wrapper(
...
@@ -167,6 +179,7 @@ def vit_torch_sdpa_wrapper(
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
scale
:
float
|
None
=
None
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
torch
.
ops
.
vllm
.
torch_sdpa_wrapper
(
q
,
k
,
v
,
cu_seqlens
)
return
torch
.
ops
.
vllm
.
torch_sdpa_wrapper
(
q
,
k
,
v
,
scale
,
cu_seqlens
)
vllm/model_executor/models/dots_ocr.py
View file @
2972a054
...
@@ -271,6 +271,7 @@ class DotsVisionAttention(nn.Module):
...
@@ -271,6 +271,7 @@ class DotsVisionAttention(nn.Module):
self
.
attn
=
MMEncoderAttention
(
self
.
attn
=
MMEncoderAttention
(
num_heads
=
self
.
num_attention_heads_per_partition
,
num_heads
=
self
.
num_attention_heads_per_partition
,
head_size
=
self
.
hidden_size_per_attention_head
,
head_size
=
self
.
hidden_size_per_attention_head
,
scale
=
self
.
hidden_size_per_attention_head
**-
0.5
,
multimodal_config
=
multimodal_config
,
multimodal_config
=
multimodal_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
)
...
...
vllm/model_executor/models/ernie45_vl.py
View file @
2972a054
...
@@ -152,6 +152,7 @@ class Ernie4_5_VisionAttention(nn.Module):
...
@@ -152,6 +152,7 @@ class Ernie4_5_VisionAttention(nn.Module):
self
.
attn
=
MMEncoderAttention
(
self
.
attn
=
MMEncoderAttention
(
num_heads
=
self
.
num_attention_heads_per_partition
,
num_heads
=
self
.
num_attention_heads_per_partition
,
head_size
=
self
.
hidden_size_per_attention_head
,
head_size
=
self
.
hidden_size_per_attention_head
,
scale
=
self
.
hidden_size_per_attention_head
**-
0.5
,
multimodal_config
=
multimodal_config
,
multimodal_config
=
multimodal_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
)
...
...
vllm/model_executor/models/glm4_1v.py
View file @
2972a054
...
@@ -304,6 +304,7 @@ class Glm4vVisionAttention(nn.Module):
...
@@ -304,6 +304,7 @@ class Glm4vVisionAttention(nn.Module):
self
.
attn
=
MMEncoderAttention
(
self
.
attn
=
MMEncoderAttention
(
num_heads
=
self
.
num_attention_heads_per_partition
,
num_heads
=
self
.
num_attention_heads_per_partition
,
head_size
=
self
.
hidden_size_per_attention_head
,
head_size
=
self
.
hidden_size_per_attention_head
,
scale
=
self
.
hidden_size_per_attention_head
**-
0.5
,
multimodal_config
=
multimodal_config
,
multimodal_config
=
multimodal_config
,
)
)
...
...
vllm/model_executor/models/glmasr.py
View file @
2972a054
...
@@ -188,6 +188,7 @@ class GlmAsrEncoderAttention(nn.Module):
...
@@ -188,6 +188,7 @@ class GlmAsrEncoderAttention(nn.Module):
self
.
attn
=
MMEncoderAttention
(
self
.
attn
=
MMEncoderAttention
(
num_heads
=
self
.
num_heads_per_rank
,
num_heads
=
self
.
num_heads_per_rank
,
head_size
=
self
.
head_dim
,
head_size
=
self
.
head_dim
,
scale
=
self
.
head_dim
**-
0.5
,
num_kv_heads
=
self
.
num_kv_heads_per_rank
,
num_kv_heads
=
self
.
num_kv_heads_per_rank
,
prefix
=
f
"
{
prefix
}
.attn"
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
)
...
...
vllm/model_executor/models/isaac.py
View file @
2972a054
...
@@ -984,6 +984,7 @@ class Siglip2VisionAttention(nn.Module):
...
@@ -984,6 +984,7 @@ class Siglip2VisionAttention(nn.Module):
self
.
attn
=
MMEncoderAttention
(
self
.
attn
=
MMEncoderAttention
(
num_heads
=
self
.
num_attention_heads_per_partition
,
num_heads
=
self
.
num_attention_heads_per_partition
,
head_size
=
self
.
hidden_size_per_attention_head
,
head_size
=
self
.
hidden_size_per_attention_head
,
scale
=
self
.
hidden_size_per_attention_head
**-
0.5
,
prefix
=
f
"
{
prefix
}
.attn"
,
prefix
=
f
"
{
prefix
}
.attn"
,
multimodal_config
=
multimodal_config
,
multimodal_config
=
multimodal_config
,
)
)
...
...
vllm/model_executor/models/moonvit.py
View file @
2972a054
...
@@ -390,6 +390,7 @@ class MoonVitEncoderLayer(nn.Module):
...
@@ -390,6 +390,7 @@ class MoonVitEncoderLayer(nn.Module):
self
.
attn
=
MMEncoderAttention
(
self
.
attn
=
MMEncoderAttention
(
num_heads
=
self
.
num_attention_heads_per_partition
,
num_heads
=
self
.
num_attention_heads_per_partition
,
head_size
=
self
.
hidden_size_per_attention_head
,
head_size
=
self
.
hidden_size_per_attention_head
,
scale
=
self
.
hidden_size_per_attention_head
**-
0.5
,
multimodal_config
=
multimodal_config
,
multimodal_config
=
multimodal_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
)
...
...
vllm/model_executor/models/paddleocr_vl.py
View file @
2972a054
...
@@ -564,6 +564,7 @@ class SiglipAttention(nn.Module):
...
@@ -564,6 +564,7 @@ class SiglipAttention(nn.Module):
self
.
attn
=
MMEncoderAttention
(
self
.
attn
=
MMEncoderAttention
(
num_heads
=
self
.
num_attention_heads_per_partition
,
num_heads
=
self
.
num_attention_heads_per_partition
,
head_size
=
self
.
hidden_size_per_attention_head
,
head_size
=
self
.
hidden_size_per_attention_head
,
scale
=
self
.
hidden_size_per_attention_head
**-
0.5
,
multimodal_config
=
multimodal_config
,
multimodal_config
=
multimodal_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
)
...
...
vllm/model_executor/models/qwen2_5_vl.py
View file @
2972a054
...
@@ -352,6 +352,7 @@ class Qwen2_5_VisionAttention(nn.Module):
...
@@ -352,6 +352,7 @@ class Qwen2_5_VisionAttention(nn.Module):
self
.
attn
=
MMEncoderAttention
(
self
.
attn
=
MMEncoderAttention
(
num_heads
=
self
.
num_attention_heads_per_partition
,
num_heads
=
self
.
num_attention_heads_per_partition
,
head_size
=
self
.
hidden_size_per_attention_head
,
head_size
=
self
.
hidden_size_per_attention_head
,
scale
=
self
.
hidden_size_per_attention_head
**-
0.5
,
multimodal_config
=
multimodal_config
,
multimodal_config
=
multimodal_config
,
)
)
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
2972a054
...
@@ -327,6 +327,7 @@ class Qwen2VisionAttention(nn.Module):
...
@@ -327,6 +327,7 @@ class Qwen2VisionAttention(nn.Module):
self
.
attn
=
MMEncoderAttention
(
self
.
attn
=
MMEncoderAttention
(
num_heads
=
self
.
num_attention_heads_per_partition
,
num_heads
=
self
.
num_attention_heads_per_partition
,
head_size
=
self
.
hidden_size_per_attention_head
,
head_size
=
self
.
hidden_size_per_attention_head
,
scale
=
self
.
hidden_size_per_attention_head
**-
0.5
,
multimodal_config
=
multimodal_config
,
multimodal_config
=
multimodal_config
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment