Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f44afef6
Unverified
Commit
f44afef6
authored
Apr 10, 2026
by
Richard Zou
Committed by
GitHub
Apr 10, 2026
Browse files
[compile] Allow strings in custom ops without regressing compilation times (#38123)
Signed-off-by:
Richard Zou
<
zou3519@gmail.com
>
parent
447ce222
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
565 additions
and
168 deletions
+565
-168
tests/compile/passes/test_rope_kvcache_fusion.py
tests/compile/passes/test_rope_kvcache_fusion.py
+2
-1
vllm/compilation/passes/fusion/attn_quant_fusion.py
vllm/compilation/passes/fusion/attn_quant_fusion.py
+182
-50
vllm/compilation/passes/fusion/mla_attn_quant_fusion.py
vllm/compilation/passes/fusion/mla_attn_quant_fusion.py
+166
-36
vllm/compilation/passes/fusion/rope_kvcache_fusion.py
vllm/compilation/passes/fusion/rope_kvcache_fusion.py
+72
-25
vllm/envs.py
vllm/envs.py
+4
-0
vllm/model_executor/layers/attention/attention.py
vllm/model_executor/layers/attention/attention.py
+18
-9
vllm/model_executor/layers/attention/kv_transfer_utils.py
vllm/model_executor/layers/attention/kv_transfer_utils.py
+2
-1
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+18
-7
vllm/model_executor/layers/attention/static_sink_attention.py
.../model_executor/layers/attention/static_sink_attention.py
+12
-4
vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py
...model_executor/layers/fused_moe/runner/moe_runner_base.py
+11
-11
vllm/model_executor/layers/mamba/gdn_linear_attn.py
vllm/model_executor/layers/mamba/gdn_linear_attn.py
+10
-4
vllm/model_executor/layers/mamba/mamba_mixer.py
vllm/model_executor/layers/mamba/mamba_mixer.py
+10
-4
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+10
-4
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+11
-5
vllm/utils/torch_utils.py
vllm/utils/torch_utils.py
+31
-5
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
+6
-2
No files found.
tests/compile/passes/test_rope_kvcache_fusion.py
View file @
f44afef6
...
...
@@ -28,6 +28,7 @@ from vllm.forward_context import get_forward_context, set_forward_context
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
_encode_layer_name
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
CommonAttentionMetadata
,
...
...
@@ -170,7 +171,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
k
=
k
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
v
=
v
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
kv_cache_dummy_dep
=
torch
.
ops
.
vllm
.
unified_kv_cache_update
(
k
,
v
,
self
.
layer_name
k
,
v
,
_encode_layer_name
(
self
.
layer_name
)
)
return
q
,
k
,
v
,
kv_cache_dummy_dep
...
...
vllm/compilation/passes/fusion/attn_quant_fusion.py
View file @
f44afef6
...
...
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
round_up
from
vllm.utils.torch_utils
import
_USE_LAYERNAME
,
_encode_layer_name
from
..vllm_inductor_pass
import
VllmFusionPatternMatcherPass
,
VllmPatternReplacement
from
.matcher_utils
import
MatcherQuantFP8
...
...
@@ -53,21 +54,43 @@ class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
@
property
def
pattern
(
self
)
->
Callable
[...,
torch
.
Tensor
]:
def
_pattern
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
kv_cache_dummy_dep
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# When _USE_LAYERNAME is enabled (torch >= 2.11), layer_name is
# passed as an explicit pattern input so the pattern matcher
# treats it as a wildcard matching hoisted LayerName placeholders.
# Otherwise it stays as a closure constant (original behavior).
_ln
=
_encode_layer_name
(
self
.
_layer_name
)
if
_USE_LAYERNAME
:
def
_pattern_with_ln
(
# type: ignore[misc]
q
,
k
,
v
,
output_attn
,
scale
,
kv_cache_dummy_dep
,
layer_name
):
at1
=
auto_functionalized
(
ATTN_OP
,
query
=
q
,
key
=
k
,
value
=
v
,
output
=
output_attn
,
layer_name
=
layer_name
,
output_scale
=
None
,
output_block_scale
=
None
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
attn_out_view
=
RESHAPE_OP
(
at1
[
1
],
[
q
.
shape
[
0
],
self
.
_num_heads
*
self
.
_head_size
]
)
return
self
.
_quant_matcher
(
attn_out_view
,
scale
)[
0
]
return
_pattern_with_ln
def
_pattern
(
q
,
k
,
v
,
output_attn
,
scale
,
kv_cache_dummy_dep
):
at1
=
auto_functionalized
(
ATTN_OP
,
query
=
q
,
key
=
k
,
value
=
v
,
output
=
output_attn
,
layer_name
=
self
.
_layer_name
,
layer_name
=
_ln
,
output_scale
=
None
,
output_block_scale
=
None
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
...
...
@@ -81,14 +104,34 @@ class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
@
property
def
replacement
(
self
)
->
Callable
[...,
torch
.
Tensor
]:
def
_replacement
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
kv_cache_dummy_dep
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
_ln
=
_encode_layer_name
(
self
.
_layer_name
)
if
_USE_LAYERNAME
:
def
_replacement_with_ln
(
# type: ignore[misc]
q
,
k
,
v
,
output_attn
,
scale
,
kv_cache_dummy_dep
,
layer_name
):
output_attn
=
torch
.
empty
(
[
q
.
shape
[
0
],
self
.
_num_heads
,
self
.
_head_size
],
dtype
=
FP8_DTYPE
,
device
=
q
.
device
,
)
at1
=
auto_functionalized
(
ATTN_OP
,
query
=
q
,
key
=
k
,
value
=
v
,
output
=
output_attn
,
layer_name
=
layer_name
,
output_scale
=
scale
,
output_block_scale
=
None
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
return
RESHAPE_OP
(
at1
[
1
],
[
-
1
,
self
.
_num_heads
*
self
.
_head_size
])
return
_replacement_with_ln
def
_replacement
(
q
,
k
,
v
,
output_attn
,
scale
,
kv_cache_dummy_dep
):
output_attn
=
torch
.
empty
(
[
q
.
shape
[
0
],
self
.
_num_heads
,
self
.
_head_size
],
dtype
=
FP8_DTYPE
,
...
...
@@ -100,7 +143,7 @@ class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
key
=
k
,
value
=
v
,
output
=
output_attn
,
layer_name
=
self
.
_layer_name
,
layer_name
=
_ln
,
output_scale
=
scale
,
output_block_scale
=
None
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
...
...
@@ -113,7 +156,7 @@ class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
dtype
=
self
.
_dtype
num_heads
=
self
.
_num_heads
head_size
=
self
.
_head_size
return
[
inputs
:
list
=
[
self
.
empty
(
5
,
num_heads
,
head_size
,
dtype
=
dtype
),
# q
self
.
empty
(
5
,
num_heads
,
head_size
,
dtype
=
dtype
),
# k
self
.
empty
(
5
,
num_heads
,
head_size
,
dtype
=
dtype
),
# v
...
...
@@ -121,6 +164,9 @@ class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
self
.
empty_fp32
(
1
,
1
),
# scale
self
.
empty
(
0
,
dtype
=
dtype
),
# kv_cache_dummy_dep
]
if
_USE_LAYERNAME
:
inputs
.
append
(
_encode_layer_name
(
self
.
_layer_name
))
return
inputs
class
AttnNvfp4QuantPattern
(
...
...
@@ -144,23 +190,64 @@ class AttnNvfp4QuantPattern(
@
property
def
pattern
(
self
)
->
Callable
[...,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
_ln
=
_encode_layer_name
(
self
.
_layer_name
)
if
_USE_LAYERNAME
:
def
_pattern_with_ln
(
# type: ignore[misc]
q
,
k
,
v
,
output_attn
,
output_quant
,
output_scale
,
input_scale
,
kv_cache_dummy_dep
,
layer_name
,
):
at1
=
auto_functionalized
(
ATTN_OP
,
query
=
q
,
key
=
k
,
value
=
v
,
output
=
output_attn
,
layer_name
=
layer_name
,
output_scale
=
None
,
output_block_scale
=
None
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
attn_out_view
=
RESHAPE_OP
(
at1
[
1
],
[
q
.
shape
[
0
],
self
.
_num_heads
*
self
.
_head_size
]
)
at2
=
auto_functionalized
(
self
.
_QUANT_OP
,
input
=
attn_out_view
,
input_scale
=
input_scale
,
is_sf_swizzled_layout
=
True
,
output
=
output_quant
,
output_scale
=
output_scale
,
)
return
at2
[
1
],
torch
.
ops
.
aten
.
view
.
dtype
(
at2
[
2
],
FP8_DTYPE
)
return
_pattern_with_ln
def
_pattern
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
output_quant
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
,
kv_cache_dummy_dep
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
q
,
k
,
v
,
output_attn
,
output_quant
,
output_scale
,
input_scale
,
kv_cache_dummy_dep
,
):
at1
=
auto_functionalized
(
ATTN_OP
,
query
=
q
,
key
=
k
,
value
=
v
,
output
=
output_attn
,
layer_name
=
self
.
_layer_name
,
layer_name
=
_ln
,
output_scale
=
None
,
output_block_scale
=
None
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
...
...
@@ -176,42 +263,80 @@ class AttnNvfp4QuantPattern(
output
=
output_quant
,
output_scale
=
output_scale
,
)
output_scale_view
=
torch
.
ops
.
aten
.
view
.
dtype
(
at2
[
2
],
FP8_DTYPE
)
return
at2
[
1
],
output_scale_view
return
at2
[
1
],
torch
.
ops
.
aten
.
view
.
dtype
(
at2
[
2
],
FP8_DTYPE
)
return
_pattern
@
property
def
replacement
(
self
)
->
Callable
[...,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
_ln
=
_encode_layer_name
(
self
.
_layer_name
)
if
_USE_LAYERNAME
:
def
_replacement_with_ln
(
# type: ignore[misc]
q
,
k
,
v
,
output_attn
,
_output_quant
,
output_scale
,
input_scale
,
kv_cache_dummy_dep
,
layer_name
,
):
output_attn
=
torch
.
empty
(
[
q
.
shape
[
0
],
self
.
_num_heads
,
self
.
_head_size
//
2
],
dtype
=
FP4_DTYPE
,
device
=
q
.
device
,
)
osv
=
torch
.
ops
.
aten
.
view
.
dtype
(
output_scale
,
FP8_DTYPE
)
at2
=
auto_functionalized
(
ATTN_OP
,
query
=
q
,
key
=
k
,
value
=
v
,
output
=
output_attn
,
layer_name
=
layer_name
,
output_scale
=
input_scale
,
output_block_scale
=
osv
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
return
RESHAPE_OP
(
at2
[
1
],
[
-
1
,
self
.
_num_heads
*
self
.
_head_size
//
2
]
),
at2
[
2
]
return
_replacement_with_ln
def
_replacement
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
_output_quant
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
,
kv_cache_dummy_dep
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
q
,
k
,
v
,
output_attn
,
_output_quant
,
output_scale
,
input_scale
,
kv_cache_dummy_dep
,
):
output_attn
=
torch
.
empty
(
[
q
.
shape
[
0
],
self
.
_num_heads
,
self
.
_head_size
//
2
],
dtype
=
FP4_DTYPE
,
device
=
q
.
device
,
)
o
utput_scale_view
=
torch
.
ops
.
aten
.
view
.
dtype
(
output_scale
,
FP8_DTYPE
)
o
sv
=
torch
.
ops
.
aten
.
view
.
dtype
(
output_scale
,
FP8_DTYPE
)
at2
=
auto_functionalized
(
ATTN_OP
,
query
=
q
,
key
=
k
,
value
=
v
,
output
=
output_attn
,
layer_name
=
self
.
_layer_name
,
layer_name
=
_ln
,
output_scale
=
input_scale
,
output_block_scale
=
o
utput_scale_view
,
output_block_scale
=
o
sv
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
output
=
RESHAPE_OP
(
at2
[
1
],
[
-
1
,
self
.
_num_heads
*
self
.
_head_size
//
2
])
return
output
,
at2
[
2
]
return
RESHAPE_OP
(
at2
[
1
],
[
-
1
,
self
.
_num_heads
*
self
.
_head_size
//
2
]
),
at2
[
2
]
return
_replacement
...
...
@@ -219,18 +344,19 @@ class AttnNvfp4QuantPattern(
dtype
=
self
.
_dtype
num_heads
=
self
.
_num_heads
head_size
=
self
.
_head_size
return
[
inputs
:
list
=
[
self
.
empty_bf16
(
5
,
num_heads
,
head_size
),
# q
self
.
empty_bf16
(
5
,
num_heads
,
head_size
),
# k
self
.
empty_bf16
(
5
,
num_heads
,
head_size
),
# v
self
.
empty_bf16
(
5
,
num_heads
,
head_size
),
# output_attn
self
.
empty
(
5
,
num_heads
*
head_size
//
2
,
dtype
=
FP4_DTYPE
),
# output_quant
self
.
empty_i32
(
128
,
round_up
(
num_heads
*
head_size
//
16
,
4
)
),
# output_scale
self
.
empty
(
5
,
num_heads
*
head_size
//
2
,
dtype
=
FP4_DTYPE
),
self
.
empty_i32
(
128
,
round_up
(
num_heads
*
head_size
//
16
,
4
)),
self
.
empty_fp32
(
1
,
1
),
# input_scale
self
.
empty
(
0
,
dtype
=
dtype
),
# kv_cache_dummy_dep
]
if
_USE_LAYERNAME
:
inputs
.
append
(
_encode_layer_name
(
self
.
_layer_name
))
return
inputs
class
AttnQuantFusionPass
(
VllmFusionPatternMatcherPass
):
...
...
@@ -259,13 +385,19 @@ class AttnQuantFusionPass(VllmFusionPatternMatcherPass):
"so no fusion patterns were registered."
)
# When _USE_LAYERNAME is enabled, layer_name is a wildcard so all
# layers produce the same pattern — register once then break.
for
layer
in
layers
:
if
layer
.
impl
.
fused_output_quant_supported
(
_FP8_QUANT_KEY
):
self
.
register
(
AttnFp8StaticQuantPattern
(
layer
,
dtype
))
if
_USE_LAYERNAME
:
break
if
current_platform
.
is_cuda
()
and
hasattr
(
torch
.
ops
.
_C
,
"scaled_fp4_quant"
):
for
layer
in
layers
:
if
layer
.
impl
.
fused_output_quant_supported
(
kNvfp4Dynamic
):
self
.
register
(
AttnNvfp4QuantPattern
(
layer
,
dtype
))
if
_USE_LAYERNAME
:
break
self
.
dump_patterns
(
config
,
self
.
pm_pass
)
vllm/compilation/passes/fusion/mla_attn_quant_fusion.py
View file @
f44afef6
...
...
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kNvfp4Dynamic
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
_USE_LAYERNAME
,
_encode_layer_name
from
..vllm_inductor_pass
import
VllmFusionPatternMatcherPass
,
VllmPatternReplacement
from
.matcher_utils
import
MatcherQuantFP8
...
...
@@ -49,21 +50,43 @@ class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
@
property
def
pattern
(
self
)
->
Callable
[...,
torch
.
Tensor
]:
def
_pattern
(
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
kv_cache_dummy_dep
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
_ln
=
_encode_layer_name
(
self
.
_layer_name
)
if
_USE_LAYERNAME
:
def
_pattern_with_ln
(
# type: ignore[misc]
q
,
kv_c_normed
,
k_pe
,
output_attn
,
scale
,
kv_cache_dummy_dep
,
layer_name
,
):
at1
=
auto_functionalized
(
MLA_ATTN_OP
,
q
=
q
,
kv_c_normed
=
kv_c_normed
,
k_pe
=
k_pe
,
output
=
output_attn
,
layer_name
=
layer_name
,
output_scale
=
None
,
output_block_scale
=
None
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
# MLA output is already 2D (T, N*V), no reshape needed
return
self
.
_quant_matcher
(
at1
[
1
],
scale
)[
0
]
return
_pattern_with_ln
def
_pattern
(
q
,
kv_c_normed
,
k_pe
,
output_attn
,
scale
,
kv_cache_dummy_dep
):
at1
=
auto_functionalized
(
MLA_ATTN_OP
,
q
=
q
,
kv_c_normed
=
kv_c_normed
,
k_pe
=
k_pe
,
output
=
output_attn
,
layer_name
=
self
.
_layer_name
,
layer_name
=
_ln
,
output_scale
=
None
,
output_block_scale
=
None
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
...
...
@@ -75,14 +98,41 @@ class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
@
property
def
replacement
(
self
)
->
Callable
[...,
torch
.
Tensor
]:
def
_replacement
(
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
kv_cache_dummy_dep
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
_ln
=
_encode_layer_name
(
self
.
_layer_name
)
if
_USE_LAYERNAME
:
def
_replacement_with_ln
(
# type: ignore[misc]
q
,
kv_c_normed
,
k_pe
,
output_attn
,
scale
,
kv_cache_dummy_dep
,
layer_name
,
):
# MLA output in quant_dtype
output_attn
=
torch
.
empty
(
[
q
.
shape
[
0
],
self
.
_output_dim
],
dtype
=
FP8_DTYPE
,
device
=
q
.
device
,
)
at1
=
auto_functionalized
(
MLA_ATTN_OP
,
q
=
q
,
kv_c_normed
=
kv_c_normed
,
k_pe
=
k_pe
,
output
=
output_attn
,
layer_name
=
layer_name
,
output_scale
=
scale
,
output_block_scale
=
None
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
return
at1
[
1
]
return
_replacement_with_ln
def
_replacement
(
q
,
kv_c_normed
,
k_pe
,
output_attn
,
scale
,
kv_cache_dummy_dep
):
# MLA output in quant_dtype
output_attn
=
torch
.
empty
(
[
q
.
shape
[
0
],
self
.
_output_dim
],
...
...
@@ -95,7 +145,7 @@ class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
kv_c_normed
=
kv_c_normed
,
k_pe
=
k_pe
,
output
=
output_attn
,
layer_name
=
self
.
_layer_name
,
layer_name
=
_ln
,
output_scale
=
scale
,
output_block_scale
=
None
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
...
...
@@ -105,7 +155,7 @@ class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
return
_replacement
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]:
return
[
inputs
:
list
=
[
self
.
empty
(
5
,
self
.
_num_heads
,
self
.
_qk_head_dim
,
dtype
=
self
.
_dtype
),
self
.
empty
(
5
,
self
.
_kv_lora_rank
,
dtype
=
self
.
_dtype
),
self
.
empty
(
5
,
1
,
self
.
_qk_rope_head_dim
,
dtype
=
self
.
_dtype
),
...
...
@@ -113,6 +163,9 @@ class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
self
.
empty_fp32
(
1
,
1
),
self
.
empty
(
0
,
dtype
=
self
.
_dtype
),
]
if
_USE_LAYERNAME
:
inputs
.
append
(
_encode_layer_name
(
self
.
_layer_name
))
return
inputs
class
MLAAttnNvfp4QuantPattern
(
...
...
@@ -141,21 +194,56 @@ class MLAAttnNvfp4QuantPattern(
def
pattern
(
self
,
)
->
Callable
[...,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
_ln
=
_encode_layer_name
(
self
.
_layer_name
)
if
_USE_LAYERNAME
:
def
_pattern_with_ln
(
# type: ignore[misc]
q
,
kv_c_normed
,
k_pe
,
output_attn
,
input_scale
,
kv_cache_dummy_dep
,
layer_name
,
):
at1
=
auto_functionalized
(
MLA_ATTN_OP
,
q
=
q
,
kv_c_normed
=
kv_c_normed
,
k_pe
=
k_pe
,
output
=
output_attn
,
layer_name
=
layer_name
,
output_scale
=
None
,
output_block_scale
=
None
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
output_quant
,
output_scale
=
create_fp4_output_tensors
(
at1
[
1
].
shape
[
0
],
at1
[
1
].
shape
[
1
],
at1
[
1
].
device
,
True
)
at2
=
auto_functionalized
(
self
.
_QUANT_OP
,
input
=
at1
[
1
],
input_scale
=
input_scale
,
is_sf_swizzled_layout
=
True
,
output
=
output_quant
,
output_scale
=
output_scale
,
)
output_scale_view
=
torch
.
ops
.
aten
.
view
.
dtype
(
at2
[
2
],
FP8_DTYPE
)
return
at2
[
1
],
output_scale_view
return
_pattern_with_ln
def
_pattern
(
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
,
kv_cache_dummy_dep
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
q
,
kv_c_normed
,
k_pe
,
output_attn
,
input_scale
,
kv_cache_dummy_dep
):
at1
=
auto_functionalized
(
MLA_ATTN_OP
,
q
=
q
,
kv_c_normed
=
kv_c_normed
,
k_pe
=
k_pe
,
output
=
output_attn
,
layer_name
=
self
.
_layer_name
,
layer_name
=
_ln
,
output_scale
=
None
,
output_block_scale
=
None
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
...
...
@@ -182,14 +270,47 @@ class MLAAttnNvfp4QuantPattern(
def
replacement
(
self
,
)
->
Callable
[...,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
_ln
=
_encode_layer_name
(
self
.
_layer_name
)
if
_USE_LAYERNAME
:
def
_replacement_with_ln
(
# type: ignore[misc]
q
,
kv_c_normed
,
k_pe
,
output_attn
,
input_scale
,
kv_cache_dummy_dep
,
layer_name
,
):
# MLA output in quant_dtype (FP4 packed as uint8)
output_attn
=
torch
.
empty
(
[
q
.
shape
[
0
],
self
.
_output_dim
//
2
],
dtype
=
FP4_DTYPE
,
device
=
q
.
device
,
)
output_scale
=
create_fp4_output_tensors
(
q
.
shape
[
0
],
self
.
_output_dim
,
q
.
device
,
True
)[
1
]
output_scale_view
=
torch
.
ops
.
aten
.
view
.
dtype
(
output_scale
,
FP8_DTYPE
)
at2
=
auto_functionalized
(
MLA_ATTN_OP
,
q
=
q
,
kv_c_normed
=
kv_c_normed
,
k_pe
=
k_pe
,
output
=
output_attn
,
layer_name
=
layer_name
,
output_scale
=
input_scale
,
output_block_scale
=
output_scale_view
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
return
at2
[
1
],
at2
[
2
]
return
_replacement_with_ln
def
_replacement
(
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
,
kv_cache_dummy_dep
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
q
,
kv_c_normed
,
k_pe
,
output_attn
,
input_scale
,
kv_cache_dummy_dep
):
# MLA output in quant_dtype (FP4 packed as uint8)
output_attn
=
torch
.
empty
(
[
q
.
shape
[
0
],
self
.
_output_dim
//
2
],
...
...
@@ -207,7 +328,7 @@ class MLAAttnNvfp4QuantPattern(
kv_c_normed
=
kv_c_normed
,
k_pe
=
k_pe
,
output
=
output_attn
,
layer_name
=
self
.
_layer_name
,
layer_name
=
_ln
,
output_scale
=
input_scale
,
output_block_scale
=
output_scale_view
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
...
...
@@ -217,7 +338,7 @@ class MLAAttnNvfp4QuantPattern(
return
_replacement
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]:
return
[
inputs
:
list
=
[
self
.
empty
(
5
,
self
.
_num_heads
,
self
.
_qk_head_dim
,
dtype
=
self
.
_dtype
),
self
.
empty
(
5
,
self
.
_kv_lora_rank
,
dtype
=
self
.
_dtype
),
self
.
empty
(
5
,
1
,
self
.
_qk_rope_head_dim
,
dtype
=
self
.
_dtype
),
...
...
@@ -225,6 +346,9 @@ class MLAAttnNvfp4QuantPattern(
self
.
empty_fp32
(
1
,
1
),
self
.
empty
(
0
,
dtype
=
self
.
_dtype
),
]
if
_USE_LAYERNAME
:
inputs
.
append
(
_encode_layer_name
(
self
.
_layer_name
))
return
inputs
class
MLAAttnQuantFusionPass
(
VllmFusionPatternMatcherPass
):
...
...
@@ -250,13 +374,19 @@ class MLAAttnQuantFusionPass(VllmFusionPatternMatcherPass):
"so no fusion patterns were registered."
)
# When _USE_LAYERNAME is enabled, layer_name is a wildcard so all
# layers produce the same pattern — register once then break.
for
layer
in
layers
:
if
layer
.
impl
.
fused_output_quant_supported
(
kFp8StaticTensorSym
):
self
.
register
(
MLAAttnFp8StaticQuantPattern
(
layer
,
dtype
))
if
_USE_LAYERNAME
:
break
if
current_platform
.
is_cuda
()
and
hasattr
(
torch
.
ops
.
_C
,
"scaled_fp4_quant"
):
for
layer
in
layers
:
if
layer
.
impl
.
fused_output_quant_supported
(
kNvfp4Dynamic
):
self
.
register
(
MLAAttnNvfp4QuantPattern
(
layer
,
dtype
))
if
_USE_LAYERNAME
:
break
self
.
dump_patterns
(
config
,
self
.
pm_pass
)
vllm/compilation/passes/fusion/rope_kvcache_fusion.py
View file @
f44afef6
...
...
@@ -15,7 +15,13 @@ from vllm.model_executor.layers.attention.attention import (
Attention
,
get_attention_context
,
)
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
vllm.utils.torch_utils
import
(
_USE_LAYERNAME
,
LayerNameType
,
_encode_layer_name
,
_resolve_layer_name
,
direct_register_custom_op
,
)
from
..inductor_pass
import
enable_fake_mode
from
..vllm_inductor_pass
import
VllmInductorPass
,
VllmPatternMatcherPass
...
...
@@ -37,7 +43,7 @@ def fused_rope_and_unified_kv_cache_update_impl(
positions
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
layer_name
:
str
=
""
,
layer_name
:
LayerNameType
,
)
->
torch
.
Tensor
:
"""
This impl fetches the KV cache and slot mapping from the forward context,
...
...
@@ -46,6 +52,7 @@ def fused_rope_and_unified_kv_cache_update_impl(
that is passed to unified_attention to signal a side effect and
the data dependency between them to ensure torch.compile preserves ordering.
"""
layer_name
=
_resolve_layer_name
(
layer_name
)
_
,
attn_layer
,
kv_cache
,
layer_slot_mapping
=
get_attention_context
(
layer_name
)
if
layer_slot_mapping
is
not
None
:
attn_layer
.
impl
.
do_rope_and_kv_cache_update
(
...
...
@@ -70,7 +77,7 @@ def fused_rope_and_unified_kv_cache_update_fake(
positions
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
layer_name
:
str
=
""
,
layer_name
:
LayerNameType
,
)
->
torch
.
Tensor
:
return
torch
.
empty
(
0
,
device
=
query
.
device
,
dtype
=
query
.
dtype
)
...
...
@@ -120,38 +127,30 @@ class RopeReshapeKVCachePattern:
num_kv_heads
=
self
.
num_kv_heads
,
)
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]
:
def
get_inputs
(
self
)
->
list
:
# Sample inputs to help pattern tracing
T
=
5
L
=
4096
qkv
=
empty_bf16
(
T
,
self
.
q_size
+
self
.
k_size
+
self
.
v_size
)
positions
=
empty_i64
(
T
)
cos_sin_cache
=
empty_bf16
(
L
,
self
.
head_size
)
return
[
qkv
,
positions
,
cos_sin_cache
,
]
inputs
:
list
=
[
qkv
,
positions
,
cos_sin_cache
]
if
_USE_LAYERNAME
:
inputs
.
append
(
_encode_layer_name
(
self
.
layer_name
))
return
inputs
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
qkv
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
def
_mk_pattern_with_layer_name_input
(
self
,
_ln
):
"""Pattern/replacement with layer_name as an explicit input."""
def
pattern
(
qkv
,
positions
,
cos_sin_cache
,
layer_name
):
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
k_size
,
self
.
v_size
],
dim
=-
1
)
q
,
k
=
self
.
rope_matcher
(
positions
,
q
,
k
,
cos_sin_cache
)
q
=
q
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
k
=
k
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
v
=
v
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size_v
)
dummy
=
torch
.
ops
.
vllm
.
unified_kv_cache_update
(
k
,
v
,
self
.
layer_name
)
return
dummy
,
q
,
k
,
v
def
replacement
(
qkv
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
ops
.
vllm
.
unified_kv_cache_update
(
k
,
v
,
layer_name
),
q
,
k
,
v
def
replacement
(
qkv
,
positions
,
cos_sin_cache
,
layer_name
):
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
k_size
,
self
.
v_size
],
dim
=-
1
)
q
=
q
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
k
=
k
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
...
...
@@ -164,10 +163,50 @@ class RopeReshapeKVCachePattern:
positions
=
positions
,
cos_sin_cache
=
cos_sin_cache
,
is_neox
=
self
.
is_neox
,
layer_name
=
self
.
layer_name
,
layer_name
=
layer_name
,
)
return
results
[
0
],
results
[
1
],
results
[
2
],
v
return
pattern
,
replacement
def
_mk_pattern_with_layer_name_closure
(
self
,
_ln
):
"""Pattern/replacement with layer_name as a closure constant."""
def
pattern
(
qkv
,
positions
,
cos_sin_cache
):
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
k_size
,
self
.
v_size
],
dim
=-
1
)
q
,
k
=
self
.
rope_matcher
(
positions
,
q
,
k
,
cos_sin_cache
)
q
=
q
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
k
=
k
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
v
=
v
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size_v
)
return
torch
.
ops
.
vllm
.
unified_kv_cache_update
(
k
,
v
,
_ln
),
q
,
k
,
v
def
replacement
(
qkv
,
positions
,
cos_sin_cache
):
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
k_size
,
self
.
v_size
],
dim
=-
1
)
q
=
q
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
k
=
k
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
v
=
v
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size_v
)
results
=
auto_functionalized
(
self
.
FUSED_OP
,
query
=
q
,
key
=
k
,
value
=
v
,
positions
=
positions
,
cos_sin_cache
=
cos_sin_cache
,
is_neox
=
self
.
is_neox
,
layer_name
=
_ln
,
)
return
results
[
0
],
results
[
1
],
results
[
2
],
v
return
pattern
,
replacement
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
_ln
=
_encode_layer_name
(
self
.
layer_name
)
if
_USE_LAYERNAME
:
pattern
,
replacement
=
self
.
_mk_pattern_with_layer_name_input
(
_ln
)
else
:
pattern
,
replacement
=
self
.
_mk_pattern_with_layer_name_closure
(
_ln
)
# NOTE: use view_to_reshape to unify view/reshape to simplify
# pattern and increase matching opportunities
def
fwd_and_view_to_reshape
(
*
args
,
**
kwargs
)
->
fx
.
GraphModule
:
...
...
@@ -176,7 +215,11 @@ class RopeReshapeKVCachePattern:
return
gm
pm
.
register_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
fwd_and_view_to_reshape
,
pm_pass
pattern
,
replacement
,
self
.
get_inputs
(),
fwd_and_view_to_reshape
,
pm_pass
,
)
...
...
@@ -205,6 +248,8 @@ class RopeKVCacheFusionPass(VllmPatternMatcherPass):
self
.
max_token_num
=
cc
.
pass_config
.
rope_kvcache_fusion_max_token_num
attn_layers
=
get_layers_from_vllm_config
(
config
,
Attention
)
# When _USE_LAYERNAME is enabled, layer_name is a wildcard so all
# layers produce the same pattern — register once then break.
for
_
,
layer
in
attn_layers
.
items
():
if
layer
.
impl
.
fused_rope_kvcache_supported
():
for
is_neox
in
[
True
,
False
]:
...
...
@@ -212,6 +257,8 @@ class RopeKVCacheFusionPass(VllmPatternMatcherPass):
layer
=
layer
,
is_neox
=
is_neox
,
).
register
(
self
.
patterns
)
if
_USE_LAYERNAME
:
break
self
.
dump_patterns
(
config
,
self
.
patterns
)
...
...
vllm/envs.py
View file @
f44afef6
...
...
@@ -129,6 +129,7 @@ if TYPE_CHECKING:
VLLM_ENABLE_V1_MULTIPROCESSING
:
bool
=
True
VLLM_LOG_BATCHSIZE_INTERVAL
:
float
=
-
1
VLLM_DISABLE_COMPILE_CACHE
:
bool
=
False
VLLM_USE_LAYERNAME
:
bool
=
True
Q_SCALE_CONSTANT
:
int
=
200
K_SCALE_CONSTANT
:
int
=
200
V_SCALE_CONSTANT
:
int
=
100
...
...
@@ -1090,6 +1091,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
os
.
getenv
(
"VLLM_LOG_BATCHSIZE_INTERVAL"
,
"-1"
)
),
"VLLM_DISABLE_COMPILE_CACHE"
:
disable_compile_cache
,
# If set to "0", disable LayerName opaque type for layer_name
# parameters in custom ops. Defaults to enabled on torch >= 2.11.
"VLLM_USE_LAYERNAME"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_LAYERNAME"
,
"1"
))),
# If set, vllm will run in development mode, which will enable
# some additional endpoints for developing and debugging,
# e.g. `/reset_prefix_cache`
...
...
vllm/model_executor/layers/attention/attention.py
View file @
f44afef6
...
...
@@ -25,6 +25,9 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
(
LayerNameType
,
_encode_layer_name
,
_resolve_layer_name
,
direct_register_custom_op
,
kv_cache_dtype_str_to_dtype
,
)
...
...
@@ -414,7 +417,9 @@ class Attention(nn.Module, AttentionLayerBase):
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
if
self
.
calculate_kv_scales
:
torch
.
ops
.
vllm
.
maybe_calc_kv_scales
(
query
,
key
,
value
,
self
.
layer_name
)
torch
.
ops
.
vllm
.
maybe_calc_kv_scales
(
query
,
key
,
value
,
_encode_layer_name
(
self
.
layer_name
)
)
output_dtype
=
query
.
dtype
if
self
.
query_quant
is
not
None
:
# quantizing with a simple torch operation enables
...
...
@@ -466,6 +471,7 @@ class Attention(nn.Module, AttentionLayerBase):
)
else
:
# Skip this if sharing KV cache with an earlier attention layer.
encoded
=
_encode_layer_name
(
self
.
layer_name
)
if
(
not
self
.
attn_backend
.
forward_includes_kv_cache_update
and
self
.
kv_sharing_target_layer_name
is
None
...
...
@@ -473,14 +479,14 @@ class Attention(nn.Module, AttentionLayerBase):
and
value
is
not
None
):
kv_cache_dummy_dep
=
torch
.
ops
.
vllm
.
unified_kv_cache_update
(
key
,
value
,
self
.
layer_name
key
,
value
,
encoded
)
torch
.
ops
.
vllm
.
unified_attention_with_output
(
query
,
key
,
value
,
output
,
self
.
layer_name
,
encoded
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
return
output
.
view
(
-
1
,
hidden_size
)
...
...
@@ -553,8 +559,9 @@ def maybe_calc_kv_scales(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
LayerNameType
,
)
->
None
:
layer_name
=
_resolve_layer_name
(
layer_name
)
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
...
...
@@ -570,7 +577,7 @@ def maybe_calc_kv_scales_fake(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
LayerNameType
,
)
->
None
:
return
...
...
@@ -622,12 +629,13 @@ def get_attention_context(
def
unified_kv_cache_update
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
LayerNameType
,
)
->
torch
.
Tensor
:
"""
Returns a dummy that is passed to unified_attention to signal a side effect and
the data dependency between them to ensure torch.compile preserves ordering.
"""
layer_name
=
_resolve_layer_name
(
layer_name
)
_
,
attn_layer
,
kv_cache
,
layer_slot_mapping
=
get_attention_context
(
layer_name
)
if
layer_slot_mapping
is
not
None
:
assert
hasattr
(
attn_layer
.
impl
,
"do_kv_cache_update"
),
(
...
...
@@ -647,7 +655,7 @@ def unified_kv_cache_update(
def
unified_kv_cache_update_fake
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
LayerNameType
,
)
->
torch
.
Tensor
:
return
torch
.
empty
(
0
,
device
=
key
.
device
,
dtype
=
key
.
dtype
)
...
...
@@ -666,7 +674,7 @@ def unified_attention_with_output(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
LayerNameType
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
kv_cache_dummy_dep
:
torch
.
Tensor
|
None
=
None
,
...
...
@@ -675,6 +683,7 @@ def unified_attention_with_output(
# that ensures torch.compile preserves ordering between KV cache update and
# attention forward.
del
kv_cache_dummy_dep
layer_name
=
_resolve_layer_name
(
layer_name
)
attn_metadata
,
self
,
kv_cache
,
_
=
get_attention_context
(
layer_name
)
self
.
impl
.
forward
(
...
...
@@ -695,7 +704,7 @@ def unified_attention_with_output_fake(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
LayerNameType
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
kv_cache_dummy_dep
:
torch
.
Tensor
|
None
=
None
,
...
...
vllm/model_executor/layers/attention/kv_transfer_utils.py
View file @
f44afef6
...
...
@@ -9,6 +9,7 @@ from vllm.distributed.kv_transfer import (
has_kv_transfer_group
,
is_v1_kv_transfer_group
,
)
from
vllm.utils.torch_utils
import
_resolve_layer_name
def
maybe_transfer_kv_layer
(
func
:
Callable
)
->
Callable
:
...
...
@@ -38,7 +39,7 @@ def maybe_transfer_kv_layer(func: Callable) -> Callable:
if
not
has_kv_transfer_group
()
or
not
is_v1_kv_transfer_group
():
return
func
(
*
args
,
**
kwargs
)
layer_name
:
str
=
args
[
layer_name_index
]
layer_name
=
_resolve_layer_name
(
args
[
layer_name_index
]
)
# Extract attention context (metadata, layer, kv_cache, layer_slot_mapping)
attn_metadata
,
_
,
kv_cache
,
_
=
get_attention_context
(
layer_name
)
...
...
vllm/model_executor/layers/attention/mla_attention.py
View file @
f44afef6
...
...
@@ -240,6 +240,9 @@ from vllm.platforms import current_platform
from
vllm.utils.flashinfer
import
has_flashinfer
,
has_nvidia_artifactory
from
vllm.utils.math_utils
import
cdiv
,
round_down
from
vllm.utils.torch_utils
import
(
LayerNameType
,
_encode_layer_name
,
_resolve_layer_name
,
direct_register_custom_op
,
is_quantized_kv_cache
,
kv_cache_dtype_str_to_dtype
,
...
...
@@ -473,7 +476,12 @@ class MLAAttention(nn.Module, AttentionLayerBase):
output_shape
:
torch
.
Size
|
None
=
None
,
)
->
torch
.
Tensor
:
if
self
.
calculate_kv_scales
:
torch
.
ops
.
vllm
.
maybe_calc_kv_scales
(
q
,
kv_c_normed
,
k_pe
,
self
.
layer_name
)
torch
.
ops
.
vllm
.
maybe_calc_kv_scales
(
q
,
kv_c_normed
,
k_pe
,
_encode_layer_name
(
self
.
layer_name
),
)
if
self
.
use_direct_call
:
forward_context
:
ForwardContext
=
get_forward_context
()
...
...
@@ -505,10 +513,11 @@ class MLAAttention(nn.Module, AttentionLayerBase):
)
return
output
else
:
encoded
=
_encode_layer_name
(
self
.
layer_name
)
kv_cache_dummy_dep
=
torch
.
ops
.
vllm
.
unified_mla_kv_cache_update
(
kv_c_normed
,
k_pe
,
self
.
layer_name
,
encoded
,
self
.
kv_cache_dtype
,
self
.
_k_scale
,
)
...
...
@@ -518,7 +527,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
kv_c_normed
,
k_pe
,
output
,
self
.
layer_name
,
encoded
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
return
output
...
...
@@ -900,7 +909,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
def
unified_mla_kv_cache_update
(
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
LayerNameType
,
kv_cache_dtype
:
str
,
k_scale
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
...
...
@@ -908,6 +917,7 @@ def unified_mla_kv_cache_update(
Returns a dummy that is passed to unified_attention to signal a side effect and
the data dependency between them to ensure torch.compile preserves ordering.
"""
layer_name
=
_resolve_layer_name
(
layer_name
)
forward_context
=
get_forward_context
()
attn_layer
=
forward_context
.
no_compile_layers
[
layer_name
]
kv_cache
=
attn_layer
.
kv_cache
...
...
@@ -939,7 +949,7 @@ def unified_mla_kv_cache_update(
def
unified_mla_kv_cache_update_fake
(
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
LayerNameType
,
kv_cache_dtype
:
str
,
k_scale
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
...
...
@@ -959,7 +969,7 @@ def unified_mla_attention_with_output(
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
LayerNameType
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
kv_cache_dummy_dep
:
torch
.
Tensor
|
None
=
None
,
...
...
@@ -968,6 +978,7 @@ def unified_mla_attention_with_output(
# that ensures torch.compile preserves ordering between KV cache update and
# attention forward.
del
kv_cache_dummy_dep
layer_name
=
_resolve_layer_name
(
layer_name
)
attn_metadata
,
layer
,
kv_cache
,
_
=
get_attention_context
(
layer_name
)
layer
.
forward_impl
(
q
,
...
...
@@ -986,7 +997,7 @@ def unified_mla_attention_with_output_fake(
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
LayerNameType
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
kv_cache_dummy_dep
:
torch
.
Tensor
|
None
=
None
,
...
...
vllm/model_executor/layers/attention/static_sink_attention.py
View file @
f44afef6
...
...
@@ -10,7 +10,12 @@ from vllm.logger import init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
vllm.utils.torch_utils
import
(
LayerNameType
,
_encode_layer_name
,
_resolve_layer_name
,
direct_register_custom_op
,
)
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionMetadata
,
...
...
@@ -170,7 +175,9 @@ class StaticSinkAttention(Attention, CustomOp):
)
if
not
self
.
sink_populated
:
self_kv_cache
=
self
.
kv_cache
torch
.
ops
.
vllm
.
maybe_populate_sink
(
self_kv_cache
,
self
.
layer_name
)
torch
.
ops
.
vllm
.
maybe_populate_sink
(
self_kv_cache
,
_encode_layer_name
(
self
.
layer_name
)
)
return
super
().
forward
(
query
,
key
,
value
,
output_shape
)
...
...
@@ -224,8 +231,9 @@ class StaticSinkAttention(Attention, CustomOp):
def
maybe_populate_sink
(
self_kv_cache
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
LayerNameType
,
)
->
None
:
layer_name
=
_resolve_layer_name
(
layer_name
)
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
if
self
.
sink_populated
or
self_kv_cache
.
numel
()
==
0
:
...
...
@@ -235,7 +243,7 @@ def maybe_populate_sink(
def
maybe_populate_sink_fake
(
self_kv_cache
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
LayerNameType
,
)
->
None
:
return
...
...
vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py
View file @
f44afef6
...
...
@@ -32,15 +32,15 @@ from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
)
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
(
HAS_OPAQUE_TYP
E
,
Module
Name
,
_USE_LAYERNAM
E
,
Layer
Name
,
direct_register_custom_op
,
)
def
get_layer_from_name
(
layer_name
:
str
)
->
torch
.
nn
.
Module
:
forward_context
:
ForwardContext
=
get_forward_context
()
if
not
HAS_OPAQUE_TYP
E
and
layer_name
==
"from_forward_context"
:
if
not
_USE_LAYERNAM
E
and
layer_name
==
"from_forward_context"
:
all_moe_layers
=
forward_context
.
all_moe_layers
assert
all_moe_layers
is
not
None
moe_layer_index
=
forward_context
.
moe_layer_index
...
...
@@ -55,21 +55,21 @@ def get_layer_from_name(layer_name: str) -> torch.nn.Module:
return
forward_context
.
no_compile_layers
[
layer_name
]
# On torch >= 2.11, layer_name is a hoisted
Module
Name opaque object;
# On torch >= 2.11, layer_name is a hoisted
Layer
Name opaque object;
# on older versions it remains a plain str.
if
TYPE_CHECKING
:
from
typing
import
TypeAlias
_layer_name_type
:
TypeAlias
=
str
|
Module
Name
_layer_name_type
:
TypeAlias
=
str
|
Layer
Name
else
:
_layer_name_type
=
Module
Name
if
HAS_OPAQUE_TYP
E
else
str
_layer_name_type
=
Layer
Name
if
_USE_LAYERNAM
E
else
str
@
torch
.
compiler
.
assume_constant_result
def
_resolve_layer_name
(
layer_name
:
str
|
Module
Name
)
->
str
:
def
_resolve_layer_name
(
layer_name
:
str
|
Layer
Name
)
->
str
:
from
torch._library.fake_class_registry
import
FakeScriptObject
if
isinstance
(
layer_name
,
Module
Name
):
if
isinstance
(
layer_name
,
Layer
Name
):
return
layer_name
.
value
elif
isinstance
(
layer_name
,
FakeScriptObject
):
return
layer_name
.
real_obj
.
value
...
...
@@ -331,9 +331,9 @@ class MoERunnerBase(MoERunner):
assert
len
(
trunc_sizes
)
==
1
return
func
(
states
,
trunc_sizes
[
0
])
def
_encode_layer_name
(
self
)
->
str
|
Module
Name
:
if
HAS_OPAQUE_TYP
E
:
return
Module
Name
(
self
.
layer_name
)
def
_encode_layer_name
(
self
)
->
str
|
Layer
Name
:
if
_USE_LAYERNAM
E
:
return
Layer
Name
(
self
.
layer_name
)
# Can be unavailable or None in unittests
if
(
is_forward_context_available
()
...
...
vllm/model_executor/layers/mamba/gdn_linear_attn.py
View file @
f44afef6
...
...
@@ -56,7 +56,12 @@ from vllm.model_executor.utils import set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.configs.qwen3_next
import
Qwen3NextConfig
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
vllm.utils.torch_utils
import
(
LayerNameType
,
_encode_layer_name
,
_resolve_layer_name
,
direct_register_custom_op
,
)
from
vllm.v1.attention.backend
import
AttentionMetadata
from
vllm.v1.attention.backends.gdn_attn
import
GDNAttentionMetadata
...
...
@@ -568,7 +573,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
b
,
a
,
core_attn_out
,
self
.
prefix
,
_encode_layer_name
(
self
.
prefix
)
,
)
# ============================================================
...
...
@@ -1084,13 +1089,14 @@ def gdn_attention_core(
b
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
core_attn_out
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
LayerNameType
,
)
->
None
:
"""
Custom op for the core attention computation.
Only handles the convolution + recurrent attention part.
Input/output projections are handled outside this op.
"""
layer_name
=
_resolve_layer_name
(
layer_name
)
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
_forward_core
(
...
...
@@ -1106,7 +1112,7 @@ def gdn_attention_core_fake(
b
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
core_attn_out
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
LayerNameType
,
)
->
None
:
"""Fake implementation for torch.compile."""
return
...
...
vllm/model_executor/layers/mamba/mamba_mixer.py
View file @
f44afef6
...
...
@@ -36,7 +36,12 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
vllm.utils.torch_utils
import
(
LayerNameType
,
_encode_layer_name
,
_resolve_layer_name
,
direct_register_custom_op
,
)
from
vllm.v1.attention.backends.mamba1_attn
import
Mamba1AttentionMetadata
...
...
@@ -228,7 +233,7 @@ class MambaMixer(MambaBase, PluggableLayer):
torch
.
ops
.
vllm
.
mamba_mixer
(
hidden_states
,
output
,
self
.
prefix
,
_encode_layer_name
(
self
.
prefix
)
,
)
def
forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
):
...
...
@@ -515,8 +520,9 @@ def split_batch_to_prefill_and_decode(
def
mamba_mixer
(
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
LayerNameType
,
)
->
None
:
layer_name
=
_resolve_layer_name
(
layer_name
)
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
forward_impl
(
hidden_states
=
hidden_states
,
output
=
output
)
...
...
@@ -525,7 +531,7 @@ def mamba_mixer(
def
mamba_mixer_fake
(
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
LayerNameType
,
)
->
None
:
return
...
...
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
f44afef6
...
...
@@ -44,7 +44,12 @@ from vllm.model_executor.model_loader.weight_utils import (
from
vllm.model_executor.parameter
import
BasevLLMParameter
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
vllm.utils.torch_utils
import
(
LayerNameType
,
_encode_layer_name
,
_resolve_layer_name
,
direct_register_custom_op
,
)
from
vllm.v1.attention.backend
import
AttentionMetadata
from
vllm.v1.attention.backends.mamba2_attn
import
Mamba2AttentionMetadata
...
...
@@ -536,7 +541,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
torch
.
ops
.
vllm
.
mamba_mixer2
(
projected_states
,
ssm_output
,
self
.
prefix
,
_encode_layer_name
(
self
.
prefix
)
,
)
# 4. gated MLP
...
...
@@ -927,8 +932,9 @@ class MambaMixer2(MambaBase, PluggableLayer):
def
mamba_mixer2
(
projected_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
LayerNameType
,
)
->
None
:
layer_name
=
_resolve_layer_name
(
layer_name
)
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
conv_ssm_forward
(
projected_states
=
projected_states
,
output
=
output
)
...
...
@@ -937,7 +943,7 @@ def mamba_mixer2(
def
mamba_mixer2_fake
(
projected_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
LayerNameType
,
)
->
None
:
return
...
...
vllm/model_executor/layers/sparse_attn_indexer.py
View file @
f44afef6
...
...
@@ -11,7 +11,12 @@ from vllm.logger import init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.platforms
import
current_platform
from
vllm.utils.deep_gemm
import
fp8_mqa_logits
,
fp8_paged_mqa_logits
,
has_deep_gemm
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
vllm.utils.torch_utils
import
(
LayerNameType
,
_encode_layer_name
,
_resolve_layer_name
,
direct_register_custom_op
,
)
from
vllm.v1.attention.backends.mla.indexer
import
(
DeepseekV32IndexerMetadata
,
)
...
...
@@ -30,7 +35,7 @@ RADIX_TOPK_WORKSPACE_SIZE = 1024 * 1024
def
sparse_attn_indexer
(
hidden_states
:
torch
.
Tensor
,
k_cache_prefix
:
str
,
k_cache_prefix
:
LayerNameType
,
kv_cache
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
...
...
@@ -46,6 +51,7 @@ def sparse_attn_indexer(
# careful! this will be None in dummy run
attn_metadata
=
get_forward_context
().
attn_metadata
fp8_dtype
=
current_platform
.
fp8_dtype
()
k_cache_prefix
=
_resolve_layer_name
(
k_cache_prefix
)
# assert isinstance(attn_metadata, dict)
if
not
isinstance
(
attn_metadata
,
dict
):
...
...
@@ -253,7 +259,7 @@ def sparse_attn_indexer(
def
sparse_attn_indexer_fake
(
hidden_states
:
torch
.
Tensor
,
k_cache_prefix
:
str
,
k_cache_prefix
:
LayerNameType
,
kv_cache
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
...
...
@@ -342,7 +348,7 @@ class SparseAttnIndexer(CustomOp):
):
return
torch
.
ops
.
vllm
.
sparse_attn_indexer
(
hidden_states
,
self
.
k_cache
.
prefix
,
_encode_layer_name
(
self
.
k_cache
.
prefix
)
,
self
.
k_cache
.
kv_cache
,
q_fp8
,
k
,
...
...
@@ -366,7 +372,7 @@ class SparseAttnIndexer(CustomOp):
if
rocm_aiter_ops
.
is_enabled
():
return
torch
.
ops
.
vllm
.
rocm_aiter_sparse_attn_indexer
(
hidden_states
,
self
.
k_cache
.
prefix
,
_encode_layer_name
(
self
.
k_cache
.
prefix
)
,
self
.
k_cache
.
kv_cache
,
q_fp8
,
k
,
...
...
vllm/utils/torch_utils.py
View file @
f44afef6
...
...
@@ -15,6 +15,7 @@ from packaging import version
from
packaging.version
import
Version
from
torch.library
import
Library
,
infer_schema
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
if
TYPE_CHECKING
:
...
...
@@ -709,37 +710,62 @@ def is_torch_equal(target: str) -> bool:
HAS_OPAQUE_TYPE
=
is_torch_equal_or_newer
(
"2.11.0.dev"
)
# Allow toggling LayerName usage via environment variable.
# Defaults to True on torch >= 2.11, False otherwise.
# Set VLLM_USE_LAYERNAME=0 to disable even on torch >= 2.11.
_USE_LAYERNAME
=
HAS_OPAQUE_TYPE
and
envs
.
VLLM_USE_LAYERNAME
if
HAS_OPAQUE_TYPE
:
from
torch._opaque_base
import
OpaqueBase
else
:
OpaqueBase
=
object
# type: ignore[misc, assignment]
class
Module
Name
(
OpaqueBase
):
# type: ignore[misc]
class
Layer
Name
(
OpaqueBase
):
# type: ignore[misc]
"""Wraps a module name string for use as a torch opaque type.
When torch >= 2.11, this is registered as a hoisted value-type opaque
object so that torch.compile lifts it as a graph input instead of baking
it as a constant. This avoids per-layer recompilation for MOE ops.
it as a constant. This avoids per-layer recompilation for custom ops
that accept layer name strings (attention, MOE, KV cache, etc.).
"""
def
__init__
(
self
,
value
:
str
):
self
.
value
=
value
def
__eq__
(
self
,
other
):
return
isinstance
(
other
,
Module
Name
)
and
self
.
value
==
other
.
value
return
isinstance
(
other
,
Layer
Name
)
and
self
.
value
==
other
.
value
def
__hash__
(
self
):
return
hash
(
self
.
value
)
def
__fx_repr__
(
self
):
return
(
f
"
Module
Name(
{
self
.
value
!
r
}
)"
,
{
Module
Name
})
return
(
f
"
Layer
Name(
{
self
.
value
!
r
}
)"
,
{
"LayerName"
:
Layer
Name
})
if
HAS_OPAQUE_TYPE
:
from
torch._library.opaque_object
import
register_opaque_type
register_opaque_type
(
ModuleName
,
typ
=
"value"
,
hoist
=
True
)
register_opaque_type
(
LayerName
,
typ
=
"value"
,
hoist
=
True
)
# On torch >= 2.11 (with VLLM_USE_LAYERNAME enabled), custom op
# layer_name parameters use LayerName; otherwise they remain plain str.
if
TYPE_CHECKING
:
from
typing
import
TypeAlias
LayerNameType
:
TypeAlias
=
str
|
LayerName
else
:
LayerNameType
=
LayerName
if
_USE_LAYERNAME
else
str
def
_resolve_layer_name
(
layer_name
:
str
|
LayerName
)
->
str
:
"""Unwrap a LayerName to str, or return str unchanged."""
return
layer_name
.
value
if
isinstance
(
layer_name
,
LayerName
)
else
layer_name
def
_encode_layer_name
(
layer_name
:
str
)
->
str
|
LayerName
:
"""Wrap a str layer name as LayerName when enabled."""
return
LayerName
(
layer_name
)
if
_USE_LAYERNAME
else
layer_name
# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform
...
...
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
View file @
f44afef6
...
...
@@ -9,6 +9,7 @@ import torch
from
vllm.forward_context
import
get_forward_context
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.torch_utils
import
LayerNameType
from
vllm.v1.attention.backends.mla.indexer
import
DeepseekV32IndexerMetadata
from
vllm.v1.attention.ops.common
import
pack_seq_triton
,
unpack_seq_triton
...
...
@@ -459,7 +460,7 @@ def rocm_fp8_mqa_logits(
def
rocm_aiter_sparse_attn_indexer_fake
(
hidden_states
:
torch
.
Tensor
,
k_cache_prefix
:
str
,
k_cache_prefix
:
LayerNameType
,
kv_cache
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
...
...
@@ -486,7 +487,7 @@ def rocm_aiter_sparse_attn_indexer_fake(
def
rocm_aiter_sparse_attn_indexer
(
hidden_states
:
torch
.
Tensor
,
k_cache_prefix
:
str
,
k_cache_prefix
:
LayerNameType
,
kv_cache
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
...
...
@@ -502,6 +503,9 @@ def rocm_aiter_sparse_attn_indexer(
# careful! this will be None in dummy run
attn_metadata
=
get_forward_context
().
attn_metadata
fp8_dtype
=
current_platform
.
fp8_dtype
()
from
vllm.utils.torch_utils
import
_resolve_layer_name
k_cache_prefix
=
_resolve_layer_name
(
k_cache_prefix
)
# assert isinstance(attn_metadata, dict)
if
not
isinstance
(
attn_metadata
,
dict
):
return
rocm_aiter_sparse_attn_indexer_fake
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment