Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ce12b407
Unverified
Commit
ce12b407
authored
Dec 16, 2025
by
Ming Yang
Committed by
GitHub
Dec 16, 2025
Browse files
[TRTLLM] Remove the MoE GEMM weight name change (#30713)
Signed-off-by:
Ming Yang
<
minos.future@gmail.com
>
parent
59bd5f6a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
40 deletions
+16
-40
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+4
-12
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+4
-12
vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
..._executor/layers/quantization/utils/flashinfer_fp4_moe.py
+8
-16
No files found.
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
ce12b407
...
...
@@ -469,16 +469,14 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
)
logger
.
debug_once
(
"Finished shuffling weights for TRT-LLM MOE"
)
layer
.
gemm1
_weight
s_fp4_shuffled
=
Parameter
(
layer
.
w13
_weight
=
Parameter
(
gemm1_weights_fp4_shuffled
,
requires_grad
=
False
)
layer
.
gemm2_weights_fp4_shuffled
=
Parameter
(
gemm2_weights_fp4_shuffled
,
requires_grad
=
False
)
layer
.
gemm1_scales_fp4_shuffled
=
Parameter
(
layer
.
w2_weight
=
Parameter
(
gemm2_weights_fp4_shuffled
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
Parameter
(
gemm1_scales_fp4_shuffled
,
requires_grad
=
False
)
layer
.
gemm2_scales_fp4_shuff
le
d
=
Parameter
(
layer
.
w2_weight_sca
le
=
Parameter
(
gemm2_scales_fp4_shuffled
,
requires_grad
=
False
)
...
...
@@ -487,12 +485,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
(
layer
.
w2_input_scale_quant
*
layer
.
g1_alphas
).
to
(
torch
.
float32
),
requires_grad
=
False
,
)
# Clean up weights that won't be used by TRT-LLM
del
layer
.
w2_weight
del
layer
.
w2_weight_scale
del
layer
.
w13_weight
del
layer
.
w13_weight_scale
else
:
# swizzle weight scales
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
ce12b407
...
...
@@ -1458,16 +1458,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
)
logger
.
debug_once
(
"Finished shuffling weights for TRT-LLM MOE"
)
layer
.
gemm1
_weight
s_fp4_shuffled
=
Parameter
(
layer
.
w13
_weight
=
Parameter
(
gemm1_weights_fp4_shuffled
,
requires_grad
=
False
)
layer
.
gemm2_weights_fp4_shuffled
=
Parameter
(
gemm2_weights_fp4_shuffled
,
requires_grad
=
False
)
layer
.
gemm1_scales_fp4_shuffled
=
Parameter
(
layer
.
w2_weight
=
Parameter
(
gemm2_weights_fp4_shuffled
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
Parameter
(
gemm1_scales_fp4_shuffled
,
requires_grad
=
False
)
layer
.
gemm2_scales_fp4_shuff
le
d
=
Parameter
(
layer
.
w2_weight_sca
le
=
Parameter
(
gemm2_scales_fp4_shuffled
,
requires_grad
=
False
)
...
...
@@ -1476,12 +1474,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
(
layer
.
w2_input_scale_quant
*
layer
.
g1_alphas
).
to
(
torch
.
float32
),
requires_grad
=
False
,
)
# Clean up weights that won't be used by TRT-LLM
del
layer
.
w2_weight
del
layer
.
w2_weight_scale
del
layer
.
w13_weight
del
layer
.
w13_weight_scale
elif
self
.
use_marlin
:
# Marlin processing
prepare_moe_fp4_layer_for_marlin
(
layer
)
...
...
vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
View file @
ce12b407
...
...
@@ -301,18 +301,14 @@ def flashinfer_trtllm_fp4_moe(
hidden_states_scale
=
hidden_states_scale_linear_fp4
.
view
(
torch
.
float8_e4m3fn
).
flatten
(),
gemm1_weights
=
layer
.
gemm1_weights_fp4_shuffled
.
data
,
gemm1_weights_scale
=
layer
.
gemm1_scales_fp4_shuffled
.
data
.
view
(
torch
.
float8_e4m3fn
),
gemm1_weights
=
layer
.
w13_weight
.
data
,
gemm1_weights_scale
=
layer
.
w13_weight_scale
.
data
.
view
(
torch
.
float8_e4m3fn
),
gemm1_bias
=
None
,
gemm1_alpha
=
None
,
gemm1_beta
=
None
,
gemm1_clamp_limit
=
None
,
gemm2_weights
=
layer
.
gemm2_weights_fp4_shuffled
.
data
,
gemm2_weights_scale
=
layer
.
gemm2_scales_fp4_shuffled
.
data
.
view
(
torch
.
float8_e4m3fn
),
gemm2_weights
=
layer
.
w2_weight
.
data
,
gemm2_weights_scale
=
layer
.
w2_weight_scale
.
data
.
view
(
torch
.
float8_e4m3fn
),
gemm2_bias
=
None
,
output1_scale_scalar
=
layer
.
g1_scale_c
.
data
,
output1_scale_gate_scalar
=
layer
.
g1_alphas
.
data
,
...
...
@@ -380,18 +376,14 @@ def flashinfer_trtllm_fp4_routed_moe(
hidden_states_scale
=
hidden_states_scale_linear_fp4
.
view
(
torch
.
float8_e4m3fn
).
flatten
(),
gemm1_weights
=
layer
.
gemm1_weights_fp4_shuffled
.
data
,
gemm1_weights_scale
=
layer
.
gemm1_scales_fp4_shuffled
.
data
.
view
(
torch
.
float8_e4m3fn
),
gemm1_weights
=
layer
.
w13_weight
.
data
,
gemm1_weights_scale
=
layer
.
w13_weight_scale
.
data
.
view
(
torch
.
float8_e4m3fn
),
gemm1_bias
=
None
,
gemm1_alpha
=
None
,
gemm1_beta
=
None
,
gemm1_clamp_limit
=
None
,
gemm2_weights
=
layer
.
gemm2_weights_fp4_shuffled
.
data
,
gemm2_weights_scale
=
layer
.
gemm2_scales_fp4_shuffled
.
data
.
view
(
torch
.
float8_e4m3fn
),
gemm2_weights
=
layer
.
w2_weight
.
data
,
gemm2_weights_scale
=
layer
.
w2_weight_scale
.
data
.
view
(
torch
.
float8_e4m3fn
),
gemm2_bias
=
None
,
output1_scale_scalar
=
layer
.
g1_scale_c
.
data
,
output1_scale_gate_scalar
=
layer
.
g1_alphas
.
data
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment