Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
110bbdd5
Commit
110bbdd5
authored
Mar 06, 2026
by
王敏
Browse files
[perf]glm4_moe模型适配rmsquant和silu_quant融合算子
parent
f1a7696f
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
181 additions
and
26 deletions
+181
-26
vllm/envs.py
vllm/envs.py
+6
-0
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+92
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+12
-2
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+17
-7
vllm/model_executor/models/glm4_moe.py
vllm/model_executor/models/glm4_moe.py
+54
-17
No files found.
vllm/envs.py
View file @
110bbdd5
...
...
@@ -281,6 +281,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_MOE_ALIGN
:
bool
=
False
VLLM_USE_MERGE_ATTN_STATES_OPT
:
bool
=
False
USE_FUSED_RMS_QUANT
:
bool
=
False
USE_FUSED_SILU_MUL_QUANT
:
bool
=
False
VLLM_USE_PD_SPLIT
:
bool
=
False
VLLM_USE_PP_SYNC
:
bool
=
False
VLLM_USE_PIECEWISE
:
bool
=
False
...
...
@@ -1804,6 +1805,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"USE_FUSED_RMS_QUANT"
,
"0"
))),
# vllm will use silu_mul_quant fused op
"USE_FUSED_SILU_MUL_QUANT"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"USE_FUSED_SILU_MUL_QUANT"
,
"0"
))),
# vLLM will split prefill and decode, not mix up
"VLLM_USE_PD_SPLIT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_PD_SPLIT"
,
"True"
).
lower
()
in
...
...
vllm/model_executor/layers/layernorm.py
View file @
110bbdd5
...
...
@@ -6,6 +6,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
typing
import
Optional
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.model_executor.custom_op
import
CustomOp
...
...
@@ -14,6 +15,7 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm
import
envs
...
...
@@ -298,6 +300,96 @@ class RMSNorm(CustomOp):
return
s
class
FusedRMSNormQuant
(
nn
.
Module
):
"""Fuse Root mean square normalization and int8 quant.
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
Refer to https://arxiv.org/abs/1910.07467
"""
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
=
1e-6
,
var_hidden_size
:
int
|
None
=
None
,
has_weight
:
bool
=
True
,
dtype
:
torch
.
dtype
|
None
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
variance_epsilon
=
eps
self
.
variance_size_override
=
(
None
if
var_hidden_size
==
hidden_size
else
var_hidden_size
)
weight_dtype
=
dtype
or
torch
.
get_default_dtype
()
self
.
has_weight
=
has_weight
self
.
weight
=
torch
.
ones
(
hidden_size
,
dtype
=
weight_dtype
)
if
self
.
has_weight
:
self
.
weight
=
nn
.
Parameter
(
self
.
weight
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
=
None
,
quant_dtype
:
torch
.
dtype
=
torch
.
int8
,
update_input
:
Optional
[
bool
]
=
True
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
i_q
,
i_s
=
torch
.
ops
.
vllm
.
fused_rmsquant
(
input
=
x
,
weight
=
self
.
weight
,
epsilon
=
self
.
variance_epsilon
,
quant_dtype
=
quant_dtype
,
residual
=
residual
,
update_input
=
update_input
)
return
i_q
,
i_s
,
residual
def
fused_rmsquant_impl
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
update_input
:
Optional
[
bool
]
=
True
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
output
=
torch
.
empty_like
(
input
,
device
=
input
.
device
,
dtype
=
quant_dtype
)
scales
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
from
lightop.op
import
rms_norm_dynamic_per_token_quant
as
ligtop_rms_norm_dynamic_per_token_quant
ligtop_rms_norm_dynamic_per_token_quant
(
output
,
input
,
weight
,
scales
,
epsilon
,
residual
,
update_input
)
return
output
,
scales
def
fused_rmsquant_fake
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
update_input
:
Optional
[
bool
]
=
True
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Fake implementation for torch.compile"""
output
=
torch
.
empty_like
(
input
,
dtype
=
quant_dtype
)
scales
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
return
output
,
scales
# from torch.library import Library
# customer_lib = Library("customer_", "FRAGMENT")
direct_register_custom_op
(
op_name
=
"fused_rmsquant"
,
op_func
=
fused_rmsquant_impl
,
mutates_args
=
[],
fake_impl
=
fused_rmsquant_fake
,
)
# --8<-- [start:gemma_rms_norm]
@
CustomOp
.
register
(
"gemma_rms_norm"
)
class
GemmaRMSNorm
(
CustomOp
):
...
...
vllm/model_executor/layers/linear.py
View file @
110bbdd5
...
...
@@ -654,11 +654,16 @@ class ColumnParallelLinear(LinearBase):
def
forward
(
self
,
input_
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
Parameter
|
None
]:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
# Matrix multiply.
assert
self
.
quant_method
is
not
None
if
iqis
is
not
None
:
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
input_quant_args
=
iqis
)
else
:
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
)
if
self
.
gather_output
and
self
.
tp_size
>
1
:
...
...
@@ -1523,6 +1528,8 @@ class RowParallelLinear(LinearBase):
def
forward
(
self
,
input_
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
Parameter
|
None
]:
if
self
.
input_is_parallel
:
input_parallel
=
input_
...
...
@@ -1537,6 +1544,9 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_
=
None
if
(
self
.
tp_rank
>
0
or
self
.
skip_bias_add
)
else
self
.
bias
if
iqis
is
not
None
:
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias_
,
input_quant_args
=
iqis
)
else
:
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias_
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
110bbdd5
...
...
@@ -7,6 +7,7 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.utils
import
W8a8GetCacheJSON
from
vllm
import
envs
try
:
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
...
...
@@ -167,6 +168,15 @@ def apply_int8_linear(
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
if
envs
.
USE_FUSED_RMS_QUANT
and
input_quant_args
is
not
None
:
assert
len
(
input_quant_args
)
==
2
x_zp
=
None
x_q
,
x_scale
=
input_quant_args
elif
envs
.
USE_FUSED_RMS_QUANT
and
silu_quant_args
is
not
None
:
assert
len
(
silu_quant_args
)
==
2
x_zp
=
None
x_q
,
x_scale
=
silu_quant_args
else
:
symmetric
=
azp_adj
is
None
if
input_scale
is
None
and
input_zero_point
is
None
and
symmetric
is
True
:
x_q
,
x_scale
=
per_token_quant_int8
(
input
)
...
...
vllm/model_executor/models/glm4_moe.py
View file @
110bbdd5
...
...
@@ -44,7 +44,7 @@ from vllm.distributed import (
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
SharedFusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
,
FusedRMSNormQuant
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
...
...
@@ -108,8 +108,14 @@ class Glm4MoeMLP(nn.Module):
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
def
forward
(
self
,
x
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
,
iqis
=
iqis
)
if
envs
.
USE_FUSED_SILU_MUL_QUANT
:
from
lmslim.quantize.quant_ops
import
lm_fuse_silu_mul_quant
xq
,
xs
=
lm_fuse_silu_mul_quant
(
gate_up
)
x
,
_
=
self
.
down_proj
(
gate_up
,
iqis
=
(
xq
,
xs
))
else
:
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
...
...
@@ -321,8 +327,9 @@ class Glm4MoeAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
,
iqis
=
iqis
)
if
not
envs
.
VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE
:
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
...
...
@@ -408,7 +415,16 @@ class Glm4MoeDecoderLayer(nn.Module):
prefix
=
f
"
{
prefix
}
.mlp"
,
)
if
envs
.
USE_FUSED_RMS_QUANT
:
self
.
input_layernorm
=
FusedRMSNormQuant
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
if
envs
.
USE_FUSED_RMS_QUANT
and
isinstance
(
self
.
mlp
,
Glm4MoeMLP
):
self
.
post_attention_layernorm
=
FusedRMSNormQuant
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -420,12 +436,33 @@ class Glm4MoeDecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
envs
.
USE_FUSED_RMS_QUANT
:
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
)
else
:
if
residual
is
None
:
residual
=
hidden_states
.
clone
()
i_q
,
i_s
,
_
=
self
.
input_layernorm
(
x
=
hidden_states
,
residual
=
None
,
quant_dtype
=
torch
.
int8
,
update_input
=
False
)
else
:
i_q
,
i_s
,
residual
=
self
.
input_layernorm
(
x
=
hidden_states
,
residual
=
residual
,
quant_dtype
=
torch
.
int8
,
update_input
=
False
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
iqis
=
(
i_q
,
i_s
))
if
envs
.
USE_FUSED_RMS_QUANT
and
isinstance
(
self
.
mlp
,
Glm4MoeMLP
):
i_q
,
i_s
,
_
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
,
iqis
=
(
i_q
,
i_s
))
else
:
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment