Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
938772af
Unverified
Commit
938772af
authored
Nov 04, 2025
by
bnellnm
Committed by
GitHub
Nov 04, 2025
Browse files
[Kernels] Isolate modular kernel code from FusedMoEMethodBase subclasses. (#27123)
parent
e4ee6586
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
271 additions
and
311 deletions
+271
-311
vllm/distributed/device_communicators/base_device_communicator.py
...tributed/device_communicators/base_device_communicator.py
+2
-2
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+199
-62
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+6
-0
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+0
-2
vllm/model_executor/layers/quantization/bitsandbytes.py
vllm/model_executor/layers/quantization/bitsandbytes.py
+1
-2
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+0
-47
vllm/model_executor/layers/quantization/experts_int8.py
vllm/model_executor/layers/quantization/experts_int8.py
+0
-2
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+9
-26
vllm/model_executor/layers/quantization/gguf.py
vllm/model_executor/layers/quantization/gguf.py
+0
-2
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+0
-2
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+11
-39
vllm/model_executor/layers/quantization/moe_wna16.py
vllm/model_executor/layers/quantization/moe_wna16.py
+0
-2
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+15
-90
vllm/model_executor/layers/quantization/quark/quark_moe.py
vllm/model_executor/layers/quantization/quark/quark_moe.py
+25
-28
vllm/model_executor/layers/quantization/rtn.py
vllm/model_executor/layers/quantization/rtn.py
+0
-2
vllm/model_executor/warmup/deep_gemm_warmup.py
vllm/model_executor/warmup/deep_gemm_warmup.py
+3
-3
No files found.
vllm/distributed/device_communicators/base_device_communicator.py
View file @
938772af
...
...
@@ -266,14 +266,14 @@ class DeviceCommunicatorBase:
module
for
module
in
model
.
modules
()
# TODO(bnell): Should use isinstance but can't. Maybe search for
# presence of quant_method.
init_prepare_finalize
?
# presence of quant_method.
maybe_init_modular_kernel
?
if
(
module
.
__class__
.
__name__
==
"FusedMoE"
or
module
.
__class__
.
__name__
==
"SharedFusedMoE"
)
]
for
module
in
moe_modules
:
module
.
quant_method
.
init_prepare_finalize
(
module
)
module
.
maybe_init_modular_kernel
(
)
def
dispatch
(
self
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
938772af
...
...
@@ -117,10 +117,8 @@ class FusedMoeWeightScaleSupported(Enum):
class
FusedMoEMethodBase
(
QuantizeMethodBase
):
def
__init__
(
self
,
moe
:
FusedMoEConfig
):
super
().
__init__
()
self
.
moe
=
moe
self
.
moe
:
FusedMoEConfig
=
moe
self
.
moe_quant_config
:
FusedMoEQuantConfig
|
None
=
None
self
.
fused_experts
:
FusedMoEModularKernel
|
None
=
None
self
.
topk_indices_dtype
=
None
@
abstractmethod
def
create_weights
(
...
...
@@ -245,9 +243,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
else
:
return
None
# Note: init_prepar
e_
f
in
alize should only be called by
# prepare_communication_buffer_for_model.
def
init_prepare_finalize
(
self
,
layer
:
torch
.
nn
.
Module
)
:
def
mayb
e_in
it_modular_kernel
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEModularKernel
|
None
:
assert
self
.
moe
is
not
None
# We must get the quant config here so that the layer is
...
...
@@ -261,17 +259,14 @@ class FusedMoEMethodBase(QuantizeMethodBase):
logger
.
debug
(
"%s for %s(%s)"
,
prepare_finalize
.
__class__
.
__name__
,
self
,
id
(
self
)
)
assert
self
.
topk_indices_dtype
is
None
assert
self
.
fused_experts
is
None
,
(
f
"Attempt to override experts for
{
id
(
self
)
}
!"
)
self
.
topk_indices_dtype
=
prepare_finalize
.
topk_indices_dtype
()
experts
=
self
.
select_gemm_impl
(
prepare_finalize
,
layer
)
self
.
fused_experts
=
FusedMoEModularKernel
(
return
FusedMoEModularKernel
(
prepare_finalize
,
experts
,
layer
.
shared_experts
,
)
else
:
return
None
def
select_gemm_impl
(
self
,
...
...
@@ -292,8 +287,16 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise
NotImplementedError
@
property
def
using_modular_kernel
(
self
)
->
bool
:
return
self
.
fused_experts
is
not
None
def
topk_indices_dtype
(
self
)
->
torch
.
dtype
|
None
:
return
None
@
property
def
supports_eplb
(
self
)
->
bool
:
return
False
@
property
def
allow_inplace
(
self
)
->
bool
:
return
False
@
abstractmethod
def
apply
(
...
...
@@ -322,6 +325,138 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise
NotImplementedError
@
CustomOp
.
register
(
"modular_fused_moe"
)
class
FusedMoEModularMethod
(
FusedMoEMethodBase
,
CustomOp
):
def
__init__
(
self
,
old_quant_method
:
FusedMoEMethodBase
,
fused_experts
:
FusedMoEModularKernel
,
):
super
().
__init__
(
old_quant_method
.
moe
)
# Find better way to copy attributes? Should we even copy attributes?
# self.__dict__.update(old_quant_method.__dict__)
self
.
moe_quant_config
=
old_quant_method
.
moe_quant_config
self
.
fused_experts
=
fused_experts
self
.
disable_expert_map
=
getattr
(
old_quant_method
,
"disable_expert_map"
,
not
fused_experts
.
supports_expert_map
(),
)
self
.
old_quant_method
=
old_quant_method
logger
.
debug
(
"Swapping out %s"
,
self
.
old_quant_method
.
__class__
.
__name__
)
@
property
def
topk_indices_dtype
(
self
)
->
torch
.
dtype
|
None
:
return
self
.
fused_experts
.
prepare_finalize
.
topk_indices_dtype
()
@
property
def
supports_eplb
(
self
)
->
bool
:
return
self
.
old_quant_method
.
supports_eplb
@
property
def
allow_inplace
(
self
)
->
bool
:
return
self
.
old_quant_method
.
allow_inplace
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
raise
NotImplementedError
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
return
self
.
moe_quant_config
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
int
|
None
=
None
,
num_expert_group
:
int
|
None
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
custom_routing_function
:
Callable
|
None
=
None
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
torch
.
Tensor
|
None
=
None
,
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Is getattr needed?
zero_expert_num
=
getattr
(
layer
,
"zero_expert_num"
,
0
)
zero_expert_type
=
getattr
(
layer
,
"zero_expert_type"
,
None
)
if
enable_eplb
:
if
self
.
supports_eplb
:
assert
expert_load_view
is
not
None
assert
logical_to_physical_map
is
not
None
assert
logical_replica_count
is
not
None
assert
isinstance
(
layer
,
FusedMoE
)
else
:
raise
NotImplementedError
(
"EPLB is not supported for "
f
"
{
self
.
old_quant_method
.
__class__
.
__name__
}
."
)
topk_weights
,
topk_ids
,
zero_expert_result
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
,
indices_type
=
self
.
topk_indices_dtype
,
enable_eplb
=
enable_eplb
,
expert_map
=
expert_map
,
expert_load_view
=
expert_load_view
,
logical_to_physical_map
=
logical_to_physical_map
,
logical_replica_count
=
logical_replica_count
,
global_num_experts
=
global_num_experts
,
zero_expert_num
=
zero_expert_num
,
zero_expert_type
=
zero_expert_type
,
)
result
=
self
.
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
self
.
allow_inplace
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
expert_map
=
None
if
self
.
disable_expert_map
else
expert_map
,
)
if
zero_expert_num
!=
0
and
zero_expert_type
is
not
None
:
assert
not
isinstance
(
result
,
tuple
),
(
"Shared + zero experts are mutually exclusive not yet supported"
)
return
result
,
zero_expert_result
else
:
return
result
@
CustomOp
.
register
(
"unquantized_fused_moe"
)
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
,
CustomOp
):
"""MoE method without quantization."""
...
...
@@ -378,6 +513,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
)
self
.
flashinfer_cutlass_moe
=
None
# type: ignore
@
property
def
supports_eplb
(
self
)
->
bool
:
return
True
@
property
def
allow_inplace
(
self
)
->
bool
:
return
True
def
maybe_make_prepare_finalize
(
self
)
->
FusedMoEPrepareAndFinalize
|
None
:
if
self
.
rocm_aiter_moe_enabled
:
return
None
...
...
@@ -650,7 +793,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
)
if
self
.
rocm_aiter_moe_enabled
:
assert
self
.
fused_experts
is
None
result
=
self
.
rocm_aiter_fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
...
...
@@ -671,21 +813,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
elif
self
.
fused_experts
is
not
None
:
result
=
self
.
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
)
else
:
assert
fused_experts
is
not
None
result
=
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
...
...
@@ -1267,7 +1395,7 @@ class FusedMoE(CustomOp):
"Only softmax scoring function is supported for non-grouped topk."
)
moe
=
FusedMoEConfig
(
self
.
moe_config
:
FusedMoEConfig
=
FusedMoEConfig
(
num_experts
=
self
.
global_num_experts
,
experts_per_token
=
top_k
,
hidden_dim
=
hidden_size
,
...
...
@@ -1279,24 +1407,26 @@ class FusedMoE(CustomOp):
is_act_and_mul
=
is_act_and_mul
,
is_lora_enabled
=
vllm_config
.
lora_config
is
not
None
,
)
self
.
moe_config
:
FusedMoEConfig
=
moe
self
.
moe_quant_config
:
FusedMoEQuantConfig
|
None
=
None
self
.
quant_config
=
quant_config
# Note:
get_quant_method
will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method
:
QuantizeMethodBase
|
None
=
None
quant_method
=
(
UnquantizedFusedMoEMethod
(
moe
)
if
quant_
config
is
None
else
quant_config
.
get_quant_method
(
self
,
prefix
)
)
def
_
get_quant_method
()
->
FusedMoEMethodBase
:
"""
Helper method to ensure self.quant_method is never
None
and
of the proper type.
"""
quant_
method
=
None
if
self
.
quant_config
is
not
None
:
quant_method
=
self
.
quant_config
.
get_quant_method
(
self
,
prefix
)
if
quant_method
is
None
:
quant_method
=
UnquantizedFusedMoEMethod
(
moe
)
assert
quant_method
is
not
None
quant_method
=
UnquantizedFusedMoEMethod
(
self
.
moe_config
)
assert
isinstance
(
quant_method
,
FusedMoEMethodBase
)
self
.
quant_method
=
quant_method
return
quant_method
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
self
.
quant_method
:
FusedMoEMethodBase
=
_get_quant_method
()
if
not
self
.
moe_config
.
is_act_and_mul
:
# Avoid circular import
...
...
@@ -1305,7 +1435,7 @@ class FusedMoE(CustomOp):
)
if
not
isinstance
(
quant_method
,
(
UnquantizedFusedMoEMethod
,
ModelOptFp8MoEMethod
)
self
.
quant_method
,
(
UnquantizedFusedMoEMethod
,
ModelOptFp8MoEMethod
)
):
raise
NotImplementedError
(
"is_act_and_mul=False is supported only for unquantized "
...
...
@@ -1316,10 +1446,7 @@ class FusedMoE(CustomOp):
"is_act_and_mul=False is supported only for CUDA for now"
)
if
self
.
enable_eplb
:
from
vllm.model_executor.layers.quantization.fp8
import
Fp8MoEMethod
if
not
isinstance
(
quant_method
,
(
Fp8MoEMethod
,
UnquantizedFusedMoEMethod
)):
if
self
.
enable_eplb
and
not
self
.
quant_method
.
supports_eplb
:
# TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not
# contain essential differences, but the current quant API
...
...
@@ -1328,6 +1455,7 @@ class FusedMoE(CustomOp):
# If you plan to add support for more quantization methods,
# please refer to the implementation in `Fp8MoEMethod`.
raise
NotImplementedError
(
f
"EPLB is not supported
{
self
.
quant_method
.
__class__
.
__name__
}
. "
"EPLB is only supported for FP8 quantization for now."
)
...
...
@@ -1353,6 +1481,15 @@ class FusedMoE(CustomOp):
self
.
batched_hidden_states
:
torch
.
Tensor
|
None
=
None
self
.
batched_router_logits
:
torch
.
Tensor
|
None
=
None
# Note: maybe_init_modular_kernel should only be called by
# prepare_communication_buffer_for_model.
# This is called after all weight loading and post-processing, so it
# should be safe to swap out the quant_method.
def
maybe_init_modular_kernel
(
self
)
->
None
:
mk
=
self
.
quant_method
.
maybe_init_modular_kernel
(
self
)
if
mk
is
not
None
:
self
.
quant_method
=
FusedMoEModularMethod
(
self
.
quant_method
,
mk
)
@
property
def
shared_experts
(
self
)
->
torch
.
nn
.
Module
|
None
:
return
None
...
...
@@ -2167,7 +2304,7 @@ class FusedMoE(CustomOp):
"""
assert
self
.
quant_method
is
not
None
return
(
self
.
quant_method
.
f
used
_experts
is
not
None
isinstance
(
self
.
quant_method
,
F
used
MoEModularMethod
)
and
self
.
quant_method
.
fused_experts
.
output_is_reduced
()
)
...
...
@@ -2403,7 +2540,7 @@ class FusedMoE(CustomOp):
self
.
ensure_dp_chunking_init
()
has_separate_shared_experts
=
(
not
isinstance
(
self
.
quant_method
.
fused_experts
,
FusedMoEModular
Kernel
)
not
isinstance
(
self
.
quant_method
,
FusedMoEModular
Method
)
and
self
.
shared_experts
is
not
None
)
...
...
@@ -2430,8 +2567,8 @@ class FusedMoE(CustomOp):
hidden_states
,
router_logits
,
has_separate_shared_experts
)
do_naive_dispatch_combine
:
bool
=
(
self
.
dp_size
>
1
and
not
self
.
quant_method
.
using_modular_kernel
do_naive_dispatch_combine
:
bool
=
self
.
dp_size
>
1
and
not
isinstance
(
self
.
quant_method
,
FusedMoEModularMethod
)
# If there are shared experts but we are not using a modular kernel, the
...
...
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
938772af
...
...
@@ -707,6 +707,12 @@ class FusedMoEModularKernel(torch.nn.Module):
f
"
{
fused_experts
.
activation_formats
[
0
]
}
"
)
def
supports_expert_map
(
self
)
->
bool
:
"""
A flag indicating whether or not this class supports expert maps.
"""
return
self
.
fused_experts
.
supports_expert_map
()
def
output_is_reduced
(
self
)
->
bool
:
"""
Indicates whether or not the output of fused MoE kernel
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
938772af
...
...
@@ -617,8 +617,6 @@ class AWQMoEMethod(FusedMoEMethodBase):
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
fused_experts
is
None
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `AWQMoEMethod` yet."
)
...
...
vllm/model_executor/layers/quantization/bitsandbytes.py
View file @
938772af
...
...
@@ -518,12 +518,11 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
assert
self
.
fused_experts
is
None
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `BitsAndBytesMoEMethod` yet."
)
topk_weights
,
topk_ids
,
_
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
938772af
...
...
@@ -462,12 +462,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
indices_type
=
self
.
topk_indices_dtype
,
)
#
# Note: the order here is important. self.fused_experts can override
# flashinfer cutlass, cutlass fp4 or fused_experts but not marlin.
#
if
self
.
use_marlin
:
assert
self
.
fused_experts
is
None
return
fused_marlin_moe
(
x
,
layer
.
w13_weight
,
...
...
@@ -488,24 +483,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
workspace
=
layer
.
workspace
,
)
elif
self
.
fused_experts
is
not
None
:
assert
is_valid_flashinfer_cutlass_fused_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
),
"Flashinfer CUTLASS Fused MoE not applicable!"
return
self
.
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
False
,
# TODO(shuw): fix later, now output is high prec
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
# FlashInfer fused experts path
elif
self
.
allow_flashinfer
:
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
# noqa: E501
...
...
@@ -1066,13 +1043,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
per_act_token
=
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
per_channel_quant
=
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
#
# Note: the order here is important. self.fused_experts can override
# cutlass fp8 or fused_experts but not marlin or rocm.
#
if
self
.
use_marlin
:
assert
activation
==
"silu"
,
f
"
{
activation
}
not supported for Marlin MoE."
assert
self
.
fused_experts
is
None
return
fused_marlin_moe
(
x
,
layer
.
w13_weight
,
...
...
@@ -1098,7 +1070,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
assert
per_act_token
==
per_channel_quant
assert
self
.
moe_quant_config
is
not
None
assert
self
.
fused_experts
is
None
return
rocm_aiter_fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
...
...
@@ -1111,18 +1082,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
quant_config
=
self
.
moe_quant_config
,
)
elif
self
.
fused_experts
is
not
None
:
return
self
.
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
None
if
self
.
disable_expert_map
else
expert_map
,
)
# cutlass path
elif
self
.
use_cutlass
:
assert
self
.
moe_quant_config
is
not
None
...
...
@@ -1318,8 +1277,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
fused_experts
is
None
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `CompressedTensorsW8A8Int8MoEMethod` yet."
...
...
@@ -1636,8 +1593,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
fused_experts
is
None
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `CompressedTensorsWNA16MarlinMoEMethod` yet."
...
...
@@ -1901,8 +1856,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
fused_experts
is
None
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `CompressedTensorsWNA16MoEMethod` yet."
...
...
vllm/model_executor/layers/quantization/experts_int8.py
View file @
938772af
...
...
@@ -158,8 +158,6 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
fused_experts
is
None
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `ExpertsInt8MoEMethod` yet."
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
938772af
...
...
@@ -703,9 +703,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self
.
quant_config
=
quant_config
self
.
weight_block_size
=
self
.
quant_config
.
weight_block_size
self
.
block_quant
:
bool
=
self
.
weight_block_size
is
not
None
self
.
fused_experts
:
mk
.
FusedMoEModularKernel
|
None
=
None
# type: ignore
self
.
fp8_backend
=
get_fp8_moe_backend
(
self
.
block_quant
)
self
.
use_marlin
=
self
.
fp8_backend
==
Fp8MoeBackend
.
MARLIN
...
...
@@ -1181,6 +1178,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
block_shape
=
self
.
weight_block_size
,
)
@
property
def
supports_eplb
(
self
)
->
bool
:
return
True
@
property
def
allow_inplace
(
self
)
->
bool
:
return
True
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -1210,10 +1215,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert
logical_replica_count
is
not
None
assert
isinstance
(
layer
,
FusedMoE
)
if
(
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
and
self
.
fused_experts
is
None
):
if
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
:
assert
activation
==
"silu"
,
(
f
"Expected 'silu' activation but got
{
activation
}
"
)
...
...
@@ -1290,10 +1292,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
num_fused_shared_experts
=
layer
.
num_fused_shared_experts
,
)
#
# Note: the order of checks is important since self.fused_experts
# can override fused_experts or cutlass but not rocm or marlin.
#
topk_weights
,
topk_ids
,
zero_expert_result
=
select_result
if
self
.
rocm_aiter_moe_enabled
:
...
...
@@ -1301,7 +1299,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
rocm_aiter_fused_experts
,
)
assert
self
.
fused_experts
is
None
result
=
rocm_aiter_fused_experts
(
x
,
layer
.
w13_weight
,
...
...
@@ -1315,7 +1312,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
elif
self
.
use_marlin
:
assert
activation
==
"silu"
,
f
"
{
activation
}
not supported for Marlin MoE."
assert
self
.
fused_experts
is
None
result
=
fused_marlin_moe
(
x
,
layer
.
w13_weight
,
...
...
@@ -1333,19 +1329,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
expert_map
=
expert_map
,
workspace
=
layer
.
workspace
,
)
elif
self
.
fused_experts
:
result
=
self
.
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
expert_map
=
expert_map
,
)
elif
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
CUTLASS
:
assert
not
self
.
block_quant
assert
not
renormalize
and
custom_routing_function
is
not
None
...
...
vllm/model_executor/layers/quantization/gguf.py
View file @
938772af
...
...
@@ -585,8 +585,6 @@ class GGUFMoEMethod(FusedMoEMethodBase):
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
fused_experts
is
None
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `GGUFMoEMethod` yet."
)
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
938772af
...
...
@@ -742,8 +742,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
fused_experts
is
None
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `GPTQMarlinMoEMethod` yet."
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
938772af
...
...
@@ -18,9 +18,6 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config
,
nvfp4_moe_quant_config
,
)
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
is_valid_flashinfer_cutlass_fused_moe
,
)
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
fused_marlin_moe
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
...
...
@@ -605,7 +602,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
if
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
:
assert
self
.
fused_experts
is
None
assert
activation
==
"silu"
,
(
f
"Expected 'silu' activation but got
{
activation
}
"
)
...
...
@@ -638,24 +634,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
indices_type
=
self
.
topk_indices_dtype
,
)
#
# Note: the order here is important. self.fused_experts can override
# cutlass or fused_experts.
#
if
self
.
fused_experts
is
not
None
:
return
self
.
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
inplace
=
False
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
elif
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
CUTLASS
:
if
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
CUTLASS
:
assert
not
renormalize
assert
activation
==
"silu"
,
(
f
"Expected 'silu' activation but got
{
activation
}
"
...
...
@@ -1647,8 +1626,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
from
vllm.model_executor.models.llama4
import
Llama4MoE
assert
self
.
fused_experts
is
None
a1_gscale
=
layer
.
w13_input_scale_quant
(
hidden_states_fp4
,
hidden_states_scale_linear_fp4
)
=
(
flashinfer
.
fp4_quantize
(
...
...
@@ -1720,13 +1697,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
indices_type
=
self
.
topk_indices_dtype
,
)
#
# Note: the order here is important. self.fused_experts can override
# flashinfer cutlass, cutlass fp4 or fused_experts but not marlin or
# trtllm.
#
if
self
.
use_marlin
:
assert
self
.
fused_experts
is
None
return
fused_marlin_moe
(
x
,
layer
.
w13_weight
,
...
...
@@ -1747,23 +1718,24 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
workspace
=
layer
.
workspace
,
)
elif
self
.
fused_experts
is
not
None
:
assert
(
elif
(
self
.
allow_flashinfer
and
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
CUTLASS
):
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
# noqa: E501
flashinfer_cutlass_moe_fp4
,
)
assert
is_valid_flashinfer_cutlass_fused_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
),
"Flashinfer CUTLASS Fused MoE not applicable!"
assert
self
.
moe_quant_config
is
not
None
return
self
.
fused_experts
(
return
flashinfer_cutlass_moe_fp4
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
False
,
# TODO(shuw): fix later, now output is high prec
quant_config
=
self
.
moe_quant_config
,
inplace
=
False
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
...
...
vllm/model_executor/layers/quantization/moe_wna16.py
View file @
938772af
...
...
@@ -226,7 +226,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
self
.
moe
=
layer
layer
.
quant_config
=
self
.
quant_config
bit8_pack_factor
=
self
.
quant_config
.
bit8_pack_factor
group_size
=
self
.
quant_config
.
group_size
...
...
@@ -381,7 +380,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
fused_experts
is
None
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `MoeWNA16Method` yet."
)
...
...
vllm/model_executor/layers/quantization/mxfp4.py
View file @
938772af
...
...
@@ -197,8 +197,6 @@ class Mxfp4Config(QuantizationConfig):
class
Mxfp4MoEMethod
(
FusedMoEMethodBase
):
def
__init__
(
self
,
moe
:
FusedMoEConfig
):
super
().
__init__
(
moe
)
self
.
topk_indices_dtype
=
None
self
.
moe
=
moe
self
.
mxfp4_backend
=
get_mxfp4_backend
(
moe
.
is_lora_enabled
)
self
.
max_capture_size
=
(
get_current_vllm_config
().
compilation_config
.
max_cudagraph_capture_size
...
...
@@ -815,6 +813,18 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
"EP batched experts format"
)
else
:
layer
.
w13_weight
=
(
self
.
w13_weight_triton_tensor
if
layer
.
w13_weight
is
None
else
layer
.
w13_weight
)
layer
.
w2_weight
=
(
self
.
w2_weight_triton_tensor
if
layer
.
w2_weight
is
None
else
layer
.
w2_weight
)
assert
all
([
w
is
not
None
for
w
in
[
layer
.
w13_weight
,
layer
.
w2_weight
]])
assert
self
.
moe_quant_config
is
not
None
if
(
self
.
mxfp4_backend
==
Mxfp4Backend
.
SM100_FI_MXFP4_MXFP8_TRTLLM
...
...
@@ -838,71 +848,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
f
"Incompatible Mxfp4 backend (
{
self
.
mxfp4_backend
}
) for EP"
)
def
_route_and_experts
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
int
|
None
=
None
,
num_expert_group
:
int
|
None
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
custom_routing_function
:
Callable
|
None
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
torch
.
Tensor
|
None
=
None
,
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
assert
isinstance
(
self
.
fused_experts
,
mk
.
FusedMoEModularKernel
)
topk_weights
,
topk_ids
,
_
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
indices_type
=
self
.
topk_indices_dtype
,
enable_eplb
=
enable_eplb
,
expert_map
=
expert_map
,
expert_load_view
=
expert_load_view
,
logical_to_physical_map
=
logical_to_physical_map
,
logical_replica_count
=
logical_replica_count
,
)
w13_weight
=
(
self
.
w13_weight_triton_tensor
if
layer
.
w13_weight
is
None
else
layer
.
w13_weight
)
w2_weight
=
(
self
.
w2_weight_triton_tensor
if
layer
.
w2_weight
is
None
else
layer
.
w2_weight
)
assert
all
([
w
is
not
None
for
w
in
[
w13_weight
,
w2_weight
]])
return
self
.
fused_experts
(
hidden_states
=
x
,
w1
=
w13_weight
,
w2
=
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
@
property
def
allow_inplace
(
self
)
->
bool
:
return
True
def
apply
(
self
,
...
...
@@ -930,29 +878,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB is not supported for mxfp4"
)
if
self
.
fused_experts
is
not
None
:
return
self
.
_route_and_experts
(
layer
,
x
,
router_logits
,
top_k
,
renormalize
,
use_grouped_topk
,
topk_group
,
num_expert_group
,
global_num_experts
,
expert_map
,
custom_routing_function
,
scoring_func
,
e_score_correction_bias
,
apply_router_weight_on_input
,
activation
,
enable_eplb
,
expert_load_view
,
logical_to_physical_map
,
logical_replica_count
,
)
if
self
.
mxfp4_backend
==
Mxfp4Backend
.
MARLIN
:
topk_weights
,
topk_ids
,
_
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
...
...
vllm/model_executor/layers/quantization/quark/quark_moe.py
View file @
938772af
...
...
@@ -310,7 +310,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
# Property to determine if AITER is used
if
self
.
rocm_aiter_moe_enabled
:
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
# noqa E501
rocm_aiter_fused_experts
,
shuffle_weights
,
)
...
...
@@ -322,17 +321,11 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffled_w13
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
shuffled_w2
,
requires_grad
=
False
)
self
.
rocm_aiter_fused_experts_func
=
rocm_aiter_fused_experts
elif
self
.
use_marlin
:
prepare_moe_fp8_layer_for_marlin
(
layer
,
False
)
# Activations not quantized for marlin.
del
layer
.
w13_input_scale
del
layer
.
w2_input_scale
self
.
fused_experts_func
=
None
else
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
self
.
fused_experts_func
=
fused_experts
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
...
...
@@ -369,8 +362,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
fused_experts
is
None
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet."
...
...
@@ -392,7 +383,11 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
)
if
self
.
rocm_aiter_moe_enabled
:
return
self
.
rocm_aiter_fused_experts_func
(
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
rocm_aiter_fused_experts
,
)
return
rocm_aiter_fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
...
...
@@ -403,7 +398,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
quant_config
=
self
.
moe_quant_config
,
expert_map
=
expert_map
,
)
if
self
.
use_marlin
:
el
if
self
.
use_marlin
:
assert
activation
==
"silu"
,
f
"
{
activation
}
not supported for Marlin MoE."
return
fused_marlin_moe
(
x
,
...
...
@@ -421,10 +416,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
)
else
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
assert
self
.
fused_experts_func
is
not
None
return
self
.
fused_experts_func
(
return
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
...
...
@@ -601,6 +596,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
block_shape
=
None
,
)
@
property
def
allow_inplace
(
self
)
->
bool
:
return
True
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -624,8 +623,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
fused_experts
is
None
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `QuarkOCP_MX_MoEMethod` yet."
...
...
vllm/model_executor/layers/quantization/rtn.py
View file @
938772af
...
...
@@ -377,8 +377,6 @@ class RTNMoEMethod(FusedMoEMethodBase):
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
fused_experts
is
None
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `RTNMoEMethod` yet."
)
...
...
vllm/model_executor/warmup/deep_gemm_warmup.py
View file @
938772af
...
...
@@ -13,7 +13,7 @@ import vllm.envs as envs
from
vllm.distributed.parallel_state
import
get_dp_group
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
DeepGemmExperts
from
vllm.model_executor.layers.fused_moe.deep_gemm_utils
import
compute_aligned_M
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoE
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoE
,
FusedMoEModularMethod
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
FusedMoEModularKernel
from
vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe
import
(
TritonOrDeepGemmExperts
,
...
...
@@ -160,8 +160,8 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
):
return
False
if
not
isinstance
(
module
.
quant_method
.
fused_experts
,
FusedMoEModular
Kernel
):
#
fused_expert
s could invoke deep_gemm_moe_fp8
if
not
isinstance
(
module
.
quant_method
,
FusedMoEModular
Method
):
#
modular kernel
s could invoke deep_gemm_moe_fp8
return
True
mk
:
FusedMoEModularKernel
=
module
.
quant_method
.
fused_experts
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment