Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0ae89f18
Unverified
Commit
0ae89f18
authored
Mar 26, 2026
by
Bowen Bao
Committed by
GitHub
Mar 26, 2026
Browse files
[Refactor] Move FusedMoE hidden_size roundup to quant_method (#34285)
Signed-off-by:
Bowen Bao
<
bowenbao@amd.com
>
parent
c2b17d71
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
204 additions
and
222 deletions
+204
-222
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+12
-3
vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
.../model_executor/layers/fused_moe/fused_moe_method_base.py
+33
-0
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
...l_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
+3
-13
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+63
-71
vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
+1
-7
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
+16
-2
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+0
-4
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+0
-4
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+25
-37
vllm/model_executor/layers/quantization/quark/quark_moe.py
vllm/model_executor/layers/quantization/quark/quark_moe.py
+51
-79
vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
..._executor/layers/quantization/utils/flashinfer_fp4_moe.py
+0
-1
vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
...el_executor/layers/quantization/utils/flashinfer_utils.py
+0
-1
No files found.
vllm/model_executor/layers/fused_moe/config.py
View file @
0ae89f18
...
...
@@ -229,9 +229,6 @@ class FusedMoEQuantConfig:
_w1
:
FusedMoEQuantDesc
_w2
:
FusedMoEQuantDesc
is_nvfp4_scale_swizzled
:
bool
=
True
# CK MXFP4 (gfx950) padding info for rocm_aiter_ops.fused_moe()
hidden_pad
:
int
=
0
intermediate_pad
:
int
=
0
def
__post_init__
(
self
):
assert
not
self
.
per_act_token_quant
or
self
.
block_shape
is
None
,
(
...
...
@@ -1172,6 +1169,11 @@ class FusedMoEConfig:
# Defaults to in_dtype if not specified.
router_logits_dtype
:
torch
.
dtype
|
None
=
None
# Defaults to hidden_dim if not specified.
hidden_dim_unpadded
:
int
|
None
=
None
# Defaults to intermediate_size_per_partition if not specified.
intermediate_size_per_partition_unpadded
:
int
|
None
=
None
moe_backend
:
str
=
"auto"
max_num_tokens
:
int
=
envs
.
VLLM_MOE_DP_CHUNK_SIZE
has_bias
:
bool
=
False
...
...
@@ -1195,6 +1197,13 @@ class FusedMoEConfig:
if
self
.
router_logits_dtype
is
None
:
self
.
router_logits_dtype
=
self
.
in_dtype
if
self
.
hidden_dim_unpadded
is
None
:
self
.
hidden_dim_unpadded
=
self
.
hidden_dim
if
self
.
intermediate_size_per_partition_unpadded
is
None
:
self
.
intermediate_size_per_partition_unpadded
=
(
self
.
intermediate_size_per_partition
)
@
property
def
tp_size
(
self
):
return
self
.
moe_parallel_config
.
tp_size
...
...
vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
View file @
0ae89f18
...
...
@@ -9,6 +9,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
...
...
@@ -65,6 +66,38 @@ class FusedMoEMethodBase(QuantizeMethodBase):
"""
return
False
def
maybe_roundup_sizes
(
self
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
act_dtype
:
torch
.
dtype
,
moe_parallel_config
:
FusedMoEParallelConfig
,
)
->
tuple
[
int
,
int
]:
"""
Given layer hidden size and intermediate size per partition and MoE
configurations, round up hidden_size and intermediate_size_per_partition
if necessary.
Args:
hidden_size: Layer hidden-size
intermediate_size_per_partition: Intermediate size per partition for
the layer.
act_dtype: Data type of the layer activations.
moe_parallel_config: Fused MoE parallelization strategy configuration.
Return:
A tuple of (rounded_hidden_size, rounded_intermediate_size_per_partition),
where:
- rounded_hidden_size is the possibly rounded up hidden size.
- rounded_intermediate_size_per_partition is the possibly rounded
up intermediate size per partition.
"""
from
.all2all_utils
import
maybe_roundup_layer_hidden_size
return
maybe_roundup_layer_hidden_size
(
hidden_size
,
act_dtype
,
moe_parallel_config
),
intermediate_size_per_partition
def
maybe_make_prepare_finalize
(
self
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
...
...
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
View file @
0ae89f18
...
...
@@ -428,13 +428,9 @@ def triton_kernel_fused_mxfp4_w4a8_experts(
assert
quant_config
.
w1_bias
is
None
or
quant_config
.
w1_bias
.
dtype
==
torch
.
float32
assert
quant_config
.
w2_bias
is
None
or
quant_config
.
w2_bias
.
dtype
==
torch
.
float32
# Shape check: when weights are padded (e.g. hidden_size padded for
# GFX950 swizzle), unpadded_K_w1 carries the original dimension.
expected_K_w1
=
unpadded_K_w1
if
unpadded_K_w1
is
not
None
else
w1
.
shape
[
-
2
]
assert
hidden_states
.
shape
[
-
1
]
==
expected_K_w1
,
(
f
"hidden_states K=
{
hidden_states
.
shape
[
-
1
]
}
!= "
f
"expected K=
{
expected_K_w1
}
(w1 K=
{
w1
.
shape
[
-
2
]
}
)"
)
# Shape check: weights are padded (e.g. hidden_size padded for
# GFX950 swizzle).
assert
hidden_states
.
shape
[
-
1
]
==
w1
.
shape
[
-
2
]
assert
w2
.
shape
[
-
1
]
==
w1
.
shape
[
1
]
E
,
_
,
N
=
w1
.
shape
...
...
@@ -494,12 +490,6 @@ def triton_kernel_fused_mxfp4_w4a8_experts(
unpadded_K
=
unpadded_K_w2
,
)
# When hidden_size was padded for alignment (e.g. GFX950 swizzle),
# the kernel output has the padded dimension. Slice back to the
# original hidden_size so downstream layers see the expected shape.
if
unpadded_N_w2
is
not
None
and
intermediate_cache3
.
shape
[
-
1
]
!=
unpadded_N_w2
:
intermediate_cache3
=
intermediate_cache3
[...,
:
unpadded_N_w2
].
contiguous
()
return
intermediate_cache3
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
0ae89f18
...
...
@@ -210,42 +210,6 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
)
# TODO(rob): move this down to the kernel.
def
maybe_roundup_hidden_size
(
hidden_size
:
int
,
act_dtype
:
torch
.
dtype
,
moe_parallel_config
:
FusedMoEParallelConfig
,
is_lora_enabled
:
bool
,
model_type
:
str
|
None
,
)
->
int
:
"""
Given layer hidden size and MoE configurations, round up hidden_size
if necessary.
Args:
hidden_size: Layer hidden-size
act_dtype: Data type of the layer activations.
moe_parallel_config: Fused MoE parallelization strategy configuration.
is_lora_enabled: True if the engine is enabled with LoRA. This
is used in the case of mxfp4 quantization in selecting the
MxFP4Backend.
model_type: for checking if gpt-oss
Return:
Rounded up hidden_size if rounding up is required based on the configs.
Original hidden size otherwise.
"""
from
vllm.model_executor.layers.fused_moe.all2all_utils
import
(
maybe_roundup_layer_hidden_size
,
)
hidden_size
=
maybe_roundup_layer_hidden_size
(
hidden_size
,
act_dtype
,
moe_parallel_config
)
return
hidden_size
# --8<-- [start:fused_moe]
@
CustomOp
.
register
(
"fused_moe"
)
class
FusedMoE
(
CustomOp
):
...
...
@@ -459,7 +423,7 @@ class FusedMoE(CustomOp):
),
"Aiter Fused MoE kernel only supports expert_map with 0 and 1s."
assert
intermediate_size
%
self
.
tp_size
==
0
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
tp_size
intermediate_size_per_partition
=
intermediate_size
//
self
.
tp_size
self
.
reduce_results
=
reduce_results
self
.
renormalize
=
renormalize
...
...
@@ -501,28 +465,13 @@ class FusedMoE(CustomOp):
)
self
.
routing_method_type
:
RoutingMethodType
=
self
.
router
.
routing_method_type
# Round up hidden size before creating moe_config.
# This way moe_config is created with the correct hidden_size from the start.
unpadded_hidden_size
=
hidden_size
self
.
model_type
=
(
self
.
vllm_config
.
model_config
.
hf_config
.
model_type
if
self
.
vllm_config
.
model_config
is
not
None
else
None
)
hidden_size
=
maybe_roundup_hidden_size
(
hidden_size
=
hidden_size
,
act_dtype
=
moe_in_dtype
,
moe_parallel_config
=
self
.
moe_parallel_config
,
is_lora_enabled
=
vllm_config
.
lora_config
is
not
None
,
model_type
=
self
.
model_type
,
)
self
.
hidden_size
=
hidden_size
self
.
moe_config
:
FusedMoEConfig
=
FusedMoEConfig
(
num_experts
=
self
.
global_num_experts
,
experts_per_token
=
top_k
,
hidden_dim
=
hidden_size
,
intermediate_size_per_partition
=
self
.
intermediate_size_per_partition
,
hidden_dim_unpadded
=
hidden_size
,
intermediate_size_per_partition
=
intermediate_size_per_partition
,
intermediate_size_per_partition_unpadded
=
intermediate_size_per_partition
,
num_local_experts
=
self
.
local_num_experts
,
num_logical_experts
=
self
.
logical_num_experts
,
moe_parallel_config
=
self
.
moe_parallel_config
,
...
...
@@ -567,13 +516,6 @@ class FusedMoE(CustomOp):
# for heuristic purposes, so it must be initialized first.
self
.
quant_method
:
FusedMoEMethodBase
=
_get_quant_method
()
# Quant methods (e.g. Mxfp4MoEMethod) may round up hidden_dim
# and intermediate_size in moe_config during __init__. Sync
# self.hidden_size so downstream consumers (e.g. LoRA) see the
# padded value.
if
self
.
moe_config
.
hidden_dim
!=
self
.
hidden_size
:
self
.
hidden_size
=
self
.
moe_config
.
hidden_dim
if
not
self
.
moe_config
.
is_act_and_mul
and
not
current_platform
.
is_cuda_alike
():
raise
NotImplementedError
(
"is_act_and_mul=False is supported only for CUDA and ROCm for now"
...
...
@@ -591,11 +533,24 @@ class FusedMoE(CustomOp):
f
"EPLB is not supported
{
self
.
quant_method
.
__class__
.
__name__
}
."
)
# Round up hidden size and update moe_config.
hidden_size
,
intermediate_size_per_partition
=
(
self
.
quant_method
.
maybe_roundup_sizes
(
hidden_size
,
intermediate_size_per_partition
,
moe_in_dtype
,
self
.
moe_parallel_config
,
)
)
self
.
moe_config
.
hidden_dim
=
hidden_size
self
.
moe_config
.
intermediate_size_per_partition
=
(
intermediate_size_per_partition
)
moe_quant_params
=
{
"num_experts"
:
self
.
local_num_experts
,
"hidden_size"
:
self
.
hidden_size
,
"unpadded_hidden_size"
:
unpadded_hidden_size
,
"intermediate_size_per_partition"
:
self
.
intermediate_size_per_partition
,
"hidden_size"
:
hidden_size
,
"intermediate_size_per_partition"
:
intermediate_size_per_partition
,
"params_dtype"
:
params_dtype
,
"weight_loader"
:
self
.
weight_loader
,
"global_num_experts"
:
self
.
global_num_experts
,
...
...
@@ -933,9 +888,17 @@ class FusedMoE(CustomOp):
# Only narrow if the loaded_weight is not a scalar (0-dim tensor)
# and we're not loading the full weight
if
not
load_full
and
loaded_weight
.
ndim
>
0
:
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
# Handle padding: loaded_weight might be smaller than shard_size on last
# TP rank
start_offset
=
shard_size
*
tp_rank
available
=
loaded_weight
.
shape
[
shard_dim
]
-
start_offset
if
available
<=
0
:
# If there is no available weight to load for this TP rank
# (can happen on last TP rank with padding), we can skip
# loading and return early
return
narrow_size
=
min
(
shard_size
,
available
)
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
start_offset
,
narrow_size
)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if
shard_id
==
"w1"
:
...
...
@@ -944,6 +907,13 @@ class FusedMoE(CustomOp):
else
:
assert
shard_id
==
"w3"
expert_data
=
expert_data
.
narrow
(
shard_dim
,
shard_size
,
shard_size
)
# Handle padding: if loaded_weight is smaller than expert_data (can happen
# on last TP shard with padding), copy to top-left corner
if
expert_data
.
shape
!=
loaded_weight
.
shape
:
expert_data
=
expert_data
[
:
loaded_weight
.
shape
[
0
],
:
loaded_weight
.
shape
[
1
]
]
expert_data
.
copy_
(
loaded_weight
)
def
_load_w2
(
...
...
@@ -961,10 +931,24 @@ class FusedMoE(CustomOp):
# Only narrow if the loaded_weight is not a scalar (0-dim tensor)
# and we're not loading the full weight
if
not
load_full
and
loaded_weight
.
ndim
>
0
:
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
# Handle padding: loaded_weight might be smaller than shard_size on last
# TP rank
start_offset
=
shard_size
*
tp_rank
available
=
loaded_weight
.
shape
[
shard_dim
]
-
start_offset
if
available
<=
0
:
# If there is no available weight to load for this TP rank
# (can happen on last TP rank with padding), we can skip
# loading and return early
return
narrow_size
=
min
(
shard_size
,
available
)
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
start_offset
,
narrow_size
)
# w2, down_proj: Load into only logical weight of w2.
# Handle padding: if loaded_weight is smaller than expert_data (can happen
# on last TP shard with padding), copy to top-left corner
if
expert_data
.
shape
!=
loaded_weight
.
shape
:
expert_data
=
expert_data
[
:
loaded_weight
.
shape
[
0
],
:
loaded_weight
.
shape
[
1
]
]
expert_data
.
copy_
(
loaded_weight
)
def
_load_single_value
(
...
...
@@ -1549,6 +1533,14 @@ class FusedMoE(CustomOp):
]
]
@
property
def
hidden_size
(
self
)
->
int
:
return
self
.
moe_config
.
hidden_dim
@
property
def
intermediate_size_per_partition
(
self
)
->
int
:
return
self
.
moe_config
.
intermediate_size_per_partition
def
extra_repr
(
self
)
->
str
:
s
=
(
f
"global_num_experts=
{
self
.
global_num_experts
}
, "
...
...
vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
View file @
0ae89f18
...
...
@@ -779,8 +779,6 @@ def make_mxfp4_moe_quant_config(
w2_scale
:
Union
[
torch
.
Tensor
,
"PrecisionConfig"
],
w1_bias
:
torch
.
Tensor
|
None
=
None
,
w2_bias
:
torch
.
Tensor
|
None
=
None
,
hidden_pad
:
int
=
0
,
intermediate_pad
:
int
=
0
,
)
->
FusedMoEQuantConfig
|
None
:
"""Create a FusedMoEQuantConfig for the given MXFP4 backend."""
if
mxfp4_backend
in
(
...
...
@@ -802,16 +800,12 @@ def make_mxfp4_moe_quant_config(
Mxfp4MoeBackend
.
FLASHINFER_CUTLASS_MXFP4_BF16
,
Mxfp4MoeBackend
.
CK
,
):
config
=
mxfp4_w4a16_moe_quant_config
(
return
mxfp4_w4a16_moe_quant_config
(
w1_bias
=
w1_bias
,
w2_bias
=
w2_bias
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
)
if
mxfp4_backend
==
Mxfp4MoeBackend
.
CK
:
config
.
hidden_pad
=
hidden_pad
config
.
intermediate_pad
=
intermediate_pad
return
config
else
:
return
ocp_mx_moe_quant_config
(
quant_dtype
=
"mxfp4"
,
...
...
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
View file @
0ae89f18
...
...
@@ -10,6 +10,7 @@ from vllm._aiter_ops import rocm_aiter_ops
from
vllm.model_executor.layers.fused_moe.activation
import
MoEActivation
from
vllm.model_executor.layers.fused_moe.config
import
(
FUSED_MOE_UNQUANTIZED_CONFIG
,
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
)
...
...
@@ -186,6 +187,7 @@ def rocm_aiter_fused_experts(
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
moe_config
:
FusedMoEConfig
,
activation
:
MoEActivation
=
MoEActivation
.
SILU
,
apply_router_weight_on_input
:
bool
=
False
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
...
...
@@ -276,6 +278,17 @@ def rocm_aiter_fused_experts(
"Only support topk=1 when `apply_router_weight_on_input` is True"
)
# Compute padding on-the-fly for CK MXFP4 kernels
hidden_pad
=
0
intermediate_pad
=
0
assert
moe_config
.
hidden_dim_unpadded
is
not
None
assert
moe_config
.
intermediate_size_per_partition_unpadded
is
not
None
hidden_pad
=
hidden_states
.
shape
[
1
]
-
moe_config
.
hidden_dim_unpadded
intermediate_pad
=
(
moe_config
.
intermediate_size_per_partition
-
moe_config
.
intermediate_size_per_partition_unpadded
)
return
rocm_aiter_ops
.
fused_moe
(
hidden_states
,
w1
,
...
...
@@ -292,8 +305,8 @@ def rocm_aiter_fused_experts(
doweight_stage1
=
apply_router_weight_on_input
,
num_local_tokens
=
num_local_tokens
,
output_dtype
=
output_dtype
,
hidden_pad
=
quant_config
.
hidden_pad
,
intermediate_pad
=
quant_config
.
intermediate_pad
,
hidden_pad
=
hidden_pad
,
intermediate_pad
=
intermediate_pad
,
bias1
=
quant_config
.
w1_bias
if
quant_config
.
use_mxfp4_w4a16
else
None
,
bias2
=
quant_config
.
w2_bias
if
quant_config
.
use_mxfp4_w4a16
else
None
,
)
...
...
@@ -419,6 +432,7 @@ class AiterExperts(mk.FusedMoEExpertsModular):
apply_router_weight_on_input
=
apply_router_weight_on_input
,
expert_map
=
expert_map
,
quant_config
=
self
.
quant_config
,
moe_config
=
self
.
moe_config
,
a1q_scale
=
a1q_scale
,
num_local_tokens
=
num_local_tokens
,
output_dtype
=
output
.
dtype
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
0ae89f18
...
...
@@ -715,8 +715,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
layer
.
intermediate_size_per_partition
=
intermediate_size_per_partition
layer
.
hidden_size
=
hidden_size
layer
.
num_experts
=
num_experts
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
...
...
@@ -2274,8 +2272,6 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
layer
.
intermediate_size_per_partition
=
intermediate_size_per_partition
layer
.
hidden_size
=
hidden_size
layer
.
num_experts
=
num_experts
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
0ae89f18
...
...
@@ -672,8 +672,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
layer
.
intermediate_size_per_partition
=
intermediate_size_per_partition
layer
.
hidden_size
=
hidden_size
layer
.
num_experts
=
num_experts
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
...
...
@@ -1011,8 +1009,6 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
layer
.
intermediate_size_per_partition
=
intermediate_size_per_partition
layer
.
hidden_size
=
hidden_size
layer
.
num_experts
=
num_experts
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
...
...
vllm/model_executor/layers/quantization/mxfp4.py
View file @
0ae89f18
...
...
@@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe import (
)
from
vllm.model_executor.layers.fused_moe
import
modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.oracle.mxfp4
import
(
...
...
@@ -107,18 +108,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self
.
_cache_permute_indices
:
dict
[
torch
.
Size
,
torch
.
Tensor
]
=
{}
self
.
moe_kernel
:
mk
.
FusedMoEKernel
|
None
=
None
# Round up dims once based on backend. This mutates the shared
# FusedMoEConfig in-place so that create_weights() and all
# downstream code see the padded dimensions. This must happen
# before create_weights() is called.
self
.
moe
.
hidden_dim
,
self
.
moe
.
intermediate_size_per_partition
=
(
mxfp4_round_up_hidden_size_and_intermediate_size
(
self
.
mxfp4_backend
,
self
.
moe
.
hidden_dim
,
self
.
moe
.
intermediate_size_per_partition
,
)
)
# Used for triton kernel precision configs
self
.
w13_precision_config
=
None
self
.
w2_precision_config
=
None
...
...
@@ -129,6 +118,23 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# so can skip the padding in the forward before applying the moe method
return
self
.
mxfp4_backend
==
Mxfp4MoeBackend
.
FLASHINFER_TRTLLM_MXFP4_MXFP8
def
maybe_roundup_sizes
(
self
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
act_dtype
:
torch
.
dtype
,
moe_parallel_config
:
FusedMoEParallelConfig
,
)
->
tuple
[
int
,
int
]:
hidden_size
,
intermediate_size_per_partition
=
super
().
maybe_roundup_sizes
(
hidden_size
=
hidden_size
,
intermediate_size_per_partition
=
intermediate_size_per_partition
,
act_dtype
=
act_dtype
,
moe_parallel_config
=
moe_parallel_config
,
)
return
mxfp4_round_up_hidden_size_and_intermediate_size
(
self
.
mxfp4_backend
,
hidden_size
,
intermediate_size_per_partition
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -143,32 +149,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
scale_dtype
=
torch
.
uint8
mxfp4_block
=
32
# Use pre-rounded sizes from config
self
.
intermediate_size
=
intermediate_size_per_partition_after_pad
=
(
self
.
moe
.
intermediate_size_per_partition
)
self
.
hidden_size
=
hidden_size
=
self
.
moe
.
hidden_dim
# Expose padded dimensions on the layer for LoRA and Marlin code
# that reads layer.hidden_size / layer.intermediate_size_per_partition.
layer
.
params_dtype
=
params_dtype
layer
.
num_experts
=
num_experts
layer
.
hidden_size
=
hidden_size
layer
.
intermediate_size_per_partition
=
(
intermediate_size_per_partition_after_pad
)
# CK (gfx950) padding info for rocm_aiter_ops.fused_moe()
self
.
hidden_pad
=
extra_weight_attrs
.
get
(
"hidden_pad"
,
0
)
self
.
intermediate_pad
=
(
intermediate_size_per_partition_after_pad
-
intermediate_size_per_partition
)
self
.
intermediate_size
=
intermediate_size_per_partition
self
.
hidden_size
=
hidden_size
# Fused gate_up_proj (column parallel)
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size_per_partition
_after_pad
,
2
*
intermediate_size_per_partition
,
hidden_size
//
2
,
dtype
=
weight_dtype
,
),
...
...
@@ -180,7 +170,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size_per_partition
_after_pad
,
2
*
intermediate_size_per_partition
,
hidden_size
//
mxfp4_block
,
dtype
=
scale_dtype
,
),
...
...
@@ -194,7 +184,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
torch
.
zeros
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
_after_pad
//
2
,
intermediate_size_per_partition
//
2
,
dtype
=
weight_dtype
,
),
requires_grad
=
False
,
...
...
@@ -206,7 +196,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
torch
.
zeros
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
_after_pad
//
mxfp4_block
,
intermediate_size_per_partition
//
mxfp4_block
,
dtype
=
scale_dtype
,
),
requires_grad
=
False
,
...
...
@@ -218,7 +208,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w13_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size_per_partition
_after_pad
,
2
*
intermediate_size_per_partition
,
dtype
=
torch
.
bfloat16
,
),
requires_grad
=
False
,
...
...
@@ -368,8 +358,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w2_scale
=
w2_scale
,
w1_bias
=
w1_bias
,
w2_bias
=
w2_bias
,
hidden_pad
=
self
.
hidden_pad
,
intermediate_pad
=
self
.
intermediate_pad
,
)
def
select_gemm_impl
(
...
...
vllm/model_executor/layers/quantization/quark/quark_moe.py
View file @
0ae89f18
...
...
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe import (
MoEActivation
,
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
fp8_w8a8_moe_quant_config
,
mxfp4_w4a8_moe_quant_config
,
...
...
@@ -27,13 +28,13 @@ from vllm.model_executor.layers.fused_moe.config import (
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
fused_marlin_moe
from
vllm.model_executor.layers.fused_moe.oracle.mxfp4
import
(
Mxfp4MoeBackend
,
mxfp4_round_up_hidden_size_and_intermediate_size
,
select_mxfp4_moe_backend
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
prepare_fp8_moe_layer_for_marlin
,
)
from
vllm.model_executor.layers.quantization.utils.mxfp4_utils
import
(
CK_MXFP4_MOE_DIM_ALIGNMENT
,
_swizzle_mxfp4
,
)
from
vllm.model_executor.layers.quantization.utils.ocp_mx_utils
import
(
...
...
@@ -49,7 +50,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.utils.math_utils
import
round_up
logger
=
init_logger
(
__name__
)
...
...
@@ -173,8 +173,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
layer
.
intermediate_size_per_partition
=
intermediate_size_per_partition
layer
.
hidden_size
=
hidden_size
layer
.
num_experts
=
num_experts
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
...
...
@@ -182,7 +180,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
...
...
@@ -194,7 +192,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
torch
.
zeros
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
...
...
@@ -461,6 +459,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
activation
=
layer
.
activation
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
quant_config
=
self
.
moe_quant_config
,
moe_config
=
layer
.
moe_config
,
expert_map
=
layer
.
expert_map
,
)
elif
self
.
use_marlin
:
...
...
@@ -527,7 +526,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
):
params_dtype
=
torch
.
uint32
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
//
8
,
# INT32 packing for W4
...
...
@@ -536,7 +535,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
requires_grad
=
False
,
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
torch
.
zeros
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
//
8
,
# INT32 packing for W4
...
...
@@ -649,6 +648,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
activation
=
layer
.
activation
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
quant_config
=
self
.
moe_quant_config
,
moe_config
=
layer
.
moe_config
,
expert_map
=
layer
.
expert_map
,
)
...
...
@@ -702,6 +702,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
self
.
mxfp4_backend
:
Mxfp4MoeBackend
|
None
=
None
if
self
.
ocp_mx_scheme
==
"w_mxfp4"
:
self
.
mxfp4_backend
,
_
=
select_mxfp4_moe_backend
(
moe
)
elif
self
.
ocp_mx_scheme
.
startswith
(
"w_mxfp4"
):
# TODO(bowenbao): refactor and introduce backends for other OCP MX schemes.
self
.
mxfp4_backend
=
Mxfp4MoeBackend
.
NONE
if
self
.
input_quant
is
not
None
:
self
.
static_input_scales
=
not
self
.
input_quant
.
get
(
"is_dynamic"
)
...
...
@@ -734,36 +737,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
self
.
emulate
=
(
not
current_platform
.
supports_mx
()
or
not
self
.
ocp_mx_scheme
.
startswith
(
"w_mxfp4"
)
)
and
(
self
.
mxfp4_backend
is
None
or
not
self
.
use_rocm_aiter_moe
)
# CK's pre-compiled MXFP4 MoE GEMM kernel instances have dimension
# alignment requirements. When violated (e.g. MiniMax-M2.1 with
# TP=4 yields intermediate_size_per_partition=384), AITER raises:
# "device_gemm ... does not support this GEMM problem".
# Fall back to emulation in that case.
# For gpt_oss models, create_weights rounds up the dimensions
# internally, so the alignment check is skipped.
if
(
not
self
.
emulate
and
self
.
use_rocm_aiter_moe
and
self
.
ocp_mx_scheme
is
not
None
and
self
.
ocp_mx_scheme
.
startswith
(
"w_mxfp4"
)
and
self
.
model_type
!=
"gpt_oss"
and
moe
.
intermediate_size_per_partition
%
CK_MXFP4_MOE_DIM_ALIGNMENT
!=
0
):
logger
.
warning_once
(
"AITER CK MXFP4 MoE GEMM does not support "
"intermediate_size_per_partition=%d (not a multiple of %d). "
"This typically happens when intermediate_size / "
"tensor_parallel_size produces an incompatible dimension. "
"Falling back to emulation mode. To avoid this overhead, "
"use a compatible tensor_parallel_size or set "
"VLLM_ROCM_USE_AITER_MOE=0."
,
moe
.
intermediate_size_per_partition
,
CK_MXFP4_MOE_DIM_ALIGNMENT
,
)
self
.
use_rocm_aiter_moe
=
False
self
.
emulate
=
True
)
and
(
self
.
mxfp4_backend
is
None
or
self
.
mxfp4_backend
is
Mxfp4MoeBackend
.
NONE
or
not
self
.
use_rocm_aiter_moe
)
if
self
.
emulate
:
logger
.
warning_once
(
...
...
@@ -780,6 +758,27 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
"The current mode supports native MoE MXFP4 computation"
)
def
maybe_roundup_sizes
(
self
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
act_dtype
:
torch
.
dtype
,
moe_parallel_config
:
FusedMoEParallelConfig
,
)
->
tuple
[
int
,
int
]:
hidden_size
,
intermediate_size_per_partition
=
super
().
maybe_roundup_sizes
(
hidden_size
=
hidden_size
,
intermediate_size_per_partition
=
intermediate_size_per_partition
,
act_dtype
=
act_dtype
,
moe_parallel_config
=
moe_parallel_config
,
)
if
self
.
mxfp4_backend
is
not
None
:
hidden_size
,
intermediate_size_per_partition
=
(
mxfp4_round_up_hidden_size_and_intermediate_size
(
self
.
mxfp4_backend
,
hidden_size
,
intermediate_size_per_partition
)
)
return
hidden_size
,
intermediate_size_per_partition
def
get_packed_dim
(
self
,
dim
:
int
,
quant_dtype
:
str
):
if
quant_dtype
==
"mxfp4"
:
assert
dim
%
2
==
0
...
...
@@ -805,40 +804,12 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
)
params_dtype
=
torch
.
uint8
self
.
intermediate_size_per_partition
=
intermediate_size_per_partition
if
self
.
model_type
==
"gpt_oss"
:
if
current_platform
.
is_rocm
():
intermediate_size_per_partition_after_pad
=
round_up
(
intermediate_size_per_partition
,
256
)
else
:
intermediate_size_per_partition_after_pad
=
round_up
(
intermediate_size_per_partition
,
64
)
else
:
intermediate_size_per_partition_after_pad
=
intermediate_size_per_partition
self
.
unpadded_hidden_size
=
extra_weight_attrs
.
get
(
"unpadded_hidden_size"
,
hidden_size
)
# On GFX950, the GFX950MXScaleLayout swizzle requires
# hidden_size to be a multiple of 256 (SCALE_K = hidden_size / 32
# must be divisible by 8). Pad hidden_size for weight/scale
# allocation; the original value is preserved in unpadded_hidden_size.
# Only applies to the native (non-emulated) CK path on GFX950.
if
(
self
.
model_type
==
"gpt_oss"
and
current_platform
.
is_rocm
()
and
not
self
.
emulate
):
hidden_size
=
round_up
(
hidden_size
,
256
)
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size_per_partition
_after_pad
,
2
*
intermediate_size_per_partition
,
self
.
get_packed_dim
(
hidden_size
,
self
.
weight_dtype
),
dtype
=
params_dtype
,
),
...
...
@@ -849,12 +820,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
torch
.
zeros
(
num_experts
,
hidden_size
,
self
.
get_packed_dim
(
intermediate_size_per_partition_after_pad
,
self
.
weight_dtype
),
self
.
get_packed_dim
(
intermediate_size_per_partition
,
self
.
weight_dtype
),
dtype
=
params_dtype
,
),
requires_grad
=
False
,
...
...
@@ -867,7 +836,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size_per_partition
_after_pad
,
2
*
intermediate_size_per_partition
,
hidden_size
//
OCP_MX_BLOCK_SIZE
,
dtype
=
params_dtype
,
),
...
...
@@ -877,7 +846,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
torch
.
ones
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
_after_pad
//
OCP_MX_BLOCK_SIZE
,
intermediate_size_per_partition
//
OCP_MX_BLOCK_SIZE
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
...
...
@@ -892,7 +861,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
w13_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size_per_partition
_after_pad
,
2
*
intermediate_size_per_partition
,
dtype
=
torch
.
float32
,
),
requires_grad
=
False
,
...
...
@@ -1072,6 +1041,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
topk_ids
=
topk_ids
,
activation
=
layer
.
activation
,
quant_config
=
self
.
moe_quant_config
,
moe_config
=
layer
.
moe_config
,
expert_map
=
layer
.
expert_map
,
)
else
:
...
...
@@ -1204,6 +1174,8 @@ class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
triton_kernel_moe_forward
,
)
assert
self
.
moe
.
hidden_dim_unpadded
is
not
None
assert
self
.
moe
.
intermediate_size_per_partition_unpadded
is
not
None
return
triton_kernel_moe_forward
(
hidden_states
=
x
,
w1
=
self
.
w13_weight_triton_tensor
,
...
...
@@ -1215,8 +1187,8 @@ class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
expert_map
=
expert_map
,
quant_config
=
self
.
moe_quant_config
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
unpadded_N_w1
=
self
.
intermediate_size_per_partition
*
2
,
unpadded_K_w1
=
self
.
unpadded_hidden_size
,
unpadded_N_w2
=
self
.
unpadded_hidden_size
,
unpadded_K_w2
=
self
.
intermediate_size_per_partition
,
unpadded_N_w1
=
self
.
moe
.
intermediate_size_per_partition
_unpadded
*
2
,
unpadded_K_w1
=
self
.
moe
.
hidden_dim_unpadded
,
unpadded_N_w2
=
self
.
moe
.
hidden_dim_unpadded
,
unpadded_K_w2
=
self
.
moe
.
intermediate_size_per_partition
_unpadded
,
)
vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
View file @
0ae89f18
...
...
@@ -254,7 +254,6 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
w13
,
w13_scale
,
w2
,
w2_scale
,
is_act_and_mul
,
min_alignment
)
)
layer
.
intermediate_size_per_partition
=
padded_intermediate
layer
.
moe_config
.
intermediate_size_per_partition
=
padded_intermediate
w13
,
w13_scale
,
w2
,
w2_scale
=
prepare_static_weights_for_trtllm_fp4_moe
(
...
...
vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
View file @
0ae89f18
...
...
@@ -439,7 +439,6 @@ def prepare_fp8_moe_layer_for_fi(
layer
.
moe_config
.
is_act_and_mul
,
min_alignment
,
)
layer
.
intermediate_size_per_partition
=
new_intermediate
layer
.
moe_config
.
intermediate_size_per_partition
=
new_intermediate
# FI kernels require W31 layout rather than W13.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment