Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a1628458
Commit
a1628458
authored
Mar 02, 2026
by
王敏
Browse files
[feat]添加rmsnorm+int8 quant融合module
parent
c9733a54
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
176 additions
and
14 deletions
+176
-14
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+98
-0
vllm/model_executor/models/glm4_moe.py
vllm/model_executor/models/glm4_moe.py
+78
-14
No files found.
vllm/model_executor/layers/layernorm.py
View file @
a1628458
...
...
@@ -6,6 +6,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
typing
import
Optional
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.model_executor.custom_op
import
CustomOp
...
...
@@ -14,8 +15,11 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm
import
envs
from
lightop
import
rms_norm_dynamic_per_token_quant
def
rms_norm
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
...
...
@@ -296,6 +300,100 @@ class RMSNorm(CustomOp):
s
=
f
"hidden_size=
{
self
.
weight
.
data
.
size
(
0
)
}
"
s
+=
f
", eps=
{
self
.
variance_epsilon
}
"
return
s
class
FusedRMSNormQuant
(
nn
.
Module
):
"""Fuse Root mean square normalization and int8 quant.
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
Refer to https://arxiv.org/abs/1910.07467
"""
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
=
1e-6
,
var_hidden_size
:
int
|
None
=
None
,
has_weight
:
bool
=
True
,
dtype
:
torch
.
dtype
|
None
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
variance_epsilon
=
eps
self
.
variance_size_override
=
(
None
if
var_hidden_size
==
hidden_size
else
var_hidden_size
)
weight_dtype
=
dtype
or
torch
.
get_default_dtype
()
self
.
has_weight
=
has_weight
self
.
weight
=
torch
.
ones
(
hidden_size
,
dtype
=
weight_dtype
)
if
self
.
has_weight
:
self
.
weight
=
nn
.
Parameter
(
self
.
weight
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
=
None
,
quant_dtype
:
torch
.
dtype
=
torch
.
int8
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
x
,
x_scales
=
fused_rmsquant
(
x
,
self
.
weight
,
self
.
variance_epsilon
,
quant_dtype
,
residual
)
return
x
,
x_scales
,
residual
def
fused_rmsquant_impl
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
update_input
:
Optional
[
bool
]
=
True
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
output
,
scales
=
rms_norm_dynamic_per_token_quant
(
input
,
weight
,
epsilon
,
quant_dtype
,
residual
,
update_input
)
return
output
,
scales
def
fused_rmsquant_fake
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
update_input
:
Optional
[
bool
]
=
True
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Fake implementation for torch.compile"""
output
=
torch
.
zeros_like
(
input
,
dtype
=
quant_dtype
)
scales
=
torch
.
ones
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
return
output
,
scales
direct_register_custom_op
(
op_name
=
"rms_norm_dynamic_per_token_quant"
,
op_func
=
fused_rmsquant_impl
,
mutates_args
=
[],
fake_impl
=
fused_rmsquant_fake
,
)
def
fused_rmsquant
(
input
:
torch
.
Tensor
,
rms_weight
:
torch
.
Tensor
,
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
update_input
:
Optional
[
bool
]
=
True
):
i_q
,
_scales
=
torch
.
ops
.
vllm
.
fused_rmsquant
(
input
=
input
,
weight
=
rms_weight
,
epsilon
=
epsilon
,
quant_dtype
=
quant_dtype
,
residual
=
residual
,
update_input
=
update_input
)
return
i_q
,
_scales
# --8<-- [start:gemma_rms_norm]
...
...
vllm/model_executor/models/glm4_moe.py
View file @
a1628458
...
...
@@ -32,6 +32,8 @@ import torch
from
torch
import
nn
from
transformers.models.glm4_moe
import
Glm4MoeConfig
from
vllm
import
envs
from
vllm.attention.layer
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
,
get_current_vllm_config
...
...
@@ -43,7 +45,7 @@ from vllm.distributed import (
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
SharedFusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
,
FusedRMSNormQuant
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
...
...
@@ -71,6 +73,10 @@ from .utils import (
make_layers
,
maybe_prefix
,
)
from
vllm.utils.torch_utils
import
direct_register_custom_op
if
envs
.
VLLM_USE_FUSED_RMS_QUANT
:
from
lightop
import
rms_norm_dynamic_per_token_quant
logger
=
init_logger
(
__name__
)
...
...
@@ -112,6 +118,44 @@ class Glm4MoeMLP(nn.Module):
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
Glm4MoeQuantedMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
QuantizationConfig
|
None
=
None
,
reduce_results
:
bool
=
True
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
,
x_scales
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
,
x_scales
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
Glm4MoE
(
nn
.
Module
):
...
...
@@ -342,6 +386,8 @@ class Glm4MoeDecoderLayer(nn.Module):
layer_idx
=
int
(
prefix
.
split
(
sep
=
"."
)[
-
1
])
self
.
layer_idx
=
layer_idx
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
self_attn
=
Glm4MoeAttention
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
...
...
@@ -368,18 +414,31 @@ class Glm4MoeDecoderLayer(nn.Module):
enable_eplb
=
enable_eplb
,
)
else
:
self
.
mlp
=
Glm4MoeMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
if
not
envs
.
VLLM_USE_FUSED_RMS_QUANT
:
self
.
mlp
=
Glm4MoeMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
else
:
self
.
mlp
=
Glm4MoeQuantedMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
if
not
envs
.
VLLM_USE_FUSED_RMS_QUANT
:
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
post_attention_layernorm
=
FusedRMSNormQuant
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
def
forward
(
...
...
@@ -394,8 +453,13 @@ class Glm4MoeDecoderLayer(nn.Module):
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
if
not
envs
.
VLLM_USE_FUSED_RMS_QUANT
:
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
else
:
hidden_states
,
scales
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
,
scales
)
return
hidden_states
,
residual
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment