Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
42135d68
Unverified
Commit
42135d68
authored
Jan 21, 2026
by
Robert Shaw
Committed by
GitHub
Jan 21, 2026
Browse files
[MoE Refactor] Oracle Select FP8+NVFP4 Kernels In Priority (#32414)
parent
e14467be
Changes
82
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
578 additions
and
472 deletions
+578
-472
vllm/model_executor/layers/fused_moe/oracle/unquantized.py
vllm/model_executor/layers/fused_moe/oracle/unquantized.py
+9
-8
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
+52
-10
vllm/model_executor/layers/fused_moe/router/custom_routing_router.py
...executor/layers/fused_moe/router/custom_routing_router.py
+5
-0
vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py
...l_executor/layers/fused_moe/router/grouped_topk_router.py
+4
-6
vllm/model_executor/layers/fused_moe/router/router_factory.py
.../model_executor/layers/fused_moe/router/router_factory.py
+0
-3
vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py
vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py
+14
-8
vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
+14
-4
vllm/model_executor/layers/fused_moe/trtllm_moe.py
vllm/model_executor/layers/fused_moe/trtllm_moe.py
+47
-11
vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
...executor/layers/fused_moe/unquantized_fused_moe_method.py
+6
-3
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+2
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+194
-185
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+51
-104
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+2
-0
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+64
-42
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+4
-3
vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
..._executor/layers/quantization/utils/flashinfer_fp4_moe.py
+84
-38
vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
...el_executor/layers/quantization/utils/flashinfer_utils.py
+0
-31
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+26
-12
vllm/model_executor/models/qwen3_moe.py
vllm/model_executor/models/qwen3_moe.py
+0
-2
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+0
-2
No files found.
vllm/model_executor/layers/fused_moe/oracle/unquantized.py
View file @
42135d68
...
...
@@ -123,7 +123,6 @@ def convert_to_unquantized_kernel_format(
def
make_unquantized_moe_kernel
(
layer
:
torch
.
nn
.
Module
,
backend
:
UnquantizedMoeBackend
,
quant_config
:
FusedMoEQuantConfig
,
moe_config
:
FusedMoEConfig
,
...
...
@@ -141,12 +140,8 @@ def make_unquantized_moe_kernel(
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
FlashInferExperts
(
out_dtype
=
layer
.
params_dtype
,
moe_config
=
moe_config
,
quant_config
=
quant_config
,
tp_rank
=
moe_config
.
moe_parallel_config
.
tp_rank
,
tp_size
=
moe_config
.
moe_parallel_config
.
tp_size
,
ep_rank
=
moe_config
.
moe_parallel_config
.
ep_rank
,
ep_size
=
moe_config
.
moe_parallel_config
.
ep_size
,
),
)
use_inplace
=
False
...
...
@@ -157,13 +152,19 @@ def make_unquantized_moe_kernel(
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
AiterExperts
(
quant_config
),
AiterExperts
(
moe_config
=
moe_config
,
quant_config
=
quant_config
,
),
)
elif
backend
==
UnquantizedMoeBackend
.
TRITON
:
from
vllm.model_executor.layers.fused_moe
import
TritonExperts
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
TritonExperts
(
quant_config
),
TritonExperts
(
moe_config
=
moe_config
,
quant_config
=
quant_config
,
),
)
return
kernel
,
use_inplace
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
View file @
42135d68
...
...
@@ -9,11 +9,21 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.model_executor.layers.fused_moe.config
import
(
FUSED_MOE_UNQUANTIZED_CONFIG
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceNoOP
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kFp8Dynamic128Sym
,
kFp8DynamicTensorSym
,
kFp8DynamicTokenSym
,
kFp8Static128BlockSym
,
kFp8StaticChannelSym
,
kFp8StaticTensorSym
,
)
class
QuantMethod
(
IntEnum
):
...
...
@@ -269,17 +279,49 @@ def rocm_aiter_fused_experts(
class
AiterExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
quant_config
):
super
().
__init__
(
quant_config
)
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
@
staticmethod
def
expects_unquantized_inputs
(
fused_moe_config
:
mk
.
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
)
->
bool
:
# AITER fused MoE kernels handle input quantization internally.
return
True
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
return
(
mk
.
FusedMoEActivationFormat
.
Standard
,
mk
.
FusedMoEActivationFormat
.
Standard
,
)
@
staticmethod
def
_supports_current_device
()
->
bool
:
return
rocm_aiter_ops
.
is_fused_moe_enabled
()
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
return
False
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
# TODO(rob): AITER also supports MXFP4, which is not
# yet supported via an Oracle. Once it is, we will add
# MXFP4 to this list.
SUPPORTED_W_A
=
[
(
None
,
None
),
(
kFp8Static128BlockSym
,
kFp8Dynamic128Sym
),
(
kFp8StaticTensorSym
,
kFp8StaticTensorSym
),
(
kFp8StaticTensorSym
,
kFp8DynamicTensorSym
),
(
kFp8StaticChannelSym
,
kFp8DynamicTokenSym
),
]
return
(
weight_key
,
activation_key
)
in
SUPPORTED_W_A
@
staticmethod
def
_supports_activation
(
activation
:
str
)
->
bool
:
return
activation
in
[
"silu"
,
"gelu"
]
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
return
True
def
supports_expert_map
(
self
):
return
True
...
...
vllm/model_executor/layers/fused_moe/router/custom_routing_router.py
View file @
42135d68
...
...
@@ -34,6 +34,11 @@ class CustomRoutingRouter(BaseRouter):
@
property
def
routing_method_type
(
self
)
->
RoutingMethodType
:
from
vllm.model_executor.models.llama4
import
Llama4MoE
# NOTE: FLASHINFER_TRTLLM support the Llama4 router.
if
self
.
custom_routing_function
==
Llama4MoE
.
custom_routing_function
:
return
RoutingMethodType
.
Llama4
return
RoutingMethodType
.
Custom
def
_compute_routing
(
...
...
vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py
View file @
42135d68
...
...
@@ -261,7 +261,6 @@ class GroupedTopKRouter(BaseRouter):
num_fused_shared_experts
:
int
=
0
,
enable_eplb
:
bool
=
False
,
indices_type_getter
:
Callable
[[],
torch
.
dtype
|
None
]
|
None
=
None
,
routing_method_type
:
RoutingMethodType
|
None
=
None
,
):
super
().
__init__
(
top_k
=
top_k
,
...
...
@@ -278,13 +277,12 @@ class GroupedTopKRouter(BaseRouter):
self
.
e_score_correction_bias
=
e_score_correction_bias
self
.
num_fused_shared_experts
=
num_fused_shared_experts
# Determine routing method type
if
routing_method_type
is
not
None
:
self
.
_routing_method_type
=
routing_method_type
elif
scoring_func
==
"sigmoid"
:
if
scoring_func
==
"sigmoid"
:
self
.
_routing_method_type
=
RoutingMethodType
.
DeepSeekV3
else
:
self
.
_routing_method_type
=
RoutingMethodType
.
TopK
# NOTE: this prohibits the FLASHINFER_TRTLLM kernels from
# being selected, since they only support DeepSeek-style.
self
.
_routing_method_type
=
RoutingMethodType
.
Unspecified
@
property
def
routing_method_type
(
self
)
->
RoutingMethodType
:
...
...
vllm/model_executor/layers/fused_moe/router/router_factory.py
View file @
42135d68
...
...
@@ -6,7 +6,6 @@ import torch
import
vllm.envs
as
envs
from
vllm.distributed.eplb.eplb_state
import
EplbLayerState
from
vllm.model_executor.layers.fused_moe.config
import
RoutingMethodType
from
vllm.model_executor.layers.fused_moe.router.base_router
import
BaseRouter
from
vllm.model_executor.layers.fused_moe.router.custom_routing_router
import
(
CustomRoutingRouter
,
...
...
@@ -36,7 +35,6 @@ def create_fused_moe_router(
global_num_experts
:
int
,
renormalize
:
bool
=
True
,
indices_type_getter
:
Callable
[[],
torch
.
dtype
|
None
]
|
None
=
None
,
routing_method_type
:
RoutingMethodType
|
None
=
None
,
# grouped topk parameters
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
int
|
None
=
None
,
...
...
@@ -128,7 +126,6 @@ def create_fused_moe_router(
num_fused_shared_experts
=
num_fused_shared_experts
,
enable_eplb
=
enable_eplb
,
indices_type_getter
=
indices_type_getter
,
routing_method_type
=
routing_method_type
,
)
router
.
capture
=
capture
return
router
...
...
vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py
View file @
42135d68
...
...
@@ -5,7 +5,10 @@
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
CutlassExpertsFp8
from
vllm.model_executor.layers.fused_moe.fallback
import
FallbackExperts
from
vllm.model_executor.layers.fused_moe.fused_moe
import
TritonExperts
...
...
@@ -17,19 +20,22 @@ class TritonOrCutlassExperts(FallbackExperts):
def
__init__
(
self
,
e
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
torch
.
dtype
|
None
,
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
,
device
:
torch
.
dtype
,
):
self
.
is_sm100
=
current_platform
.
has_device_capability
(
100
)
super
().
__init__
(
experts
=
CutlassExpertsFp8
(
e
,
n
,
k
,
out_dtype
,
quant_config
,
device
),
fallback_experts
=
TritonExperts
(
quant_config
),
experts
=
CutlassExpertsFp8
(
moe_config
,
quant_config
),
fallback_experts
=
TritonExperts
(
moe_config
,
quant_config
),
)
@
staticmethod
def
get_clses
()
->
tuple
[
type
[
mk
.
FusedMoEPermuteExpertsUnpermute
],
type
[
mk
.
FusedMoEPermuteExpertsUnpermute
],
]:
return
(
CutlassExpertsFp8
,
TritonExperts
)
def
workspace_shapes
(
self
,
M
:
int
,
...
...
vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
View file @
42135d68
...
...
@@ -4,7 +4,10 @@
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
DeepGemmExperts
,
_valid_deep_gemm
,
...
...
@@ -20,12 +23,19 @@ from vllm.utils.deep_gemm import (
class
TritonOrDeepGemmExperts
(
FallbackExperts
):
"""DeepGemm with fallback to Triton for low latency shapes."""
def
__init__
(
self
,
quant_config
:
FusedMoEQuantConfig
):
def
__init__
(
self
,
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
):
super
().
__init__
(
experts
=
DeepGemmExperts
(
quant_config
),
fallback_experts
=
TritonExperts
(
quant_config
),
experts
=
DeepGemmExperts
(
moe_config
,
quant_config
),
fallback_experts
=
TritonExperts
(
moe_config
,
quant_config
),
)
@
staticmethod
def
get_clses
()
->
tuple
[
type
[
mk
.
FusedMoEPermuteExpertsUnpermute
],
type
[
mk
.
FusedMoEPermuteExpertsUnpermute
],
]:
return
(
DeepGemmExperts
,
TritonExperts
)
def
workspace_shapes
(
self
,
M
:
int
,
...
...
vllm/model_executor/layers/fused_moe/trtllm_moe.py
View file @
42135d68
...
...
@@ -6,37 +6,73 @@ import torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceNoOP
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
)
class
TrtLlmGenExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
moe
:
FusedMoEConfig
,
moe
_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
,
gemm1_alpha
,
gemm1_beta
,
gemm1_clamp_limit
,
max_capture_size
,
):
super
().
__init__
(
quant_config
)
self
.
moe
=
moe
super
().
__init__
(
moe_config
,
quant_config
)
self
.
gemm1_alpha
=
gemm1_alpha
self
.
gemm1_beta
=
gemm1_beta
self
.
gemm1_clamp_limit
=
gemm1_clamp_limit
self
.
max_capture_size
=
max_capture_size
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
return
(
mk
.
FusedMoEActivationFormat
.
Standard
,
mk
.
FusedMoEActivationFormat
.
Standard
,
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
@
staticmethod
def
_supports_current_device
()
->
bool
:
raise
NotImplementedError
(
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
)
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
raise
NotImplementedError
(
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
)
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
raise
NotImplementedError
(
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
)
@
staticmethod
def
_supports_activation
(
activation
:
str
)
->
bool
:
raise
NotImplementedError
(
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
)
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
raise
NotImplementedError
(
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
)
def
supports_chunking
(
self
)
->
bool
:
...
...
@@ -86,7 +122,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk
=
topk_ids
.
size
(
-
1
)
local_num_experts
=
w1
.
size
(
0
)
intermediate_size
=
w2
.
size
(
1
)
local_expert_offset
=
self
.
moe
.
ep_rank
*
local_num_experts
local_expert_offset
=
self
.
moe
_config
.
ep_rank
*
local_num_experts
x_quant
=
hidden_states
x_scale
=
a1q_scale
...
...
vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
View file @
42135d68
...
...
@@ -96,13 +96,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
):
logger
.
debug
(
"BatchedTritonExperts %s"
,
self
.
moe
)
return
BatchedTritonExperts
(
moe_config
=
self
.
moe
,
quant_config
=
self
.
moe_quant_config
,
max_num_tokens
=
self
.
moe
.
max_num_tokens
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
quant_config
=
self
.
moe_quant_config
,
)
else
:
logger
.
debug
(
"TritonExperts %s"
,
self
.
moe
)
return
TritonExperts
(
self
.
moe_quant_config
)
return
TritonExperts
(
moe_config
=
self
.
moe
,
quant_config
=
self
.
moe_quant_config
,
)
def
create_weights
(
self
,
...
...
@@ -192,7 +196,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
assert
self
.
moe_quant_config
is
not
None
self
.
kernel
,
self
.
use_inplace
=
make_unquantized_moe_kernel
(
layer
=
layer
,
backend
=
self
.
unquantized_backend
,
quant_config
=
self
.
moe_quant_config
,
moe_config
=
self
.
moe
,
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
42135d68
...
...
@@ -739,6 +739,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
return
BatchedMarlinExperts
(
max_num_tokens
=
max_num_tokens_per_rank
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
moe_config
=
self
.
moe
,
quant_config
=
self
.
moe_quant_config
,
w13_g_idx
=
w13_g_idx
,
w2_g_idx
=
w2_g_idx
,
...
...
@@ -749,6 +750,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
else
:
# Standard Marlin experts for AWQ
return
MarlinExperts
(
moe_config
=
self
.
moe
,
quant_config
=
self
.
moe_quant_config
,
w13_g_idx
=
w13_g_idx
,
w2_g_idx
=
w2_g_idx
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
42135d68
...
...
@@ -19,7 +19,6 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEActivationFormat
,
FusedMoEConfig
,
FusedMoEMethodBase
,
FusedMoEPermuteExpertsUnpermute
,
FusedMoERouter
,
...
...
@@ -27,9 +26,9 @@ from vllm.model_executor.layers.fused_moe import (
UnquantizedFusedMoEMethod
,
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEQuantConfig
,
fp8_w8a8_moe_quant_config
,
fp8_w8a16_moe_quant_config
,
RoutingMethodType
,
int4_w4a16_moe_quant_config
,
int4_w4afp8_moe_quant_config
,
int8_w8a8_moe_quant_config
,
...
...
@@ -45,15 +44,17 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend
,
convert_to_fp8_moe_kernel_format
,
make_fp8_moe_kernel
,
make_fp8_moe_kernel_for_mkm
,
make_fp8_moe_quant_config
,
select_fp8_moe_backend
,
)
from
vllm.model_executor.layers.fused_moe.oracle.nvfp4
import
(
FLASHINFER_NVFP4_MOE_BACKENDS
,
NvFp4MoeBackend
,
convert_to_nvfp4_moe_kernel_format
,
is_global_sf_supported_for_nvfp4_backend
,
make_mxfp4_moe_quant_config
,
make_nvfp4_moe_kernel
,
make_nvfp4_moe_kernel_for_mkm
,
make_nvfp4_moe_quant_config
,
select_nvfp4_moe_backend
,
)
...
...
@@ -62,10 +63,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compress
WNA16_SUPPORTED_TYPES_MAP
,
)
from
vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe
import
(
build_flashinfer_fp4_cutlass_moe_prepare_finalize
,
flashinfer_trtllm_fp4_moe
,
flashinfer_trtllm_fp4_routed_moe
,
select_nvfp4_gemm_impl
,
)
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
apply_fi_trtllm_fp8_per_tensor_moe
,
build_flashinfer_fp8_cutlass_moe_prepare_finalize
,
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
process_fp8_input_tensor_strategy_moe
,
...
...
@@ -79,12 +82,18 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_moe_permute_scales
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
is_fp4_marlin_supported
,
prepare_moe_fp4_layer_for_marlin
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
convert_bf16_scales_to_fp8
,
convert_packed_uint4b8_to_signed_int4_inplace
,
kFp8Dynamic128Sym
,
kFp8DynamicTokenSym
,
kFp8Static128BlockSym
,
kFp8StaticChannelSym
,
kFp8StaticTensorSym
,
kNvfp4Dynamic
,
kNvfp4Static
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
normalize_e4m3fn_to_e4m3fnuz
,
...
...
@@ -200,7 +209,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
f
"or None for NVFP4A16, found
{
input_quant
}
"
,
)
return
CompressedTensorsW4A4Nvfp4MoEMethod
(
layer
.
moe_config
,
layer_name
,
use_
marlin
=
input_quant
is
None
layer
.
moe_config
,
layer_name
,
use_
a16
=
(
input_quant
is
None
)
)
elif
(
quant_config
.
_is_fp8_w8a8_sm90
(
weight_quant
,
input_quant
)
...
...
@@ -234,6 +243,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
super
().
__init__
(
moe
)
self
.
group_size
=
32
self
.
mxfp4_backend
=
NvFp4MoeBackend
.
MARLIN
self
.
experts_cls
=
MarlinExperts
self
.
kernel
:
mk
.
FusedMoEModularKernel
|
None
=
None
def
create_weights
(
...
...
@@ -327,9 +337,9 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self
.
moe_quant_config
=
self
.
get_fused_moe_quant_config
(
layer
)
if
self
.
moe_quant_config
is
not
None
:
self
.
kernel
=
make_nvfp4_moe_kernel
(
backend
=
self
.
mxfp4_backend
,
quant_config
=
self
.
moe_quant_config
,
moe_quant_config
=
self
.
moe_quant_config
,
moe_config
=
self
.
moe
,
experts_cls
=
self
.
experts_cls
,
)
def
apply
(
...
...
@@ -368,34 +378,30 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
self
,
moe
:
FusedMoEConfig
,
layer_name
:
str
|
None
=
None
,
use_
marlin
:
bool
=
False
,
use_
a16
:
bool
=
False
,
):
super
().
__init__
(
moe
)
self
.
group_size
=
16
if
use_marlin
:
if
is_fp4_marlin_supported
():
self
.
nvfp4_backend
=
NvFp4MoeBackend
.
MARLIN
else
:
raise
ValueError
(
"Marlin FP4 MoE kernel requested but not "
,
"supported on current platform."
,
)
else
:
self
.
nvfp4_backend
=
select_nvfp4_moe_backend
()
# TODO: move this type of check into the oracle.
if
not
self
.
moe
.
is_act_and_mul
and
self
.
nvfp4_backend
not
in
[
NvFp4MoeBackend
.
FLASHINFER_CUTLASS
,
NvFp4MoeBackend
.
MARLIN
,
]:
raise
NotImplementedError
(
"Non-gated activations are only supported by FlashInfer "
f
"CUTLASS and Marlin NvFP4 MoE backends, not
{
self
.
nvfp4_backend
}
."
)
# Select experts implementation.
self
.
nvfp4_backend
,
self
.
experts_cls
=
select_nvfp4_moe_backend
(
config
=
self
.
moe
,
weight_key
=
kNvfp4Static
,
activation_key
=
None
if
use_a16
else
kNvfp4Dynamic
,
)
# Delay creation of the kernel until after process-weights.
self
.
kernel
:
mk
.
FusedMoEModularKernel
|
None
=
None
self
.
use_global_sf
=
is_global_sf_supported_for_nvfp4_backend
(
self
.
nvfp4_backend
)
self
.
kernel
:
mk
.
FusedMoEModularKernel
|
None
=
None
@
property
def
topk_indices_dtype
(
self
)
->
torch
.
dtype
|
None
:
if
self
.
kernel
is
not
None
:
return
self
.
kernel
.
prepare_finalize
.
topk_indices_dtype
()
return
None
def
create_weights
(
self
,
...
...
@@ -571,35 +577,40 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer
.
w13_input_scale
=
a13_scale
layer
.
w2_input_scale
=
a2_scale
# Initialize the kernel that will be called in apply().
# Setup modular kernel for TP case and naive DP/EP case.
# In non-naive DP/EP case, we will create a ModularKernelMethod.
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self
.
moe_quant_config
=
self
.
get_fused_moe_quant_config
(
layer
)
use_dp
=
self
.
moe
.
dp_size
>
1
if
self
.
moe_quant_config
is
not
None
and
not
use_dp
:
if
self
.
moe_quant_config
and
(
(
not
self
.
moe
.
moe_parallel_config
.
use_all2all_kernels
)
or
self
.
moe
.
moe_parallel_config
.
use_naive_all2all_kernels
):
assert
self
.
experts_cls
is
not
None
self
.
kernel
=
make_nvfp4_moe_kernel
(
backend
=
self
.
nvfp4_backend
,
quant_config
=
self
.
moe_quant_config
,
moe_quant_config
=
self
.
moe_quant_config
,
moe_config
=
self
.
moe
,
experts_cls
=
self
.
experts_cls
,
)
def
maybe_make_prepare_finalize
(
self
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
)
->
mk
.
FusedMoEPrepareAndFinalize
|
None
:
UNSUPPORTED
=
[
NvFp4MoeBackend
.
MARLIN
,
NvFp4MoeBackend
.
FLASHINFER_TRTLLM
]
if
self
.
nvfp4_backend
in
UNSUPPORTED
:
if
self
.
nvfp4_backend
==
NvFp4MoeBackend
.
FLASHINFER_TRTLLM
:
return
None
elif
self
.
nvfp4_backend
==
NvFp4MoeBackend
.
FLASHINFER_CUTLASS
:
#
TP case: avoid convert to ModularKernelMethod - to be refactored
.
if
self
.
moe
.
dp_size
==
1
:
#
For no-EP case, don't use the MKM framework
.
if
not
self
.
moe
.
moe_parallel_config
.
use_all2all_kernels
:
return
None
# For now, fp4 moe only works with the flashinfer dispatcher.
prepare_finalize
=
build_flashinfer_fp4_cutlass_moe_prepare_finalize
(
self
.
moe
prepare_finalize
=
build_flashinfer_fp8_cutlass_moe_prepare_finalize
(
self
.
moe
,
use_deepseek_fp8_block_scale
=
False
,
)
logger
.
debug_once
(
"%s"
,
prepare_finalize
.
__class__
.
__name__
)
return
prepare_finalize
else
:
return
super
().
maybe_make_prepare_finalize
(
routing_tables
)
return
super
().
maybe_make_prepare_finalize
(
routing_tables
)
def
select_gemm_impl
(
self
,
...
...
@@ -607,14 +618,13 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer
:
torch
.
nn
.
Module
,
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
assert
self
.
moe_quant_config
is
not
None
"""Return the appropriate GEMM experts implementation."""
experts
=
select_nvfp4_gemm_impl
(
self
.
moe
,
self
.
moe_quant_config
,
allow_flashinfer
=
(
self
.
nvfp4_backend
in
FLASHINFER_NVFP4_MOE_BACKENDS
),
assert
self
.
experts_cls
is
not
None
return
make_nvfp4_moe_kernel_for_mkm
(
moe_config
=
self
.
moe
,
quant_config
=
self
.
moe_quant_config
,
experts_cls
=
self
.
experts_cls
,
prepare_finalize
=
prepare_finalize
,
)
logger
.
debug_once
(
"Using %s"
,
experts
.
__class__
.
__name__
)
return
experts
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
...
...
@@ -727,33 +737,41 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization."
)
self
.
fp8_backend
=
select_fp8_moe_backend
(
block_quant
=
self
.
block_quant
,
tp_size
=
moe
.
tp_size
,
with_lora_support
=
moe
.
is_lora_enabled
,
is_act_and_mul
=
moe
.
is_act_and_mul
,
# TODO(rob): enable selecting this externally.
ct2vllm_weight
=
{
QuantizationStrategy
.
CHANNEL
:
kFp8StaticChannelSym
,
QuantizationStrategy
.
TENSOR
:
kFp8StaticTensorSym
,
QuantizationStrategy
.
BLOCK
:
kFp8Static128BlockSym
,
}
ct2vllm_act
=
{
QuantizationStrategy
.
TOKEN
:
kFp8DynamicTokenSym
,
QuantizationStrategy
.
TENSOR
:
(
kFp8StaticTensorSym
if
self
.
static_input_scales
else
kFp8Dynamic128Sym
),
}
weight_key
=
ct2vllm_weight
[
self
.
weight_quant
.
strategy
]
if
weight_key
==
kFp8Static128BlockSym
:
activation_key
=
kFp8Dynamic128Sym
else
:
activation_key
=
ct2vllm_act
[
self
.
input_quant
.
strategy
]
# Select Fp8 MoE backend
self
.
fp8_backend
,
self
.
experts_cls
=
select_fp8_moe_backend
(
config
=
self
.
moe
,
weight_key
=
weight_key
,
activation_key
=
activation_key
,
allow_vllm_cutlass
=
True
,
)
if
self
.
fp8_backend
!=
Fp8MoeBackend
.
MARLIN
:
per_act_token
=
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
per_channel_quant
=
(
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
)
if
per_act_token
!=
per_channel_quant
:
raise
NotImplementedError
(
"For FP8 Fused MoE layers, per-token and per-channel must be "
"used together."
)
# TODO(rob): hook this up in a follow up PR.
if
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
:
raise
NotImplementedError
(
"FlashInfer TRTLLM backend not supported for compressed-tensors yet."
)
self
.
disable_expert_map
=
False
# Delay creation of the kernel until after process-weights.
self
.
kernel
:
mk
.
FusedMoEModularKernel
|
None
=
None
@
property
def
topk_indices_dtype
(
self
)
->
torch
.
dtype
|
None
:
if
self
.
kernel
is
not
None
:
return
self
.
kernel
.
prepare_finalize
.
topk_indices_dtype
()
return
None
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -970,140 +988,75 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
replace_parameter
(
layer
,
"w13_weight_scale"
,
w13_scale
)
replace_parameter
(
layer
,
"w2_weight_scale"
,
w2_scale
)
# Setup modular kernel for TP case and naive DP/EP case.
# In non-naive DP/EP case, we will create a ModularKernelMethod.
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self
.
moe_quant_config
=
self
.
get_fused_moe_quant_config
(
layer
)
if
self
.
moe_quant_config
:
if
self
.
moe_quant_config
and
(
(
not
self
.
moe
.
moe_parallel_config
.
use_all2all_kernels
)
or
self
.
moe
.
moe_parallel_config
.
use_naive_all2all_kernels
):
assert
self
.
experts_cls
is
not
None
self
.
kernel
,
self
.
use_inplace
=
make_fp8_moe_kernel
(
layer
=
layer
,
moe_quant_config
=
self
.
moe_quant_config
,
moe_config
=
self
.
moe
,
fp8_backend
=
self
.
fp8_backend
,
experts_cls
=
self
.
experts_cls
,
)
def
maybe_make_prepare_finalize
(
self
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
)
->
mk
.
FusedMoEPrepareAndFinalize
|
None
:
if
self
.
fp8_backend
in
[
Fp8MoeBackend
.
MARLIN
,
Fp8MoeBackend
.
AITER
]
:
if
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
:
return
None
else
:
return
super
().
maybe_make_prepare_finalize
(
routing_tables
)
elif
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_CUTLASS
:
# For no-EP case, don't use the MKM framework.
if
not
self
.
moe
.
moe_parallel_config
.
use_all2all_kernels
:
return
None
prepare_finalize
=
build_flashinfer_fp8_cutlass_moe_prepare_finalize
(
self
.
moe
,
use_deepseek_fp8_block_scale
=
self
.
block_quant
,
)
logger
.
debug_once
(
"%s"
,
prepare_finalize
.
__class__
.
__name__
)
return
prepare_finalize
return
super
().
maybe_make_prepare_finalize
(
routing_tables
)
def
select_gemm_impl
(
self
,
prepare_finalize
:
mk
.
FusedMoEPrepareAndFinalize
,
layer
:
torch
.
nn
.
Module
,
)
->
FusedMoEPermuteExpertsUnpermute
:
# cutlass path
assert
self
.
moe_quant_config
is
not
None
if
self
.
fp8_backend
==
Fp8MoeBackend
.
VLLM_CUTLASS
:
from
vllm.model_executor.layers.fused_moe
import
(
CutlassBatchedExpertsFp8
,
CutlassExpertsFp8
,
)
experts
:
FusedMoEPermuteExpertsUnpermute
num_dispatchers
=
prepare_finalize
.
num_dispatchers
()
if
(
prepare_finalize
.
activation_format
==
FusedMoEActivationFormat
.
BatchedExperts
):
logger
.
debug
(
"CutlassBatchedExpertsFp8(%s)"
,
self
.
__class__
.
__name__
)
experts
=
CutlassBatchedExpertsFp8
(
max_experts_per_worker
=
self
.
moe
.
num_local_experts
,
num_dispatchers
=
num_dispatchers
,
out_dtype
=
self
.
moe
.
in_dtype
,
e
=
layer
.
local_num_experts
,
n
=
layer
.
intermediate_size_per_partition
,
k
=
layer
.
hidden_size
,
device
=
layer
.
w13_weight
.
device
,
quant_config
=
self
.
moe_quant_config
,
)
else
:
logger
.
debug
(
"CutlassExpertsFp8(%s)"
,
self
.
__class__
.
__name__
)
experts
=
CutlassExpertsFp8
(
out_dtype
=
self
.
moe
.
in_dtype
,
e
=
layer
.
local_num_experts
,
n
=
layer
.
intermediate_size_per_partition
,
k
=
layer
.
hidden_size
,
device
=
layer
.
w13_weight
.
device
,
quant_config
=
self
.
moe_quant_config
,
)
# TODO(rob): investigate disable_expert_map
self
.
disable_expert_map
=
(
num_dispatchers
>
1
or
not
experts
.
supports_expert_map
()
)
return
experts
from
vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe
import
(
BatchedDeepGemmExperts
,
)
from
vllm.model_executor.layers.fused_moe.fused_batched_moe
import
(
BatchedTritonExperts
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
TritonExperts
,
)
from
vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe
import
(
TritonOrDeepGemmExperts
,
assert
self
.
experts_cls
is
not
None
return
make_fp8_moe_kernel_for_mkm
(
moe_config
=
self
.
moe
,
quant_config
=
self
.
moe_quant_config
,
experts_cls
=
self
.
experts_cls
,
prepare_finalize
=
prepare_finalize
,
)
assert
self
.
fp8_backend
not
in
[
Fp8MoeBackend
.
AITER
,
Fp8MoeBackend
.
MARLIN
]
if
(
prepare_finalize
.
activation_format
==
FusedMoEActivationFormat
.
BatchedExperts
):
max_num_tokens_per_rank
=
prepare_finalize
.
max_num_tokens_per_rank
()
assert
max_num_tokens_per_rank
is
not
None
if
self
.
fp8_backend
==
Fp8MoeBackend
.
DEEPGEMM
:
logger
.
debug
(
"BatchedDeepGemmExperts(%s)"
,
self
.
__class__
.
__name__
)
return
BatchedDeepGemmExperts
(
max_num_tokens
=
max_num_tokens_per_rank
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
quant_config
=
self
.
moe_quant_config
,
)
else
:
logger
.
debug
(
"BatchedTritonExperts(%s)"
,
self
.
__class__
.
__name__
)
return
BatchedTritonExperts
(
max_num_tokens
=
max_num_tokens_per_rank
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
quant_config
=
self
.
moe_quant_config
,
)
else
:
if
self
.
fp8_backend
==
Fp8MoeBackend
.
DEEPGEMM
:
logger
.
debug
(
"TritonOrDeepGemmExperts(%s)"
,
self
.
__class__
.
__name__
)
return
TritonOrDeepGemmExperts
(
self
.
moe_quant_config
)
else
:
logger
.
debug
(
"TritonExperts(%s)"
,
self
.
__class__
.
__name__
)
return
TritonExperts
(
self
.
moe_quant_config
)
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
if
self
.
fp8_backend
==
Fp8MoeBackend
.
MARLIN
:
return
fp8_w8a16_moe_quant_config
(
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
block_shape
=
self
.
weight_block_size
,
)
per_act_token
=
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
per_channel_quant
=
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
w1_scale
=
layer
.
w13_weight_scale
w2_scale
=
layer
.
w2_weight_scale
a1_scale
=
layer
.
w13_input_scale
a2_scale
=
layer
.
w2_input_scale
return
fp8_w8a8_moe_quant_config
(
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
per_act_token_quant
=
per_act_token
,
per_out_ch_quant
=
per_channel_quant
,
block_shape
=
layer
.
weight_block_size
,
return
make_fp8_moe_quant_config
(
fp8_backend
=
self
.
fp8_backend
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
per_act_token_quant
=
(
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
),
per_out_ch_quant
=
(
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
),
block_shape
=
self
.
weight_block_size
,
)
def
apply
(
...
...
@@ -1113,6 +1066,56 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
:
if
layer
.
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `FlashInfer TRTLLM FP8 MoE`."
)
assert
layer
.
activation
==
"silu"
if
self
.
block_quant
:
import
vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe
# noqa: E501, F401
e_score_correction_bias
=
(
layer
.
e_score_correction_bias
.
to
(
x
.
dtype
)
if
layer
.
e_score_correction_bias
is
not
None
else
None
)
routing_method_type
=
layer
.
routing_method_type
return
torch
.
ops
.
vllm
.
flashinfer_fused_moe_blockscale_fp8
(
routing_logits
=
router_logits
.
to
(
torch
.
float32
)
if
routing_method_type
==
RoutingMethodType
.
DeepSeekV3
else
router_logits
,
routing_bias
=
e_score_correction_bias
,
x
=
x
,
w13_weight
=
layer
.
w13_weight
,
w13_weight_scale_inv
=
layer
.
w13_weight_scale
,
w2_weight
=
layer
.
w2_weight
,
w2_weight_scale_inv
=
layer
.
w2_weight_scale
,
global_num_experts
=
layer
.
global_num_experts
,
top_k
=
layer
.
top_k
,
num_expert_group
=
layer
.
num_expert_group
,
topk_group
=
layer
.
topk_group
,
intermediate_size
=
layer
.
intermediate_size_per_partition
,
expert_offset
=
layer
.
ep_rank
*
layer
.
local_num_experts
,
local_num_experts
=
layer
.
local_num_experts
,
block_shape
=
self
.
weight_block_size
,
routing_method_type
=
routing_method_type
,
routed_scaling
=
layer
.
routed_scaling_factor
,
)
else
:
return
apply_fi_trtllm_fp8_per_tensor_moe
(
layer
=
layer
,
hidden_states
=
x
,
router_logits
=
router_logits
,
routing_bias
=
layer
.
e_score_correction_bias
,
global_num_experts
=
layer
.
global_num_experts
,
top_k
=
layer
.
top_k
,
num_expert_group
=
layer
.
num_expert_group
,
topk_group
=
layer
.
topk_group
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
)
topk_weights
,
topk_ids
=
router
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
...
...
@@ -1130,7 +1133,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
global_num_experts
=
layer
.
global_num_experts
,
# TODO(rob): investigate the disable_expert_map introduced by:
# https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501
expert_map
=
None
if
self
.
disable_expert_map
else
layer
.
expert_map
,
expert_map
=
layer
.
expert_map
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
)
...
...
@@ -1596,6 +1599,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
return
BatchedMarlinExperts
(
max_num_tokens
=
max_num_tokens_per_rank
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
moe_config
=
self
.
moe
,
quant_config
=
self
.
moe_quant_config
,
w13_g_idx
=
layer
.
w13_weight_g_idx
,
w2_g_idx
=
layer
.
w2_weight_g_idx
,
...
...
@@ -1605,6 +1609,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
)
else
:
return
MarlinExperts
(
moe_config
=
self
.
moe
,
quant_config
=
self
.
moe_quant_config
,
w13_g_idx
=
layer
.
w13_weight_g_idx
,
w2_g_idx
=
layer
.
w2_weight_g_idx
,
...
...
@@ -1854,7 +1859,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer
.
w13_weight
=
layer
.
w13_weight_packed
layer
.
w2_weight
=
layer
.
w2_weight_packed
return
TritonWNA16Experts
(
quant_config
=
self
.
moe_quant_config
)
return
TritonWNA16Experts
(
moe_config
=
self
.
moe
,
quant_config
=
self
.
moe_quant_config
)
else
:
raise
NotImplementedError
(
"TritonExperts requires Triton. "
...
...
@@ -2467,6 +2474,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
c_strides2
=
self
.
a_strides1_c_strides2
,
s_strides1
=
self
.
s_strides1
,
s_strides2
=
self
.
s_strides2
,
moe_config
=
self
.
moe
,
quant_config
=
self
.
moe_quant_config
,
group_size
=
self
.
group_size
,
)
...
...
@@ -2505,6 +2513,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer
.
w2_weight_packed
,
topk_weights
,
topk_ids
,
moe_config
=
self
.
moe
,
quant_config
=
self
.
moe_quant_config
,
activation
=
layer
.
activation
,
global_num_experts
=
layer
.
global_num_experts
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
42135d68
...
...
@@ -19,7 +19,6 @@ from vllm.model_executor.layers.batch_invariant import (
)
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEActivationFormat
,
FusedMoEMethodBase
,
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPrepareAndFinalize
,
...
...
@@ -35,6 +34,7 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend
,
convert_to_fp8_moe_kernel_format
,
make_fp8_moe_kernel
,
make_fp8_moe_kernel_for_mkm
,
make_fp8_moe_quant_config
,
select_fp8_moe_backend
,
)
...
...
@@ -55,7 +55,6 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
apply_fi_trtllm_fp8_per_tensor_moe
,
build_flashinfer_fp8_cutlass_moe_prepare_finalize
,
select_cutlass_fp8_gemm_impl
,
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
W8A8BlockFp8LinearOp
,
...
...
@@ -79,8 +78,10 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
is_layer_skipped
,
kFp8Dynamic128Sym
,
kFp8DynamicTensorSym
,
kFp8DynamicTokenSym
,
kFp8Static128BlockSym
,
kFp8StaticTensorSym
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
...
...
@@ -658,38 +659,36 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self
.
weight_scale_name
=
(
"weight_scale_inv"
if
self
.
block_quant
else
"weight_scale"
)
self
.
fp8_backend
=
select_fp8_moe_backend
(
block_quant
=
self
.
block_quant
,
tp_size
=
layer
.
moe_parallel_config
.
tp_size
,
with_lora_support
=
self
.
moe
.
is_lora_enabled
,
is_act_and_mul
=
self
.
moe
.
is_act_and_mul
,
)
if
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_CUTLASS
:
if
self
.
block_quant
and
self
.
weight_block_size
!=
[
128
,
128
]:
raise
NotImplementedError
(
"FlashInfer CUTLASS FP8 MoE backend only supports block "
"size [128, 128]."
)
if
layer
.
activation
!=
"silu"
:
raise
NotImplementedError
(
"FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
"activation function, but got {layer.activation}."
)
dynamic_per_token
=
(
not
self
.
block_quant
and
self
.
quant_config
.
activation_scheme
!=
"static"
)
if
dynamic_per_token
and
self
.
fp8_backend
in
[
Fp8MoeBackend
.
FLASHINFER_TRTLLM
,
Fp8MoeBackend
.
FLASHINFER_CUTLASS
,
]:
raise
NotImplementedError
(
"FlashInfer FP8 MoE backend does not support dynamic per token "
"activation quantization."
# Set weight key and activation key for kernel compatibility
if
self
.
block_quant
:
weight_key
=
kFp8Static128BlockSym
activation_key
=
kFp8Dynamic128Sym
else
:
weight_key
=
kFp8StaticTensorSym
activation_key
=
(
kFp8StaticTensorSym
if
self
.
quant_config
.
activation_scheme
==
"static"
else
kFp8DynamicTensorSym
)
# Select Fp8 MoE backend
self
.
fp8_backend
,
self
.
experts_cls
=
select_fp8_moe_backend
(
config
=
self
.
moe
,
weight_key
=
weight_key
,
activation_key
=
activation_key
,
allow_vllm_cutlass
=
False
,
)
# Delay creation of the kernel until after process-weights.
self
.
kernel
:
mk
.
FusedMoEModularKernel
|
None
=
None
@
property
def
topk_indices_dtype
(
self
)
->
torch
.
dtype
|
None
:
if
self
.
kernel
is
not
None
:
return
self
.
kernel
.
prepare_finalize
.
topk_indices_dtype
()
return
None
def
create_weights
(
self
,
layer
:
Module
,
...
...
@@ -842,14 +841,21 @@ class Fp8MoEMethod(FusedMoEMethodBase):
replace_parameter
(
layer
,
f
"w13_
{
self
.
weight_scale_name
}
"
,
w13_scale
)
replace_parameter
(
layer
,
f
"w2_
{
self
.
weight_scale_name
}
"
,
w2_scale
)
# Setup modular kernel for TP case.
# Setup modular kernel for TP case and naive DP/EP case.
# In non-naive DP/EP case, we will create a ModularKernelMethod.
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self
.
moe_quant_config
=
self
.
get_fused_moe_quant_config
(
layer
)
if
self
.
moe_quant_config
:
if
self
.
moe_quant_config
and
(
(
not
self
.
moe
.
moe_parallel_config
.
use_all2all_kernels
)
or
self
.
moe
.
moe_parallel_config
.
use_naive_all2all_kernels
):
assert
self
.
experts_cls
is
not
None
self
.
kernel
,
self
.
use_inplace
=
make_fp8_moe_kernel
(
layer
=
layer
,
moe_quant_config
=
self
.
moe_quant_config
,
moe_config
=
self
.
moe
,
fp8_backend
=
self
.
fp8_backend
,
experts_cls
=
self
.
experts_cls
,
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
...
...
@@ -904,13 +910,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
)
->
mk
.
FusedMoEPrepareAndFinalize
|
None
:
if
self
.
fp8_backend
in
[
Fp8MoeBackend
.
AITER
,
Fp8MoeBackend
.
MARLIN
,
Fp8MoeBackend
.
FLASHINFER_TRTLLM
,
]:
if
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
:
return
None
elif
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_CUTLASS
:
# For no-EP case, don't use the MKM framework.
if
not
self
.
moe
.
moe_parallel_config
.
use_all2all_kernels
:
return
None
prepare_finalize
=
build_flashinfer_fp8_cutlass_moe_prepare_finalize
(
self
.
moe
,
use_deepseek_fp8_block_scale
=
self
.
block_quant
,
...
...
@@ -924,73 +930,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
layer
:
torch
.
nn
.
Module
,
)
->
FusedMoEPermuteExpertsUnpermute
:
from
vllm.model_executor.layers.fused_moe
import
(
BatchedDeepGemmExperts
,
BatchedTritonExperts
,
TritonExperts
,
TritonOrDeepGemmExperts
,
)
if
self
.
fp8_backend
in
[
Fp8MoeBackend
.
MARLIN
,
Fp8MoeBackend
.
AITER
]:
raise
NotImplementedError
(
"Marlin and ROCm AITER are not supported with all2all yet."
)
assert
self
.
moe_quant_config
is
not
None
if
(
prepare_finalize
.
activation_format
==
FusedMoEActivationFormat
.
BatchedExperts
):
max_num_tokens_per_rank
=
prepare_finalize
.
max_num_tokens_per_rank
()
assert
max_num_tokens_per_rank
is
not
None
experts_impl
=
(
BatchedDeepGemmExperts
if
self
.
fp8_backend
==
Fp8MoeBackend
.
DEEPGEMM
else
BatchedTritonExperts
)
logger
.
debug
(
"%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s"
,
experts_impl
.
__name__
,
self
.
__class__
.
__name__
,
max_num_tokens_per_rank
,
self
.
weight_block_size
,
False
,
)
return
experts_impl
(
max_num_tokens
=
max_num_tokens_per_rank
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
quant_config
=
self
.
moe_quant_config
,
)
elif
self
.
moe
.
is_lora_enabled
:
return
TritonExperts
(
quant_config
=
self
.
moe_quant_config
)
elif
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_CUTLASS
:
# Select GEMM experts with block-scale when weights are block-quantized
experts
=
select_cutlass_fp8_gemm_impl
(
self
.
moe
,
self
.
moe_quant_config
,
use_deepseek_fp8_block_scale
=
self
.
block_quant
,
)
logger
.
debug_once
(
"Using %s"
,
experts
.
__class__
.
__name__
)
return
experts
elif
self
.
fp8_backend
==
Fp8MoeBackend
.
DEEPGEMM
:
logger
.
debug
(
"TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s"
,
self
.
__class__
.
__name__
,
self
.
weight_block_size
,
False
,
)
return
TritonOrDeepGemmExperts
(
self
.
moe_quant_config
)
else
:
assert
self
.
fp8_backend
==
Fp8MoeBackend
.
TRITON
logger
.
debug
(
"TritonExperts(%s): block_size=%s, per_act_token=%s"
,
self
.
__class__
.
__name__
,
self
.
weight_block_size
,
False
,
)
return
TritonExperts
(
self
.
moe_quant_config
)
assert
self
.
experts_cls
is
not
None
return
make_fp8_moe_kernel_for_mkm
(
moe_config
=
self
.
moe
,
quant_config
=
self
.
moe_quant_config
,
experts_cls
=
self
.
experts_cls
,
prepare_finalize
=
prepare_finalize
,
)
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
...
...
@@ -1067,7 +1014,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routed_scaling
=
layer
.
routed_scaling_factor
,
)
else
:
re
sult
=
apply_fi_trtllm_fp8_per_tensor_moe
(
re
turn
apply_fi_trtllm_fp8_per_tensor_moe
(
layer
=
layer
,
hidden_states
=
x
,
router_logits
=
router_logits
,
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
42135d68
...
...
@@ -875,6 +875,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
return
BatchedMarlinExperts
(
max_num_tokens
=
max_num_tokens_per_rank
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
moe_config
=
self
.
moe
,
quant_config
=
self
.
moe_quant_config
,
w13_g_idx
=
w13_g_idx
,
w2_g_idx
=
w2_g_idx
,
...
...
@@ -885,6 +886,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
else
:
# Standard Marlin experts for GPTQ
return
MarlinExperts
(
moe_config
=
self
.
moe
,
quant_config
=
self
.
moe_quant_config
,
w13_g_idx
=
w13_g_idx
,
w2_g_idx
=
w2_g_idx
,
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
42135d68
...
...
@@ -27,15 +27,16 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend
,
convert_to_fp8_moe_kernel_format
,
make_fp8_moe_kernel
,
make_fp8_moe_kernel_for_mkm
,
make_fp8_moe_quant_config
,
select_fp8_moe_backend
,
)
from
vllm.model_executor.layers.fused_moe.oracle.nvfp4
import
(
FLASHINFER_NVFP4_MOE_BACKENDS
,
NvFp4MoeBackend
,
convert_to_nvfp4_moe_kernel_format
,
is_global_sf_supported_for_nvfp4_backend
,
make_nvfp4_moe_kernel
,
make_nvfp4_moe_kernel_for_mkm
,
make_nvfp4_moe_quant_config
,
select_nvfp4_moe_backend
,
)
...
...
@@ -57,12 +58,10 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize
,
flashinfer_trtllm_fp4_moe
,
flashinfer_trtllm_fp4_routed_moe
,
select_nvfp4_gemm_impl
,
)
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
apply_fi_trtllm_fp8_per_tensor_moe
,
build_flashinfer_fp8_cutlass_moe_prepare_finalize
,
select_cutlass_fp8_gemm_impl
,
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
W8A8BlockFp8LinearOp
,
...
...
@@ -84,6 +83,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8DynamicTokenSym
,
kFp8StaticTensorSym
,
kFp8StaticTokenSym
,
kNvfp4Dynamic
,
kNvfp4Static
,
swizzle_blockscale
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
...
...
@@ -728,14 +729,23 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
super
().
__init__
(
moe_config
)
self
.
quant_config
=
quant_config
assert
self
.
quant_config
.
is_checkpoint_fp8_serialized
self
.
fp8_backend
=
select_fp8_moe_backend
(
block_quant
=
False
,
tp_size
=
moe_config
.
moe_parallel_config
.
tp_size
,
with_lora_support
=
self
.
moe
.
is_lora_enabled
,
is_act_and_mul
=
self
.
moe
.
is_act_and_mul
,
# Select Fp8 MoE backend
self
.
fp8_backend
,
self
.
experts_cls
=
select_fp8_moe_backend
(
config
=
self
.
moe
,
weight_key
=
kFp8StaticTensorSym
,
activation_key
=
kFp8StaticTensorSym
,
)
# Delay creation of the kernel until after process-weights.
self
.
kernel
:
mk
.
FusedMoEModularKernel
|
None
=
None
@
property
def
topk_indices_dtype
(
self
)
->
torch
.
dtype
|
None
:
if
self
.
kernel
is
not
None
:
return
self
.
kernel
.
prepare_finalize
.
topk_indices_dtype
()
return
None
def
maybe_make_prepare_finalize
(
self
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
...
...
@@ -744,8 +754,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
if
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
:
return
None
elif
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_CUTLASS
:
#
TP case: avoid convert to ModularKernelMethod - to be refactored
.
if
self
.
moe
.
dp_size
==
1
:
#
For no-EP case, don't use the MKM framework
.
if
not
self
.
moe
.
moe_parallel_config
.
use_all2all_kernels
:
return
None
prepare_finalize
=
build_flashinfer_fp8_cutlass_moe_prepare_finalize
(
...
...
@@ -762,12 +772,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer
:
torch
.
nn
.
Module
,
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
assert
self
.
moe_quant_config
is
not
None
experts
=
select_cutlass_fp8_gemm_impl
(
self
.
moe
,
self
.
moe_quant_config
,
assert
self
.
experts_cls
is
not
None
return
make_fp8_moe_kernel_for_mkm
(
moe_config
=
self
.
moe
,
quant_config
=
self
.
moe_quant_config
,
experts_cls
=
self
.
experts_cls
,
prepare_finalize
=
prepare_finalize
,
)
logger
.
debug_once
(
"Using %s"
,
experts
.
__class__
.
__name__
)
return
experts
def
create_weights
(
self
,
...
...
@@ -876,14 +887,15 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
replace_parameter
(
layer
,
"w13_weight_scale"
,
w13_scale
)
replace_parameter
(
layer
,
"w2_weight_scale"
,
w2_scale
)
# Setup modular kernel
for TP case
.
# Setup modular kernel.
self
.
moe_quant_config
=
self
.
get_fused_moe_quant_config
(
layer
)
if
self
.
moe_quant_config
:
assert
self
.
experts_cls
is
not
None
self
.
kernel
,
self
.
use_inplace
=
make_fp8_moe_kernel
(
layer
=
layer
,
moe_quant_config
=
self
.
moe_quant_config
,
moe_config
=
self
.
moe
,
fp8_backend
=
self
.
fp8_backend
,
experts_cls
=
self
.
experts_cls
,
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
...
...
@@ -1335,32 +1347,35 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
)
->
None
:
super
().
__init__
(
moe_config
)
self
.
quant_config
=
quant_config
self
.
nvfp4_backend
=
select_nvfp4_moe_backend
()
# TODO: move this type of check into the oracle.
if
not
self
.
moe
.
is_act_and_mul
and
self
.
nvfp4_backend
not
in
[
NvFp4MoeBackend
.
FLASHINFER_CUTLASS
,
NvFp4MoeBackend
.
MARLIN
,
]:
raise
NotImplementedError
(
"Non-gated activations are only supported by FlashInfer "
f
"CUTLASS and Marlin NvFP4 MoE backends, not
{
self
.
nvfp4_backend
}
."
)
# Select experts implementation.
self
.
nvfp4_backend
,
self
.
experts_cls
=
select_nvfp4_moe_backend
(
config
=
self
.
moe
,
weight_key
=
kNvfp4Static
,
activation_key
=
kNvfp4Dynamic
,
)
# Delay creation of the kernel until after process-weights.
self
.
kernel
:
mk
.
FusedMoEModularKernel
|
None
=
None
self
.
use_global_sf
=
is_global_sf_supported_for_nvfp4_backend
(
self
.
nvfp4_backend
)
self
.
kernel
:
mk
.
FusedMoEModularKernel
|
None
=
None
@
property
def
topk_indices_dtype
(
self
)
->
torch
.
dtype
|
None
:
if
self
.
kernel
is
not
None
:
return
self
.
kernel
.
prepare_finalize
.
topk_indices_dtype
()
return
None
def
maybe_make_prepare_finalize
(
self
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
)
->
mk
.
FusedMoEPrepareAndFinalize
|
None
:
UNSUPPORTED
=
[
NvFp4MoeBackend
.
MARLIN
,
NvFp4MoeBackend
.
FLASHINFER_TRTLLM
]
if
self
.
nvfp4_backend
in
UNSUPPORTED
:
if
self
.
nvfp4_backend
==
NvFp4MoeBackend
.
FLASHINFER_TRTLLM
:
return
None
elif
self
.
nvfp4_backend
==
NvFp4MoeBackend
.
FLASHINFER_CUTLASS
:
#
TP case: avoid convert to ModularKernelMethod - to be refactored
.
if
self
.
moe
.
dp_size
==
1
:
#
For no-EP case, don't use the MKM framework
.
if
not
self
.
moe
.
moe_parallel_config
.
use_all2all_kernels
:
return
None
# For now, fp4 moe only works with the flashinfer dispatcher.
prepare_finalize
=
build_flashinfer_fp4_cutlass_moe_prepare_finalize
(
...
...
@@ -1377,13 +1392,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer
:
torch
.
nn
.
Module
,
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
assert
self
.
moe_quant_config
is
not
None
experts
=
select_nvfp4_gemm_impl
(
self
.
moe
,
self
.
moe_quant_config
,
allow_flashinfer
=
self
.
nvfp4_backend
in
FLASHINFER_NVFP4_MOE_BACKENDS
,
assert
self
.
experts_cls
is
not
None
return
make_nvfp4_moe_kernel_for_mkm
(
moe_config
=
self
.
moe
,
quant_config
=
self
.
moe_quant_config
,
experts_cls
=
self
.
experts_cls
,
prepare_finalize
=
prepare_finalize
,
)
logger
.
debug_once
(
"Using %s"
,
experts
.
__class__
.
__name__
)
return
experts
def
uses_weight_scale_2_pattern
(
self
)
->
bool
:
"""
...
...
@@ -1554,13 +1569,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
replace_parameter
(
layer
,
"w2_weight_scale_2"
,
w2_scale_2
)
replace_parameter
(
layer
,
"w2_input_scale"
,
a2_scale
)
# Setup modular kernel for TP case and naive DP/EP case.
# In non-naive DP/EP case, we will create a ModularKernelMethod.
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self
.
moe_quant_config
=
self
.
get_fused_moe_quant_config
(
layer
)
use_dp
=
self
.
moe
.
dp_size
>
1
if
self
.
moe_quant_config
is
not
None
and
not
use_dp
:
if
self
.
moe_quant_config
and
(
(
not
self
.
moe
.
moe_parallel_config
.
use_all2all_kernels
)
or
self
.
moe
.
moe_parallel_config
.
use_naive_all2all_kernels
):
assert
self
.
experts_cls
is
not
None
self
.
kernel
=
make_nvfp4_moe_kernel
(
backend
=
self
.
nvfp4_backend
,
quant_config
=
self
.
moe_quant_config
,
moe_quant_config
=
self
.
moe_quant_config
,
moe_config
=
self
.
moe
,
experts_cls
=
self
.
experts_cls
,
)
@
property
...
...
vllm/model_executor/layers/quantization/mxfp4.py
View file @
42135d68
...
...
@@ -853,6 +853,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
max_num_tokens
=
max_num_tokens_per_rank
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
quant_config
=
self
.
moe_quant_config
,
moe_config
=
self
.
moe
,
)
else
:
raise
NotImplementedError
(
...
...
@@ -875,11 +876,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
}
return
TrtLlmGenExperts
(
self
.
moe
,
self
.
moe_quant_config
,
**
kwargs
)
elif
self
.
mxfp4_backend
==
Mxfp4Backend
.
MARLIN
:
return
MarlinExperts
(
self
.
moe_quant_config
)
return
MarlinExperts
(
self
.
moe
,
self
.
moe_quant_config
)
elif
self
.
mxfp4_backend
==
Mxfp4Backend
.
TRITON
:
if
self
.
moe
.
is_lora_enabled
:
return
UnfusedOAITritonExperts
(
self
.
moe_quant_config
)
return
OAITritonExperts
(
self
.
moe_quant_config
)
return
UnfusedOAITritonExperts
(
self
.
moe
,
self
.
moe_quant_config
)
return
OAITritonExperts
(
self
.
moe
,
self
.
moe_quant_config
)
else
:
raise
NotImplementedError
(
f
"Incompatible Mxfp4 backend (
{
self
.
mxfp4_backend
}
) for EP"
...
...
vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
View file @
42135d68
...
...
@@ -11,19 +11,16 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoE
Quant
Config
,
FusedMoE
Parallel
Config
,
RoutingMethodType
,
)
from
vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe
import
(
FlashInferCuteDSLExperts
,
)
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
FlashInferExperts
,
)
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize
import
(
# noqa: E501
create_flashinfer_prepare_finalize
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kNvfp4Dynamic
,
kNvfp4Static
,
swizzle_blockscale
,
)
from
vllm.platforms
import
current_platform
...
...
@@ -47,6 +44,86 @@ __all__ = [
"build_flashinfer_fp4_cutlass_moe_prepare_finalize"
,
]
#
# Methods used by the oracle for kernel selection.
#
def
_supports_current_device
()
->
bool
:
"""Supports only Blackwell-family GPUs."""
p
=
current_platform
return
p
.
is_cuda
()
and
p
.
is_device_capability_family
(
100
)
def
_supports_no_act_and_mul
()
->
bool
:
"""Does not support non-gated MoE (i.e. Nemotron-Nano)."""
return
False
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
"""Supports Nvfp4 quantization."""
SUPPORTED_W_A
=
[
(
kNvfp4Static
,
kNvfp4Dynamic
),
]
return
(
weight_key
,
activation_key
)
in
SUPPORTED_W_A
def
_supports_activation
(
activation
:
str
)
->
bool
:
"""Supports silu activation only."""
return
activation
in
[
"silu"
]
def
_supports_routing_method
(
routing_method
:
RoutingMethodType
,
)
->
bool
:
"""Monolithic kernels need to express router support."""
# NOTE(rob): potentially allow others here. This is a conservative list.
return
routing_method
in
[
RoutingMethodType
.
DeepSeekV3
,
RoutingMethodType
.
Renormalize
,
RoutingMethodType
.
RenormalizeNaive
,
RoutingMethodType
.
Llama4
,
]
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
"""Supports EP."""
return
True
def
is_supported_config_trtllm
(
moe_config
:
FusedMoEConfig
,
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
activation_format
:
mk
.
FusedMoEActivationFormat
,
)
->
tuple
[
bool
,
str
|
None
]:
"""
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
"""
def
_make_reason
(
reason
:
str
)
->
str
:
return
f
"kernel does not support
{
reason
}
"
if
not
_supports_current_device
():
return
False
,
_make_reason
(
"current device"
)
elif
not
(
moe_config
.
is_act_and_mul
or
_supports_no_act_and_mul
()):
return
False
,
_make_reason
(
"no act_and_mul MLP layer"
)
elif
not
_supports_activation
(
moe_config
.
activation
):
return
False
,
_make_reason
(
f
"
{
moe_config
.
activation
}
activation"
)
elif
not
_supports_quant_scheme
(
weight_key
,
activation_key
):
return
False
,
_make_reason
(
"quantization scheme"
)
elif
not
_supports_parallel_config
(
moe_config
.
moe_parallel_config
):
return
False
,
_make_reason
(
"parallel config"
)
elif
not
_supports_routing_method
(
moe_config
.
routing_method
):
return
False
,
_make_reason
(
"routing method"
)
elif
activation_format
!=
mk
.
FusedMoEActivationFormat
.
Standard
:
return
False
,
_make_reason
(
"activation format"
)
return
True
,
None
def
is_flashinfer_fp4_cutlass_moe_available
()
->
bool
:
"""Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
...
...
@@ -96,37 +173,6 @@ def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
)
def
select_nvfp4_gemm_impl
(
moe
:
FusedMoEConfig
,
moe_quant_config
:
FusedMoEQuantConfig
,
allow_flashinfer
:
bool
,
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
if
allow_flashinfer
:
if
envs
.
VLLM_FLASHINFER_MOE_BACKEND
==
"masked_gemm"
:
return
FlashInferCuteDSLExperts
(
out_dtype
=
moe
.
in_dtype
,
quant_config
=
moe_quant_config
,
)
elif
envs
.
VLLM_FLASHINFER_MOE_BACKEND
==
"throughput"
:
return
FlashInferExperts
(
out_dtype
=
moe
.
in_dtype
,
quant_config
=
moe_quant_config
,
ep_rank
=
moe
.
moe_parallel_config
.
ep_rank
,
ep_size
=
moe
.
moe_parallel_config
.
ep_size
,
tp_rank
=
moe
.
moe_parallel_config
.
tp_rank
,
tp_size
=
moe
.
moe_parallel_config
.
tp_size
,
use_dp
=
moe
.
moe_parallel_config
.
dp_size
>
1
,
)
# native cutlass experts currently don't support DP; TP case won't call this
raise
ValueError
(
"CutlassExpertsFp4 doesn't support DP. Use flashinfer CUTLASS "
"Fused MoE backend instead (set VLLM_USE_FLASHINFER_MOE_FP4=1)"
)
def
prepare_static_weights_for_trtllm_fp4_moe
(
# args_dequant,
# args,
...
...
vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
View file @
42135d68
...
...
@@ -9,10 +9,6 @@ from vllm import envs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
FlashInferExperts
,
)
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize
import
(
# noqa: E501
create_flashinfer_prepare_finalize
,
...
...
@@ -203,33 +199,6 @@ def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
)
def
select_cutlass_fp8_gemm_impl
(
moe
:
FusedMoEConfig
|
None
,
quant_config
:
FusedMoEQuantConfig
,
out_dtype
:
torch
.
dtype
|
None
=
None
,
use_deepseek_fp8_block_scale
:
bool
=
False
,
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
"""Return a GEMM *experts* implementation for fused-MoE layers"""
if
moe
is
not
None
:
return
FlashInferExperts
(
out_dtype
=
moe
.
in_dtype
,
quant_config
=
quant_config
,
ep_rank
=
moe
.
moe_parallel_config
.
ep_rank
,
ep_size
=
moe
.
moe_parallel_config
.
ep_size
,
tp_rank
=
moe
.
moe_parallel_config
.
tp_rank
,
tp_size
=
moe
.
moe_parallel_config
.
tp_size
,
use_deepseek_fp8_block_scale
=
use_deepseek_fp8_block_scale
,
)
assert
out_dtype
is
not
None
,
"If moe config is None, out_dtype must be passed"
return
FlashInferExperts
(
out_dtype
=
out_dtype
,
quant_config
=
quant_config
,
use_deepseek_fp8_block_scale
=
use_deepseek_fp8_block_scale
,
)
def
get_flashinfer_moe_backend
()
->
FlashinferMoeBackend
:
backend_map
=
{
"throughput"
:
FlashinferMoeBackend
.
CUTLASS
,
...
...
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
42135d68
...
...
@@ -48,6 +48,7 @@ class GroupShape(_GroupShape):
# Aliases for common quantization group shapes
PER_TENSOR
:
ClassVar
[
"GroupShape"
]
PER_TOKEN
:
ClassVar
[
"GroupShape"
]
PER_CHANNEL
:
ClassVar
[
"GroupShape"
]
def
is_per_tensor
(
self
)
->
bool
:
return
self
.
row
==
-
1
and
self
.
col
==
-
1
...
...
@@ -55,12 +56,16 @@ class GroupShape(_GroupShape):
def
is_per_token
(
self
)
->
bool
:
return
self
.
row
==
1
and
self
.
col
==
-
1
def
is_per_channel
(
self
)
->
bool
:
return
self
.
row
==
-
1
and
self
.
col
==
1
def
is_per_group
(
self
)
->
bool
:
return
self
.
row
==
1
and
self
.
col
>=
1
GroupShape
.
PER_TENSOR
=
GroupShape
(
-
1
,
-
1
)
GroupShape
.
PER_TOKEN
=
GroupShape
(
1
,
-
1
)
GroupShape
.
PER_CHANNEL
=
GroupShape
(
-
1
,
1
)
@
dataclass
(
frozen
=
True
)
...
...
@@ -77,16 +82,12 @@ class ScaleDesc:
group_shape
:
GroupShape
def
__str__
(
self
):
group_shape
=
(
"per_tensor"
if
self
.
group_shape
==
GroupShape
.
PER_TENSOR
else
(
"per_token"
if
self
.
group_shape
==
GroupShape
.
PER_TOKEN
else
str
(
self
.
group_shape
)
)
)
d
=
{
GroupShape
.
PER_TENSOR
:
"per_tensor"
,
GroupShape
.
PER_TOKEN
:
"per_token"
,
GroupShape
.
PER_CHANNEL
:
"per_channel"
,
}
group_shape
=
d
.
get
(
self
.
group_shape
,
str
(
self
.
group_shape
))
return
(
f
"
{
fx
.
graph
.
dtype_abbrs
[
self
.
dtype
]
}
,"
f
"
{
'static'
if
self
.
static
else
'dynamic'
}
,
{
group_shape
}
"
...
...
@@ -126,15 +127,28 @@ kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True)
kStaticTokenScale
=
ScaleDesc
(
torch
.
float32
,
True
,
GroupShape
.
PER_TOKEN
)
kFp8StaticTokenSym
=
QuantKey
(
FP8_DTYPE
,
kStaticTokenScale
,
symmetric
=
True
)
kStaticChannelScale
=
ScaleDesc
(
torch
.
float32
,
True
,
GroupShape
.
PER_CHANNEL
)
kFp8StaticChannelSym
=
QuantKey
(
FP8_DTYPE
,
kStaticChannelScale
,
symmetric
=
True
)
kDynamicTokenScale
=
ScaleDesc
(
torch
.
float32
,
False
,
GroupShape
.
PER_TOKEN
)
kFp8DynamicTokenSym
=
QuantKey
(
FP8_DTYPE
,
kDynamicTokenScale
,
symmetric
=
True
)
kNvfp4GroupScale
=
ScaleDesc
(
FP8_DTYPE
,
False
,
GroupShape
(
1
,
16
))
kNvfp4Quant
=
QuantKey
(
FP4_DTYPE
,
scale
=
kNvfp4GroupScale
,
scale2
=
kStaticTensorScale
)
kNvfp4DynamicGroupScale
=
ScaleDesc
(
FP8_DTYPE
,
False
,
GroupShape
(
1
,
16
))
kNvfp4Dynamic
=
QuantKey
(
FP4_DTYPE
,
scale
=
kNvfp4DynamicGroupScale
,
scale2
=
kStaticTensorScale
)
kNvfp4StaticGroupScale
=
ScaleDesc
(
FP8_DTYPE
,
True
,
GroupShape
(
1
,
16
))
kNvfp4Static
=
QuantKey
(
FP4_DTYPE
,
scale
=
kNvfp4StaticGroupScale
,
scale2
=
kStaticTensorScale
)
kDynamic128Scale
=
ScaleDesc
(
torch
.
float32
,
False
,
GroupShape
(
1
,
128
))
kFp8Dynamic128Sym
=
QuantKey
(
FP8_DTYPE
,
kDynamic128Scale
,
symmetric
=
True
)
kStatic128BlockScale
=
ScaleDesc
(
torch
.
float32
,
True
,
GroupShape
(
128
,
128
))
kFp8Static128BlockSym
=
QuantKey
(
FP8_DTYPE
,
kStatic128BlockScale
,
symmetric
=
True
)
kDynamic64Scale
=
ScaleDesc
(
torch
.
float32
,
False
,
GroupShape
(
1
,
64
))
kFp8Dynamic64Sym
=
QuantKey
(
FP8_DTYPE
,
kDynamic64Scale
,
symmetric
=
True
)
...
...
vllm/model_executor/models/qwen3_moe.py
View file @
42135d68
...
...
@@ -43,7 +43,6 @@ from vllm.distributed import (
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe.config
import
RoutingMethodType
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
...
@@ -172,7 +171,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
is_sequence_parallel
=
self
.
is_sequence_parallel
,
routing_method_type
=
RoutingMethodType
.
Renormalize
,
)
self
.
gate
=
ReplicatedLinear
(
...
...
vllm/model_executor/models/qwen3_next.py
View file @
42135d68
...
...
@@ -34,7 +34,6 @@ from vllm.model_executor.layers.fla.ops import (
fused_recurrent_gated_delta_rule
,
)
from
vllm.model_executor.layers.fused_moe
import
SharedFusedMoE
from
vllm.model_executor.layers.fused_moe.config
import
RoutingMethodType
from
vllm.model_executor.layers.layernorm
import
(
GemmaRMSNorm
as
Qwen3NextRMSNorm
,
)
...
...
@@ -181,7 +180,6 @@ class Qwen3NextSparseMoeBlock(nn.Module):
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
is_sequence_parallel
=
self
.
is_sequence_parallel
,
routing_method_type
=
RoutingMethodType
.
Renormalize
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment