Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
89a8f88b
Commit
89a8f88b
authored
Mar 28, 2026
by
wanglong3
Committed by
zhangzbb
Mar 28, 2026
Browse files
feat: Support rms+quant fusion in minimax_m2 series model.
parent
ca158ae9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
11 deletions
+51
-11
vllm/model_executor/models/minimax_m2.py
vllm/model_executor/models/minimax_m2.py
+51
-11
No files found.
vllm/model_executor/models/minimax_m2.py
View file @
89a8f88b
...
...
@@ -39,7 +39,7 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
,
FusedRMSNormQuant
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
...
...
@@ -58,6 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import (
maybe_remap_kv_scale_name
,
)
from
vllm.sequence
import
IntermediateTensors
from
vllm
import
envs
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
...
...
@@ -229,8 +230,10 @@ class MiniMaxM2Attention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
# iqis: (input_quant, input_scale)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
,
iqis
=
iqis
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
MiniMaxText01RMSNormTP
.
forward_qk
(
self
.
q_norm
,
self
.
k_norm
,
q
.
contiguous
(),
k
.
contiguous
()
...
...
@@ -282,27 +285,64 @@ class MiniMaxM2DecoderLayer(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
if
envs
.
USE_FUSED_RMS_QUANT
:
self
.
input_layernorm
=
FusedRMSNormQuant
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
def
rms_quant_fusion_
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
# Self Attention
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
residual
=
hidden_states
.
clone
()
input_quant
,
input_scale
,
_
=
self
.
input_layernorm
(
x
=
hidden_states
,
residual
=
None
,
quant_dtype
=
torch
.
int8
,
update_input
=
False
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
input_quant
,
input_scale
,
residual
=
self
.
input_layernorm
(
x
=
hidden_states
,
residual
=
residual
,
quant_dtype
=
torch
.
int8
,
update_input
=
False
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
positions
=
positions
,
hidden_states
=
hidden_states
,
iqis
=
(
input_quant
,
input_scale
)
)
return
hidden_states
,
residual
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
if
envs
.
USE_FUSED_RMS_QUANT
:
hidden_states
,
residual
=
self
.
rms_quant_fusion_forward
(
positions
,
hidden_states
,
residual
)
else
:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment