Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f1a7696f
Commit
f1a7696f
authored
Mar 06, 2026
by
王敏
Browse files
[perf]添加Module支持split qkv+rmsnorm+rope+kvcache融合算子,GLM4_MOE完成适配
parent
0786df31
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
261 additions
and
21 deletions
+261
-21
vllm/attention/layer.py
vllm/attention/layer.py
+202
-1
vllm/config/compilation.py
vllm/config/compilation.py
+1
-0
vllm/envs.py
vllm/envs.py
+6
-0
vllm/model_executor/models/glm4_moe.py
vllm/model_executor/models/glm4_moe.py
+52
-20
No files found.
vllm/attention/layer.py
View file @
f1a7696f
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer."""
from
typing
import
cast
from
typing
import
cast
,
Optional
import
torch
import
torch.nn
as
nn
...
...
@@ -195,6 +195,8 @@ class Attention(nn.Module, AttentionLayerBase):
block_size
=
64
if
envs
.
VLLM_USE_FLASH_ATTN_PA
and
envs
.
VLLM_USE_FLASH_MLA
else
16
calculate_kv_scales
=
False
self
.
block_size
=
block_size
# llm-compressor mdls need to set cache_dtype to "fp8" manually.
if
getattr
(
quant_config
,
"kv_cache_scheme"
,
None
)
is
not
None
:
kv_cache_dtype
=
"fp8"
...
...
@@ -494,6 +496,101 @@ class Attention(nn.Module, AttentionLayerBase):
dtype
=
self
.
kv_cache_torch_dtype
,
)
class
FusedQkvSplitRmsNormRopeAttention
(
Attention
):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
|
None
=
None
,
alibi_slopes
:
list
[
float
]
|
None
=
None
,
use_alibi_sqrt
:
bool
|
None
=
None
,
cache_config
:
CacheConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
logits_soft_cap
:
float
|
None
=
None
,
per_layer_sliding_window
:
int
|
None
=
None
,
prefix
:
str
=
""
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
str
|
None
=
None
,
attn_backend
:
type
[
AttentionBackend
]
|
None
=
None
,
head_size_v
:
int
|
None
=
None
,
**
extra_impl_args
,
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
use_alibi_sqrt
,
cache_config
,
quant_config
,
logits_soft_cap
,
per_layer_sliding_window
,
prefix
,
attn_type
,
kv_sharing_target_layer_name
,
attn_backend
,
head_size_v
,
**
extra_impl_args
)
def
forward
(
self
,
qkv
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
weight_q_norm
:
torch
.
Tensor
,
weight_k_norm
:
torch
.
Tensor
,
epsilon
:
float
,
# For some alternate attention backends like MLA the attention output
# shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape.
output_shape
:
torch
.
Size
|
None
=
None
,
is_neox
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
The KV cache is stored inside this class and is accessed via
`self.kv_cache`.
Attention metadata (`attn_metadata`) is set using a context manager in
the model runner's `execute_model` method. It is accessed via forward
context using
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
output_dtype
=
qkv
.
dtype
num_tokens
=
qkv
.
shape
[
0
]
if
output_shape
is
None
:
# Handle both 2D [num_tokens, hidden] and
# 3D [num_tokens, heads, head_dim] query
output_shape
=
torch
.
Size
(
(
num_tokens
,
self
.
num_heads
*
self
.
head_size_v
)
)
output
=
torch
.
empty
(
output_shape
,
dtype
=
output_dtype
,
device
=
qkv
.
device
)
output
=
output
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size_v
)
hidden_size
=
output_shape
[
-
1
]
q_size
=
self
.
num_heads
*
self
.
head_size
kv_size
=
self
.
num_kv_heads
*
self
.
head_size
query
,
key
,
value
=
torch
.
ops
.
vllm
.
fused_qkv_split_rmsnorm_rope_kv_store
(
qkv
=
qkv
,
positions
=
positions
,
layer_name
=
self
.
layer_name
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
cos_sin_cache
=
cos_sin_cache
,
weight_q_norm
=
weight_q_norm
,
weight_k_norm
=
weight_k_norm
,
epsilon
=
epsilon
,
head_size
=
self
.
head_size
,
head_size_v
=
self
.
head_size_v
,
q_size
=
q_size
,
kv_size
=
kv_size
,
block_size
=
self
.
block_size
,
is_neox
=
is_neox
)
kv_cache_dummy_dep
=
None
torch
.
ops
.
vllm
.
unified_attention_with_output
(
query
,
key
,
value
,
output
,
self
.
layer_name
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
return
output
.
view
(
-
1
,
hidden_size
)
class
MLAAttention
(
nn
.
Module
,
AttentionLayerBase
):
"""Multi-Head Latent Attention layer.
...
...
@@ -995,3 +1092,107 @@ direct_register_custom_op(
fake_impl
=
unified_mla_attention_with_output_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
def
fused_qkv_split_rmsnorm_rope_kv_store_impl
(
qkv
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
layer_name
:
str
,
kv_cache_dtype
:
str
,
cos_sin_cache
:
torch
.
Tensor
,
weight_q_norm
:
torch
.
Tensor
,
weight_k_norm
:
torch
.
Tensor
,
epsilon
:
float
,
head_size
:
int
,
head_size_v
:
int
,
q_size
:
int
,
kv_size
:
int
,
block_size
:
int
,
is_neox
:
bool
=
False
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
num_tokens
=
qkv
.
shape
[
0
]
forward_context
=
get_forward_context
()
slot_mapping
=
forward_context
.
slot_mapping
layer_slot_mapping
=
slot_mapping
.
get
(
layer_name
)
assert
isinstance
(
slot_mapping
,
dict
),
(
f
"Expected slot_mapping to be a dict, got
{
type
(
slot_mapping
)
}
. "
)
attn_layer
=
forward_context
.
no_compile_layers
[
layer_name
]
kv_cache
=
attn_layer
.
kv_cache
[
forward_context
.
virtual_engine
]
if
layer_slot_mapping
is
not
None
:
if
current_platform
.
is_rocm
():
key_cache
,
value_cache
=
kv_cache
else
:
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
if
kv_cache_dtype
.
startswith
(
"fp8"
):
# queries are quantized in the attention layer
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionBackend
kv_cache_dtype
=
FlashAttentionBackend
.
get_fp8_dtype_for_flashattn
(
kv_cache_dtype
)
key_cache
=
key_cache
.
view
(
kv_cache_dtype
)
value_cache
=
value_cache
.
view
(
kv_cache_dtype
)
else
:
key_cache
=
torch
.
empty
([
0
],
device
=
qkv
.
device
,
dtype
=
qkv
.
dtype
)
value_cache
=
torch
.
empty
([
0
],
device
=
qkv
.
device
,
dtype
=
qkv
.
dtype
)
from
lightop
import
split_qkv_rms_rotary_embedding_fuse_with_kv_store_quant
q
,
k
,
v
=
split_qkv_rms_rotary_embedding_fuse_with_kv_store_quant
(
positions
,
qkv
.
contiguous
(),
q_size
,
kv_size
,
cos_sin_cache
,
head_dim
=
head_size
,
page_size
=
block_size
,
k_buffer
=
key_cache
,
v_buffer
=
value_cache
,
kv_cache_loc
=
layer_slot_mapping
,
is_neox
=
is_neox
,
weight_q
=
weight_q_norm
,
weight_k
=
weight_k_norm
,
output_dtype
=
qkv
.
dtype
,
kv_cache_dtype
=
kv_cache_dtype
,
epsilon
=
epsilon
,
residual_q
=
None
,
residual_k
=
None
,
k_scale
=
None
,
v_scale
=
None
,
)
q
=
q
.
contiguous
().
view
(
num_tokens
,
q_size
//
head_size
,
head_size
)
k
=
k
.
contiguous
().
view
(
num_tokens
,
kv_size
//
head_size_v
,
head_size_v
)
v
=
v
.
contiguous
().
view
(
num_tokens
,
kv_size
//
head_size_v
,
head_size_v
)
return
q
,
k
,
v
def
fused_qkv_split_rmsnorm_rope_kv_store_fake
(
qkv
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
layer_name
:
str
,
kv_cache_dtype
:
str
,
cos_sin_cache
:
torch
.
Tensor
,
weight_q_norm
:
torch
.
Tensor
,
weight_k_norm
:
torch
.
Tensor
,
epsilon
:
float
,
head_size
:
int
,
head_size_v
:
int
,
q_size
:
int
,
kv_size
:
int
,
block_size
:
int
,
is_neox
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
num_token
=
qkv
.
shape
[
0
]
q
=
torch
.
empty
((
num_token
,
q_size
//
head_size
,
head_size
),
device
=
qkv
.
device
,
dtype
=
qkv
.
dtype
)
k
=
torch
.
empty
((
num_token
,
kv_size
//
head_size_v
,
head_size_v
),
device
=
qkv
.
device
,
dtype
=
qkv
.
dtype
)
v
=
torch
.
empty
((
num_token
,
kv_size
//
head_size_v
,
head_size_v
),
device
=
qkv
.
device
,
dtype
=
qkv
.
dtype
)
return
q
,
k
,
v
direct_register_custom_op
(
op_name
=
"fused_qkv_split_rmsnorm_rope_kv_store"
,
op_func
=
fused_qkv_split_rmsnorm_rope_kv_store_impl
,
mutates_args
=
[
"qkv"
,
"positions"
],
fake_impl
=
fused_qkv_split_rmsnorm_rope_kv_store_fake
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,),
)
\ No newline at end of file
vllm/config/compilation.py
View file @
f1a7696f
...
...
@@ -956,6 +956,7 @@ class CompilationConfig:
# https://github.com/vllm-project/vllm/issues/33267
if
not
self
.
use_inductor_graph_partition
:
self
.
splitting_ops
.
append
(
"vllm::unified_kv_cache_update"
)
self
.
splitting_ops
.
append
(
"vllm::fused_qkv_split_rmsnorm_rope_kv_store"
)
elif
len
(
self
.
splitting_ops
)
==
0
:
if
(
...
...
vllm/envs.py
View file @
f1a7696f
...
...
@@ -302,6 +302,7 @@ if TYPE_CHECKING:
VLLM_USE_MOE_W16A16_TRITON
:
bool
=
False
VLLM_V1_FAST_TOKEN_ID_COPY
:
bool
=
False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER
:
bool
=
False
VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE
:
bool
=
False
def
get_default_cache_root
():
...
...
@@ -1897,6 +1898,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
).
lower
()
in
(
"true"
,
"1"
)
),
#If set to 1/True, enable fuse split qkv+rmsnorm+rope+kv update just like glm4.7 moe attention.
"VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
}
# --8<-- [end:env-vars-definition]
...
...
vllm/model_executor/models/glm4_moe.py
View file @
f1a7696f
...
...
@@ -32,7 +32,8 @@ import torch
from
torch
import
nn
from
transformers.models.glm4_moe
import
Glm4MoeConfig
from
vllm.attention.layer
import
Attention
from
vllm
import
envs
from
vllm.attention.layer
import
Attention
,
FusedQkvSplitRmsNormRopeAttention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
,
get_current_vllm_config
from
vllm.distributed
import
(
...
...
@@ -290,15 +291,27 @@ class Glm4MoeAttention(nn.Module):
max_position
=
max_position_embeddings
,
rope_parameters
=
config
.
rope_parameters
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
if
not
envs
.
VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE
:
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
else
:
self
.
attn
=
FusedQkvSplitRmsNormRopeAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
if
self
.
use_qk_norm
:
self
.
q_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
...
...
@@ -310,17 +323,36 @@ class Glm4MoeAttention(nn.Module):
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
if
self
.
use_qk_norm
:
q
=
self
.
q_norm
(
q
.
reshape
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)).
reshape
(
q
.
shape
)
k
=
self
.
k_norm
(
k
.
reshape
(
-
1
,
self
.
num_kv_heads
,
self
.
head_dim
)).
reshape
(
k
.
shape
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
if
not
envs
.
VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE
:
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
if
self
.
use_qk_norm
:
q
=
self
.
q_norm
(
q
.
reshape
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)).
reshape
(
q
.
shape
)
k
=
self
.
k_norm
(
k
.
reshape
(
-
1
,
self
.
num_kv_heads
,
self
.
head_dim
)).
reshape
(
k
.
shape
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
else
:
cos_sin_cache
=
self
.
rotary_emb
.
cos_sin_cache
if
(
cos_sin_cache
.
device
!=
qkv
.
device
or
cos_sin_cache
.
dtype
!=
qkv
.
dtype
):
cos_sin_cache
=
cos_sin_cache
.
to
(
qkv
.
device
,
dtype
=
qkv
.
dtype
,
non_blocking
=
True
)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self
.
rotary_emb
.
cos_sin_cache
=
cos_sin_cache
attn_output
=
self
.
attn
(
qkv
,
positions
,
cos_sin_cache
,
self
.
q_norm
.
weight
,
self
.
k_norm
.
weight
,
self
.
q_norm
.
variance_epsilon
,
is_neox
=
self
.
rotary_emb
.
is_neox_style
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment