Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
42135d68
Unverified
Commit
42135d68
authored
Jan 21, 2026
by
Robert Shaw
Committed by
GitHub
Jan 21, 2026
Browse files
[MoE Refactor] Oracle Select FP8+NVFP4 Kernels In Priority (#32414)
parent
e14467be
Changes
82
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1739 additions
and
857 deletions
+1739
-857
vllm/compilation/activation_quant_fusion.py
vllm/compilation/activation_quant_fusion.py
+3
-3
vllm/compilation/fusion.py
vllm/compilation/fusion.py
+2
-2
vllm/compilation/fusion_attn.py
vllm/compilation/fusion_attn.py
+2
-2
vllm/compilation/matcher_utils.py
vllm/compilation/matcher_utils.py
+2
-2
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+48
-13
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+34
-1
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+164
-91
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+39
-93
vllm/model_executor/layers/fused_moe/fallback.py
vllm/model_executor/layers/fused_moe/fallback.py
+70
-8
vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
...model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
+50
-41
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
...model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
+93
-106
vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
.../model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
+105
-1
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
+121
-26
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+80
-46
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+120
-77
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
...l_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
+47
-28
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+36
-34
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+115
-6
vllm/model_executor/layers/fused_moe/oracle/fp8.py
vllm/model_executor/layers/fused_moe/oracle/fp8.py
+350
-186
vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
+258
-91
No files found.
vllm/compilation/activation_quant_fusion.py
View file @
42135d68
...
...
@@ -18,7 +18,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kFp8StaticTensorSym
,
kNvfp4
Quant
,
kNvfp4
Dynamic
,
)
from
vllm.platforms
import
current_platform
...
...
@@ -41,7 +41,7 @@ silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr(
torch
.
ops
.
_C
,
"silu_and_mul_nvfp4_quant"
)
if
silu_and_mul_nvfp4_quant_supported
:
FUSED_OPS
[
kNvfp4
Quant
]
=
torch
.
ops
.
_C
.
silu_and_mul_nvfp4_quant
.
default
# noqa: E501
FUSED_OPS
[
kNvfp4
Dynamic
]
=
torch
.
ops
.
_C
.
silu_and_mul_nvfp4_quant
.
default
# noqa: E501
class
ActivationQuantPattern
(
ABC
):
...
...
@@ -129,7 +129,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
"""
def
__init__
(
self
)
->
None
:
super
().
__init__
(
kNvfp4
Quant
)
super
().
__init__
(
kNvfp4
Dynamic
)
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]:
result
=
self
.
empty_quant
(
5
,
32
)
...
...
vllm/compilation/fusion.py
View file @
42135d68
...
...
@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8DynamicTensorSym
,
kFp8DynamicTokenSym
,
kFp8StaticTensorSym
,
kNvfp4
Quant
,
kNvfp4
Dynamic
,
kStaticTensorScale
,
)
from
vllm.platforms
import
current_platform
...
...
@@ -63,7 +63,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8DynamicTokenSym
:
torch
.
ops
.
_C
.
dynamic_per_token_scaled_fp8_quant
.
default
,
# noqa: E501
}
if
current_platform
.
is_cuda
()
and
hasattr
(
torch
.
ops
.
_C
,
"scaled_fp4_quant"
):
QUANT_OPS
[
kNvfp4
Quant
]
=
torch
.
ops
.
_C
.
scaled_fp4_quant
.
default
QUANT_OPS
[
kNvfp4
Dynamic
]
=
torch
.
ops
.
_C
.
scaled_fp4_quant
.
default
if
current_platform
.
is_cuda
():
QUANT_OPS
[
kFp8Dynamic128Sym
]
=
torch
.
ops
.
_C
.
per_token_group_fp8_quant
.
default
# noqa: E501
QUANT_OPS
[
kFp8Dynamic64Sym
]
=
torch
.
ops
.
_C
.
per_token_group_fp8_quant
.
default
# noqa: E501
...
...
vllm/compilation/fusion_attn.py
View file @
42135d68
...
...
@@ -16,7 +16,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kNvfp4
Quant
,
kNvfp4
Dynamic
,
kStaticTensorScale
,
)
from
vllm.platforms
import
current_platform
...
...
@@ -217,7 +217,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
"""
def
__init__
(
self
,
layer
:
Attention
,
dtype
:
torch
.
dtype
)
->
None
:
super
().
__init__
(
layer
,
kNvfp4
Quant
,
dtype
)
super
().
__init__
(
layer
,
kNvfp4
Dynamic
,
dtype
)
def
_register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
def
pattern
(
...
...
vllm/compilation/matcher_utils.py
View file @
42135d68
...
...
@@ -21,7 +21,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8DynamicTensorSym
,
kFp8DynamicTokenSym
,
kFp8StaticTensorSym
,
kNvfp4
Quant
,
kNvfp4
Dynamic
,
)
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.platforms
import
current_platform
...
...
@@ -38,7 +38,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
}
if
current_platform
.
is_cuda
()
and
hasattr
(
torch
.
ops
.
_C
,
"scaled_fp4_quant"
):
QUANT_OPS
[
kNvfp4
Quant
]
=
torch
.
ops
.
_C
.
scaled_fp4_quant
.
default
# noqa: E501
QUANT_OPS
[
kNvfp4
Dynamic
]
=
torch
.
ops
.
_C
.
scaled_fp4_quant
.
default
# noqa: E501
if
current_platform
.
is_cuda
():
QUANT_OPS
[
kFp8Dynamic128Sym
]
=
torch
.
ops
.
_C
.
per_token_group_fp8_quant
.
default
# noqa: E501
...
...
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
42135d68
...
...
@@ -7,11 +7,20 @@ import torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.forward_context
import
get_forward_context
,
is_forward_context_available
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceDelegate
,
)
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kFp8Dynamic128Sym
,
kFp8Static128BlockSym
,
)
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.deep_gemm
import
(
...
...
@@ -19,6 +28,7 @@ from vllm.utils.deep_gemm import (
fp8_m_grouped_gemm_nt_masked
,
get_mk_alignment_for_contiguous_layout
,
is_deep_gemm_e8m0_used
,
is_deep_gemm_supported
,
)
from
vllm.utils.math_utils
import
cdiv
,
round_up
...
...
@@ -253,29 +263,52 @@ def persistent_masked_m_silu_mul_quant(
class
BatchedDeepGemmExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
,
max_num_tokens
:
int
,
num_dispatchers
:
int
,
quant_config
:
FusedMoEQuantConfig
,
):
"""
max_num_tokens: Maximum number of tokens from a DP Rank
num_dispatchers: The number of DP dispatchers.
quant_config: Quantization configuration
"""
super
().
__init__
(
quant_config
)
super
().
__init__
(
moe_config
=
moe_config
,
quant_config
=
quant_config
,
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
num_dispatchers
,
)
assert
self
.
block_shape
==
get_mk_alignment_for_contiguous_layout
()
assert
self
.
quant_config
.
use_fp8_w8a8
self
.
max_num_tokens
=
max_num_tokens
self
.
num_dispatchers
=
num_dispatchers
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
return
(
mk
.
FusedMoEActivationFormat
.
BatchedExperts
,
mk
.
FusedMoEActivationFormat
.
BatchedExperts
,
)
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
BatchedExperts
@
staticmethod
def
_supports_current_device
()
->
bool
:
return
is_deep_gemm_supported
()
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
return
False
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
SUPPORTED_W_A
=
[(
kFp8Static128BlockSym
,
kFp8Dynamic128Sym
)]
return
(
weight_key
,
activation_key
)
in
SUPPORTED_W_A
@
staticmethod
def
_supports_activation
(
activation
:
str
)
->
bool
:
return
activation
in
[
"silu"
]
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
return
True
def
supports_chunking
(
self
)
->
bool
:
return
False
...
...
@@ -310,6 +343,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
assert
self
.
num_dispatchers
is
not
None
assert
self
.
max_num_tokens
is
not
None
num_dispatchers
=
self
.
num_dispatchers
num_experts
=
local_num_experts
max_num_tokens
=
M
if
self
.
max_num_tokens
is
None
else
self
.
max_num_tokens
...
...
vllm/model_executor/layers/fused_moe/config.py
View file @
42135d68
...
...
@@ -862,6 +862,7 @@ class FusedMoEParallelConfig:
use_ep
:
bool
# whether to use EP or not
all2all_backend
:
str
# all2all backend for MoE communication
enable_eplb
:
bool
# whether to enable expert load balancing
@
property
def
use_all2all_kernels
(
self
):
...
...
@@ -882,6 +883,16 @@ class FusedMoEParallelConfig:
def
use_deepep_ll_kernels
(
self
):
return
self
.
use_all2all_kernels
and
self
.
all2all_backend
==
"deepep_low_latency"
@
property
def
use_batched_activation_format
(
self
):
return
self
.
use_deepep_ll_kernels
or
self
.
use_pplx_kernels
@
property
def
use_naive_all2all_kernels
(
self
):
return
self
.
use_all2all_kernels
and
(
self
.
all2all_backend
in
[
"naive"
,
"allgather_reducescatter"
]
)
@
staticmethod
def
flatten_tp_across_dp_and_pcp
(
tp_size
:
int
,
dp_size
:
int
,
dp_rank
:
int
,
pcp_size
:
int
,
pcp_rank
:
int
...
...
@@ -999,6 +1010,7 @@ class FusedMoEParallelConfig:
ep_rank
=
0
,
use_ep
=
False
,
all2all_backend
=
vllm_parallel_config
.
all2all_backend
,
enable_eplb
=
vllm_parallel_config
.
enable_eplb
,
)
# DP + EP / TP + EP / DP + TP + EP
assert
use_ep
...
...
@@ -1017,6 +1029,24 @@ class FusedMoEParallelConfig:
ep_rank
=
ep_rank
,
use_ep
=
True
,
all2all_backend
=
vllm_parallel_config
.
all2all_backend
,
enable_eplb
=
vllm_parallel_config
.
enable_eplb
,
)
@
classmethod
def
make_no_parallel
(
cls
)
->
"FusedMoEParallelConfig"
:
"""For usage in CI/CD and testing."""
return
FusedMoEParallelConfig
(
tp_size
=
1
,
tp_rank
=
0
,
pcp_size
=
1
,
pcp_rank
=
0
,
dp_size
=
1
,
dp_rank
=
0
,
ep_size
=
1
,
ep_rank
=
0
,
use_ep
=
False
,
all2all_backend
=
"naive"
,
enable_eplb
=
False
,
)
...
...
@@ -1026,8 +1056,11 @@ class FusedMoEConfig:
num_experts
:
int
experts_per_token
:
int
hidden_dim
:
int
intermediate_size_per_partition
:
int
num_local_experts
:
int
activation
:
str
device
:
torch
.
device
|
str
routing_method
:
RoutingMethodType
moe_parallel_config
:
FusedMoEParallelConfig
# The activation type.
...
...
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
42135d68
...
...
@@ -7,7 +7,11 @@ import torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.moe_permute_unpermute
import
(
moe_permute
,
moe_unpermute
,
...
...
@@ -23,6 +27,19 @@ from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache
,
apply_moe_activation
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kFp8DynamicTensorSym
,
kFp8DynamicTokenSym
,
kFp8StaticChannelSym
,
kFp8StaticTensorSym
,
kNvfp4Dynamic
,
kNvfp4Static
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
cutlass_group_gemm_supported
,
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
...
...
@@ -238,29 +255,57 @@ def run_cutlass_moe_fp8(
class
CutlassExpertsFp8Base
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
e
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
torch
.
dtype
|
None
,
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
,
device
:
torch
.
dtype
,
max_num_tokens
:
int
|
None
=
None
,
num_dispatchers
:
int
|
None
=
None
,
):
super
().
__init__
(
moe_config
=
moe_config
,
quant_config
=
quant_config
,
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
num_dispatchers
,
)
assert
quant_config
.
use_fp8_w8a8
super
().
__init__
(
quant_config
)
# E: num_experts
# N: intermediate size per partition
# K: hidden dim
e
=
moe_config
.
num_local_experts
n
=
moe_config
.
intermediate_size_per_partition
k
=
moe_config
.
hidden_dim
device
=
moe_config
.
device
ab_strides1_c_strides2
=
torch
.
full
((
e
,),
k
,
device
=
device
,
dtype
=
torch
.
int64
)
ab_strides2
=
torch
.
full
((
e
,),
n
,
device
=
device
,
dtype
=
torch
.
int64
)
c_strides1
=
torch
.
full
((
e
,),
2
*
n
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
out_dtype
=
out
_dtype
self
.
out_dtype
=
moe_config
.
in
_dtype
self
.
ab_strides1
=
ab_strides1_c_strides2
self
.
ab_strides2
=
ab_strides2
self
.
c_strides1
=
c_strides1
self
.
c_strides2
=
ab_strides1_c_strides2
@
staticmethod
def
_supports_current_device
()
->
bool
:
return
cutlass_group_gemm_supported
()
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
return
False
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
SUPPORTED_W_A
=
[
(
kFp8StaticChannelSym
,
kFp8DynamicTokenSym
),
(
kFp8StaticTensorSym
,
kFp8DynamicTensorSym
),
(
kFp8StaticTensorSym
,
kFp8StaticTensorSym
),
]
return
(
weight_key
,
activation_key
)
in
SUPPORTED_W_A
@
staticmethod
def
_supports_activation
(
activation
:
str
)
->
bool
:
return
activation
in
[
"silu"
,
"gelu"
,
"swigluoai"
]
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
# Let PrepareAndFinalize::finalize() decide the impl.
return
TopKWeightAndReduceDelegate
()
...
...
@@ -291,7 +336,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
expert_num_tokens
=
expert_tokens_meta
.
expert_num_tokens
use_batched_format
=
(
self
.
activation_format
s
[
0
]
==
mk
.
FusedMoEActivationFormat
.
BatchedExperts
self
.
activation_format
()
==
mk
.
FusedMoEActivationFormat
.
BatchedExperts
)
in_dtype
=
hidden_states
.
dtype
...
...
@@ -324,20 +369,23 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
class
CutlassExpertsFp8
(
CutlassExpertsFp8Base
):
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
return
(
mk
.
FusedMoEActivationFormat
.
Standard
,
mk
.
FusedMoEActivationFormat
.
Standard
,
)
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
# CutlassExpertsFp8 does not support expert map, which is
# needed for STANDARD activation format kernels in DP/EP mode.
# Note that the BATCHED activation format does not use
# the expert map for identifying experts.
return
not
moe_parallel_config
.
use_all2all_kernels
def
supports_chunking
(
self
)
->
bool
:
return
True
def
supports_expert_map
(
self
)
->
bool
:
return
Tru
e
return
Fals
e
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
# topk weights and reduction are fused in moe_unpermute cuda kernel
...
...
@@ -365,26 +413,16 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
class
CutlassBatchedExpertsFp8
(
CutlassExpertsFp8Base
):
def
__init__
(
self
,
max_experts_per_worker
:
int
,
num_dispatchers
:
int
,
*
args
,
**
kwargs
,
):
super
().
__init__
(
*
args
,
**
kwargs
)
assert
max_experts_per_worker
>
0
self
.
max_experts_per_worker
=
max_experts_per_worker
self
.
num_dispatchers
=
num_dispatchers
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
# BATCHED activation format works with EP because
# expert_map is not used to identify experts (the
# info is encoded/managed by the P/F logic).
return
True
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
return
(
mk
.
FusedMoEActivationFormat
.
BatchedExperts
,
mk
.
FusedMoEActivationFormat
.
BatchedExperts
,
)
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
BatchedExperts
def
supports_chunking
(
self
)
->
bool
:
return
False
...
...
@@ -408,14 +446,15 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
num_dp
=
self
.
num_dispatchers
assert
num_dp
is
not
None
experts_per_worker
=
self
.
moe_config
.
num_local_experts
activation_out_dim
=
self
.
adjust_N_for_activation
(
N
,
activation
)
workspace1
=
(
self
.
max_
experts_per_worker
,
M
*
num_dp
,
max
(
N
,
K
))
workspace1
=
(
experts_per_worker
,
M
*
num_dp
,
max
(
N
,
K
))
workspace2
=
(
self
.
max_
experts_per_worker
,
experts_per_worker
,
M
*
num_dp
,
max
(
activation_out_dim
,
K
),
)
output
=
(
self
.
max_
experts_per_worker
,
M
,
K
)
output
=
(
experts_per_worker
,
M
,
K
)
return
(
workspace1
,
workspace2
,
output
)
...
...
@@ -601,34 +640,41 @@ def run_cutlass_moe_fp4(
return
# Split into batched and non-batched
class
CutlassExpertsFp4
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
max_experts_per_worker
:
int
,
out_dtype
:
torch
.
dtype
,
quant_config
:
FusedMoEQuantConfig
,
use_batched_format
:
bool
=
False
,
):
super
().
__init__
(
quant_config
)
self
.
max_experts_per_worker
=
max_experts_per_worker
self
.
out_dtype
=
out_dtype
self
.
use_batched_format
=
use_batched_format
@
staticmethod
def
expects_unquantized_inputs
(
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
)
->
bool
:
return
True
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
if
self
.
use_batched_format
:
return
(
mk
.
FusedMoEActivationFormat
.
BatchedExperts
,
mk
.
FusedMoEActivationFormat
.
BatchedExperts
,
)
else
:
return
(
mk
.
FusedMoEActivationFormat
.
Standard
,
mk
.
FusedMoEActivationFormat
.
Standard
,
)
@
staticmethod
def
_supports_current_device
()
->
bool
:
return
current_platform
.
has_device_capability
((
10
,
0
))
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
return
False
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
return
(
weight_key
,
activation_key
)
==
(
kNvfp4Static
,
kNvfp4Dynamic
)
@
staticmethod
def
_supports_activation
(
activation
:
str
)
->
bool
:
return
activation
in
[
"silu"
,
"gelu"
,
"swigluoai"
]
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
# CutlassExpertsFp4 does not support expert map, which is
# needed for STANDARD activation format kernels in EP mode.
return
moe_parallel_config
.
ep_size
==
1
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
def
supports_expert_map
(
self
)
->
bool
:
return
False
...
...
@@ -640,7 +686,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
return
TopKWeightAndReduceNoOP
()
def
workspace_dtype
(
self
,
act_dtype
:
torch
.
dtype
)
->
torch
.
dtype
:
return
self
.
out_dtype
if
self
.
out_dtype
is
not
None
else
act_dtype
return
act_dtype
def
workspace_shapes
(
self
,
...
...
@@ -653,18 +699,9 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
activation
:
str
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
activation_out_dim
=
self
.
adjust_N_for_activation
(
N
,
activation
)
workspace1
:
tuple
[
int
,
...]
=
()
workspace2
:
tuple
[
int
,
...]
=
()
output
:
tuple
[
int
,
...]
=
()
if
self
.
use_batched_format
:
workspace1
=
(
self
.
max_experts_per_worker
,
M
,
max
(
N
,
K
))
workspace2
=
(
self
.
max_experts_per_worker
,
M
,
activation_out_dim
)
output
=
(
self
.
max_experts_per_worker
,
M
,
K
)
else
:
workspace1
=
(
M
*
topk
,
max
(
2
*
N
,
K
))
workspace2
=
(
M
*
topk
,
N
)
output
=
(
M
,
K
)
workspace1
=
(
M
*
topk
,
max
(
2
*
N
,
K
))
workspace2
=
(
M
*
topk
,
N
)
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
)
def
apply
(
...
...
@@ -869,10 +906,11 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
c_strides2
:
torch
.
Tensor
,
s_strides1
:
torch
.
Tensor
,
s_strides2
:
torch
.
Tensor
,
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
,
group_size
:
int
,
):
super
().
__init__
(
quant_config
)
super
().
__init__
(
moe_config
=
moe_config
,
quant_config
=
quant_config
)
self
.
out_dtype
=
out_dtype
self
.
a_strides1
=
a_strides1
self
.
a_strides2
=
a_strides2
...
...
@@ -884,13 +922,46 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
self
.
s_strides2
=
s_strides2
self
.
group_size
=
group_size
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
return
(
mk
.
FusedMoEActivationFormat
.
Standard
,
mk
.
FusedMoEActivationFormat
.
Standard
,
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
@
staticmethod
def
_supports_current_device
()
->
bool
:
raise
NotImplementedError
(
"CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
"This method should not be called."
)
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
raise
NotImplementedError
(
"CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
"This method should not be called."
)
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
raise
NotImplementedError
(
"CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
"This method should not be called."
)
@
staticmethod
def
_supports_activation
(
activation
:
str
)
->
bool
:
raise
NotImplementedError
(
"CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
"This method should not be called."
)
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
raise
NotImplementedError
(
"CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
"This method should not be called."
)
def
supports_chunking
(
self
)
->
bool
:
...
...
@@ -947,7 +1018,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
expert_num_tokens
=
None
use_batched_format
=
(
self
.
activation_format
s
[
0
]
==
mk
.
FusedMoEActivationFormat
.
BatchedExperts
self
.
activation_format
()
==
mk
.
FusedMoEActivationFormat
.
BatchedExperts
)
assert
not
use_batched_format
,
"batched format not supported"
...
...
@@ -1003,6 +1074,7 @@ def cutlass_moe_w4a8_fp8(
s_strides1
:
torch
.
Tensor
,
s_strides2
:
torch
.
Tensor
,
quant_config
:
FusedMoEQuantConfig
,
moe_config
:
FusedMoEConfig
,
activation
:
str
=
"silu"
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
...
...
@@ -1076,6 +1148,7 @@ def cutlass_moe_w4a8_fp8(
c_strides2
=
c_strides2
,
s_strides1
=
s_strides1
,
s_strides2
=
s_strides2
,
moe_config
=
moe_config
,
quant_config
=
quant_config
,
group_size
=
group_size
,
),
...
...
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
View file @
42135d68
...
...
@@ -6,17 +6,15 @@ import torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
fp8_w8a8_moe_quant_config
,
)
from
vllm.model_executor.layers.fused_moe.deep_gemm_utils
import
(
compute_aligned_M
,
deepgemm_moe_permute
,
deepgemm_unpermute_and_reduce
,
)
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
)
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceNoOP
,
)
...
...
@@ -26,9 +24,15 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8_packed_for_deepgemm
,
silu_mul_per_token_group_quant_fp8_colmajor
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kFp8Dynamic128Sym
,
kFp8Static128BlockSym
,
)
from
vllm.utils.deep_gemm
import
(
DeepGemmQuantScaleFMT
,
get_mk_alignment_for_contiguous_layout
,
is_deep_gemm_supported
,
m_grouped_fp8_gemm_nt_contiguous
,
)
from
vllm.utils.import_utils
import
has_deep_gemm
...
...
@@ -109,21 +113,42 @@ def _valid_deep_gemm(
class
DeepGemmExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
quant_config
:
FusedMoEQuantConfig
):
super
().
__init__
(
quant_config
)
def
__init__
(
self
,
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
):
super
().
__init__
(
moe_config
=
moe_config
,
quant_config
=
quant_config
)
assert
quant_config
.
block_shape
==
get_mk_alignment_for_contiguous_layout
()
assert
quant_config
.
quant_dtype
==
torch
.
float8_e4m3fn
assert
not
quant_config
.
per_act_token_quant
assert
not
quant_config
.
per_out_ch_quant
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
return
(
mk
.
FusedMoEActivationFormat
.
Standard
,
mk
.
FusedMoEActivationFormat
.
Standard
,
)
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
@
staticmethod
def
_supports_current_device
()
->
bool
:
return
is_deep_gemm_supported
()
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
return
False
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
SUPPORTED_W_A
=
[
(
kFp8Static128BlockSym
,
kFp8Dynamic128Sym
),
]
return
(
weight_key
,
activation_key
)
in
SUPPORTED_W_A
@
staticmethod
def
_supports_activation
(
activation
:
str
)
->
bool
:
return
activation
in
[
"silu"
]
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
return
True
def
supports_chunking
(
self
)
->
bool
:
return
True
...
...
@@ -283,82 +308,3 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_map
=
expert_map
,
output
=
output
,
)
def
deep_gemm_moe_fp8
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
a1_scale
:
torch
.
Tensor
|
None
=
None
,
a2_scale
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with DeepGemm
grouped gemm.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1 (torch.Tensor): The first set of fp8 quantized expert weights.
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2 (torch.Tensor): The second set of fp8 quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The token->expert mapping for topk_weights.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
Returns:
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
"""
quant_config
=
fp8_w8a8_moe_quant_config
(
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
get_mk_alignment_for_contiguous_layout
(),
)
fn
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
DeepGemmExperts
(
quant_config
),
)
return
fn
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
inplace
=
inplace
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
vllm/model_executor/layers/fused_moe/fallback.py
View file @
42135d68
...
...
@@ -6,6 +6,8 @@ from abc import ABC, abstractmethod
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEParallelConfig
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
QuantKey
class
FallbackExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
,
ABC
):
...
...
@@ -16,18 +18,78 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
experts
:
mk
.
FusedMoEPermuteExpertsUnpermute
,
fallback_experts
:
mk
.
FusedMoEPermuteExpertsUnpermute
,
):
super
().
__init__
(
experts
.
quant_config
)
super
().
__init__
(
moe_config
=
experts
.
moe_config
,
quant_config
=
experts
.
quant_config
)
self
.
fallback_experts
=
fallback_experts
self
.
experts
=
experts
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
assert
(
self
.
fallback_experts
.
activation_formats
==
self
.
experts
.
activation_formats
@
staticmethod
def
get_clses
()
->
tuple
[
type
[
mk
.
FusedMoEPermuteExpertsUnpermute
],
type
[
mk
.
FusedMoEPermuteExpertsUnpermute
],
]:
"""
Get the cls for the experts and fallback experts.
Subclasses should implement this method, so that
we have a consistent way to call the _supports_*
class methods below.
"""
raise
NotImplementedError
(
"Subclasses must return the cls for the experts and fallback experts."
)
@
classmethod
def
activation_format
(
cls
:
type
[
"FallbackExperts"
],
)
->
mk
.
FusedMoEActivationFormat
:
experts_cls
,
fallback_cls
=
cls
.
get_clses
()
assert
experts_cls
.
activation_format
()
==
fallback_cls
.
activation_format
()
return
experts_cls
.
activation_format
()
@
classmethod
def
_supports_current_device
(
cls
)
->
bool
:
experts_cls
,
fallback_cls
=
cls
.
get_clses
()
return
(
experts_cls
.
_supports_current_device
()
and
fallback_cls
.
_supports_current_device
()
)
@
classmethod
def
_supports_no_act_and_mul
(
cls
)
->
bool
:
experts_cls
,
fallback_cls
=
cls
.
get_clses
()
return
(
experts_cls
.
_supports_no_act_and_mul
()
and
fallback_cls
.
_supports_no_act_and_mul
()
)
return
self
.
fallback_experts
.
activation_formats
@
classmethod
def
_supports_quant_scheme
(
cls
,
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
experts_cls
,
fallback_cls
=
cls
.
get_clses
()
return
experts_cls
.
_supports_quant_scheme
(
weight_key
,
activation_key
)
and
fallback_cls
.
_supports_quant_scheme
(
weight_key
,
activation_key
)
@
classmethod
def
_supports_activation
(
cls
,
activation
:
str
)
->
bool
:
experts_cls
,
fallback_cls
=
cls
.
get_clses
()
return
experts_cls
.
_supports_activation
(
activation
)
and
fallback_cls
.
_supports_activation
(
activation
)
@
classmethod
def
_supports_parallel_config
(
cls
,
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
experts_cls
,
fallback_cls
=
cls
.
get_clses
()
return
experts_cls
.
_supports_parallel_config
(
moe_parallel_config
)
and
fallback_cls
.
_supports_parallel_config
(
moe_parallel_config
)
def
supports_chunking
(
self
)
->
bool
:
assert
(
...
...
vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
View file @
42135d68
...
...
@@ -6,13 +6,22 @@ import torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
envs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceDelegate
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kNvfp4Dynamic
,
kNvfp4Static
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
(
flashinfer_cutedsl_grouped_gemm_nt_masked
,
has_flashinfer_cutedsl_grouped_gemm_nt_masked
,
scaled_fp4_grouped_quantize
,
silu_and_mul_scaled_nvfp4_experts_quantize
,
)
...
...
@@ -20,54 +29,54 @@ from vllm.utils.flashinfer import (
logger
=
init_logger
(
__name__
)
def
is_valid_flashinfer_cutedsl_fused_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
)
->
bool
:
"""
Check if the given problem size is supported by the FlashInfer CuteDSL MoE
kernel.
"""
if
not
has_flashinfer_cutedsl_grouped_gemm_nt_masked
():
logger
.
debug_once
(
"FlashInferCuteDSLExperts disabled: "
"flashinfer_cutedsl_fused_moe not available."
)
return
False
# Data type checks
if
(
w1
.
dtype
!=
torch
.
uint8
or
w2
.
dtype
!=
torch
.
uint8
or
hidden_states
.
dtype
not
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
):
logger
.
debug_once
(
"FlashInferCuteDSLExperts disabled: w1/w2 must be torch.uint8 "
f
"(got w1=
{
w1
.
dtype
}
, w2=
{
w2
.
dtype
}
), hidden_states must be "
f
"float32, float16, or bfloat16 (got
{
hidden_states
.
dtype
}
)."
)
return
False
return
True
class
FlashInferCuteDSLExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
out_dtype
:
torch
.
dtype
,
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
,
max_num_tokens
:
int
,
num_dispatchers
:
int
,
):
super
().
__init__
(
quant_config
)
super
().
__init__
(
moe_config
=
moe_config
,
quant_config
=
quant_config
,
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
num_dispatchers
,
)
assert
quant_config
.
quant_dtype
==
"nvfp4"
,
(
"Only nvfp4 quantization are currently supported."
)
self
.
out_dtype
=
out
_dtype
self
.
out_dtype
=
moe_config
.
in
_dtype
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
return
(
mk
.
FusedMoEActivationFormat
.
BatchedExperts
,
mk
.
FusedMoEActivationFormat
.
BatchedExperts
,
)
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
BatchedExperts
@
staticmethod
def
_supports_current_device
()
->
bool
:
return
current_platform
.
is_device_capability_family
(
100
)
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
return
False
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
SUPPORTED_W_A
=
[
(
kNvfp4Static
,
kNvfp4Dynamic
),
]
return
(
weight_key
,
activation_key
)
in
SUPPORTED_W_A
@
staticmethod
def
_supports_activation
(
activation
:
str
)
->
bool
:
return
activation
in
[
"silu"
]
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
return
True
def
supports_expert_map
(
self
)
->
bool
:
return
False
...
...
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
View file @
42135d68
...
...
@@ -5,13 +5,22 @@ import torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize
import
(
# noqa: E501
create_flashinfer_prepare_finalize
,
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceNoOP
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kFp8Dynamic128Sym
,
kFp8Static128BlockSym
,
kFp8StaticTensorSym
,
kNvfp4Dynamic
,
kNvfp4Static
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
(
flashinfer_cutlass_fused_moe
,
has_flashinfer_cutlass_fused_moe
,
...
...
@@ -50,40 +59,100 @@ def is_valid_flashinfer_cutlass_fused_moe(
class
FlashInferExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
out_dtype
:
torch
.
dtype
,
moe_config
:
mk
.
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
,
ep_rank
:
int
=
0
,
ep_size
:
int
=
1
,
tp_rank
:
int
=
0
,
tp_size
:
int
=
1
,
use_dp
:
bool
=
False
,
use_deepseek_fp8_block_scale
:
bool
=
False
,
):
super
().
__init__
(
quant_config
)
super
().
__init__
(
moe_config
,
quant_config
)
assert
quant_config
.
quant_dtype
in
(
"nvfp4"
,
torch
.
float8_e4m3fn
,
None
),
(
"Only nvfp4, fp8, bfloat16 and"
" float16 quantization are currently supported."
)
self
.
ep_rank
=
ep_rank
self
.
ep_size
=
ep_size
self
.
tp_rank
=
tp_rank
self
.
tp_size
=
tp_size
self
.
out_dtype
=
out
_dtype
self
.
use_dp
=
use_dp
self
.
ep_rank
=
moe_config
.
moe_parallel_config
.
ep_rank
self
.
ep_size
=
moe_config
.
moe_parallel_config
.
ep_size
self
.
tp_rank
=
moe_config
.
moe_parallel_config
.
tp_rank
self
.
tp_size
=
moe_config
.
moe_parallel_config
.
tp_size
self
.
out_dtype
=
moe_config
.
in
_dtype
self
.
use_dp
=
moe_config
.
moe_parallel_config
.
dp_size
>
1
# Enables DeepSeek-style FP8 block-scale path:
# - pass per-block weight scales to the kernel
# - skip input activation quantization (kernel applies scaling)
self
.
use_deepseek_fp8_block_scale
=
use_deepseek_fp8_block_scale
self
.
use_deepseek_fp8_block_scale
=
quant_config
.
is_block_quantized
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
@
staticmethod
def
expects_unquantized_inputs
(
moe_config
:
mk
.
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
)
->
bool
:
# NVFP4 TP kernels and FP8 block-quantized kernels apply
# input quantization inside FusedMoEPermuteExpertsUnpermute.
return
(
quant_config
.
use_nvfp4_w4a4
and
not
moe_config
.
moe_parallel_config
.
use_all2all_kernels
)
or
(
quant_config
.
use_fp8_w8a8
and
quant_config
.
is_block_quantized
)
@
staticmethod
def
_supports_current_device
()
->
bool
:
return
(
mk
.
FusedMoEActivationFormat
.
Standard
,
mk
.
FusedMoEActivationFormat
.
Standard
,
current_platform
.
is_cuda
()
and
(
current_platform
.
is_device_capability
((
9
,
0
))
or
current_platform
.
is_device_capability_family
(
100
)
)
and
has_flashinfer_cutlass_fused_moe
()
)
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
return
True
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
# The following are supported by FlashInferExperts:
# * unquantized
# * fp8 static per-tensor on 9.0+
# * fp8 block on 9.0
# * nvfp4 on 10.0+
p
=
current_platform
scheme
=
(
weight_key
,
activation_key
)
return
(
(
scheme
in
[
(
None
,
None
),
(
kFp8StaticTensorSym
,
kFp8StaticTensorSym
),
]
)
or
(
(
scheme
==
(
kFp8Static128BlockSym
,
kFp8Dynamic128Sym
))
and
(
p
.
is_device_capability
((
9
,
0
)))
)
or
(
(
scheme
==
(
kNvfp4Static
,
kNvfp4Dynamic
))
and
(
p
.
is_device_capability_family
(
100
))
)
)
@
staticmethod
def
_supports_activation
(
activation
:
str
)
->
bool
:
return
activation
in
[
"silu"
,
"relu2_no_mul"
]
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
# FLASHINFER_CUTLASS currently uses its down P/F, which does not
# work with SP. This will be removed in follow up after we get
# rid of the FlashInfer specific P/F function.
return
(
moe_parallel_config
.
dp_size
==
1
or
moe_parallel_config
.
dp_size
==
moe_parallel_config
.
ep_size
)
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
def
supports_expert_map
(
self
)
->
bool
:
return
False
...
...
@@ -231,85 +300,3 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# No support for LoRA in flashinfer_cutlass_fused_moe.
# See TODOs in flashinfer functions runMoe and runMoeMinLantency.
raise
NotImplementedError
(
"LoRA is not supported for flashinfer_cutlass_moe"
)
def
flashinfer_cutlass_moe_fp4
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
quant_config
:
FusedMoEQuantConfig
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
)
->
torch
.
Tensor
:
fused_experts
=
mk
.
FusedMoEModularKernel
(
create_flashinfer_prepare_finalize
(
use_dp
=
False
,
use_nvfp4
=
True
,
enable_alltoallv
=
False
),
FlashInferExperts
(
out_dtype
=
hidden_states
.
dtype
,
quant_config
=
quant_config
,
use_dp
=
False
,
),
)
return
fused_experts
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
inplace
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
def
flashinfer_cutlass_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
quant_config
:
FusedMoEQuantConfig
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
tp_rank
:
int
=
0
,
tp_size
:
int
=
1
,
ep_rank
:
int
=
0
,
ep_size
:
int
=
1
,
use_dp
:
bool
=
False
,
)
->
torch
.
Tensor
:
fused_experts
=
mk
.
FusedMoEModularKernel
(
create_flashinfer_prepare_finalize
(
use_dp
=
use_dp
),
FlashInferExperts
(
out_dtype
=
hidden_states
.
dtype
,
quant_config
=
quant_config
,
tp_rank
=
tp_rank
,
tp_size
=
tp_size
,
ep_rank
=
ep_rank
,
ep_size
=
ep_size
,
),
)
return
fused_experts
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
inplace
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
View file @
42135d68
...
...
@@ -3,7 +3,12 @@
import
torch
from
vllm.model_executor.layers.fused_moe.config
import
RoutingMethodType
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEParallelConfig
,
RoutingMethodType
,
)
from
vllm.model_executor.layers.fused_moe.utils
import
moe_kernel_quantize_input
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
calculate_tile_tokens_dim
,
...
...
@@ -11,8 +16,107 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kFp8Dynamic128Sym
,
kFp8Static128BlockSym
,
kFp8StaticTensorSym
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
direct_register_custom_op
#
# Methods used by the oracle for kernel selection.
#
def
_supports_current_device
()
->
bool
:
"""Supports only Blackwell-family GPUs."""
p
=
current_platform
# Add check flashinfer trtllm is available
return
p
.
is_cuda
()
and
p
.
is_device_capability_family
(
100
)
def
_supports_no_act_and_mul
()
->
bool
:
"""Does not support non-gated MoE (i.e. Nanotron-Mini)."""
return
False
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
"""Supports Fp8 per-tensor and Fp8 block."""
SUPPORTED_W_A
=
[
(
kFp8Static128BlockSym
,
kFp8Dynamic128Sym
),
(
kFp8StaticTensorSym
,
kFp8StaticTensorSym
),
]
return
(
weight_key
,
activation_key
)
in
SUPPORTED_W_A
def
_supports_activation
(
activation
:
str
)
->
bool
:
"""Supports silu activation only."""
return
activation
in
[
"silu"
]
def
_supports_routing_method
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
routing_method
:
RoutingMethodType
,
)
->
bool
:
"""Monolithic kernels need to express router support."""
if
(
weight_key
,
activation_key
)
==
(
kFp8Static128BlockSym
,
kFp8Dynamic128Sym
):
# NOTE(rob): potentially allow others here. This is a conservative list.
return
routing_method
in
[
RoutingMethodType
.
DeepSeekV3
,
RoutingMethodType
.
Renormalize
,
RoutingMethodType
.
RenormalizeNaive
,
]
elif
(
weight_key
,
activation_key
)
==
(
kFp8StaticTensorSym
,
kFp8StaticTensorSym
):
# NOTE(rob): kernel requires Llama4.
return
routing_method
==
RoutingMethodType
.
Llama4
else
:
raise
ValueError
(
"Unsupported quantization scheme."
)
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
"""Supports TRTLLM Kernel does not support EPLB."""
return
not
moe_parallel_config
.
enable_eplb
def
is_supported_config_trtllm
(
moe_config
:
FusedMoEConfig
,
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
activation_format
:
mk
.
FusedMoEActivationFormat
,
)
->
tuple
[
bool
,
str
|
None
]:
"""
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
"""
def
_make_reason
(
reason
:
str
)
->
str
:
return
f
"kernel does not support
{
reason
}
"
if
not
_supports_current_device
():
return
False
,
_make_reason
(
"current device"
)
elif
not
(
moe_config
.
is_act_and_mul
or
_supports_no_act_and_mul
()):
return
False
,
_make_reason
(
"no act_and_mul MLP layer"
)
elif
not
_supports_activation
(
moe_config
.
activation
):
return
False
,
_make_reason
(
f
"
{
moe_config
.
activation
}
activation"
)
elif
not
_supports_quant_scheme
(
weight_key
,
activation_key
):
return
False
,
_make_reason
(
"quantization scheme"
)
elif
not
_supports_parallel_config
(
moe_config
.
moe_parallel_config
):
return
False
,
_make_reason
(
"parallel config"
)
elif
not
_supports_routing_method
(
weight_key
,
activation_key
,
moe_config
.
routing_method
):
return
False
,
_make_reason
(
"routing method"
)
elif
activation_format
!=
mk
.
FusedMoEActivationFormat
.
Standard
:
return
False
,
_make_reason
(
"activation format"
)
return
True
,
None
def
flashinfer_fused_moe_blockscale_fp8
(
routing_logits
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
View file @
42135d68
...
...
@@ -5,7 +5,11 @@
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
try_get_optimal_moe_config
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceDelegate
,
...
...
@@ -17,7 +21,17 @@ from vllm.model_executor.layers.fused_moe.utils import (
normalize_batched_scales_shape
,
normalize_scales_shape
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
group_broadcast
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
group_broadcast
,
kFp8Dynamic128Sym
,
kFp8DynamicTensorSym
,
kFp8DynamicTokenSym
,
kFp8Static128BlockSym
,
kFp8StaticChannelSym
,
kFp8StaticTensorSym
,
)
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
...
...
@@ -633,25 +647,62 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
__init__
(
self
,
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
,
max_num_tokens
:
int
,
num_dispatchers
:
int
,
quant_config
:
FusedMoEQuantConfig
,
):
super
().
__init__
(
quant_config
)
super
().
__init__
(
moe_config
=
moe_config
,
quant_config
=
quant_config
,
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
num_dispatchers
,
)
assert
not
self
.
quant_config
.
use_int8_w8a8
,
"NYI"
assert
not
self
.
quant_config
.
use_int8_w8a16
,
"NYI"
assert
not
self
.
quant_config
.
use_int4_w4a16
,
"NYI"
assert
self
.
quant_config
.
ocp_mx_scheme
is
None
,
"NYI"
self
.
max_num_tokens
=
max_num_tokens
self
.
num_dispatchers
=
num_dispatchers
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
return
(
mk
.
FusedMoEActivationFormat
.
BatchedExperts
,
mk
.
FusedMoEActivationFormat
.
BatchedExperts
,
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
BatchedExperts
@
staticmethod
def
_supports_current_device
()
->
bool
:
raise
NotImplementedError
(
"NaiveBatchedExperts is not yet used by an Oracle. "
"This method should not be called."
)
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
raise
NotImplementedError
(
"NaiveBatchedExperts is not yet used by an Oracle. "
"This method should not be called."
)
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
raise
NotImplementedError
(
"NaiveBatchedExperts is not yet used by an Oracle. "
"This method should not be called."
)
@
staticmethod
def
_supports_activation
(
activation
:
str
)
->
bool
:
raise
NotImplementedError
(
"NaiveBatchedExperts is not yet used by an Oracle. "
"This method should not be called."
)
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
raise
NotImplementedError
(
"NaiveBatchedExperts is not yet used by an Oracle. "
"This method should not be called."
)
def
supports_chunking
(
self
)
->
bool
:
...
...
@@ -675,6 +726,8 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
activation
:
str
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
assert
self
.
num_dispatchers
is
not
None
assert
self
.
max_num_tokens
is
not
None
num_dp
=
self
.
num_dispatchers
num_experts
=
local_num_experts
workspace13
=
(
num_experts
,
self
.
max_num_tokens
*
num_dp
,
K
)
...
...
@@ -826,29 +879,69 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
__init__
(
self
,
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
,
max_num_tokens
:
int
,
num_dispatchers
:
int
,
quant_config
:
FusedMoEQuantConfig
,
):
super
().
__init__
(
quant_config
)
super
().
__init__
(
moe_config
=
moe_config
,
quant_config
=
quant_config
,
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
num_dispatchers
,
)
assert
not
self
.
quant_config
.
use_int8_w8a8
,
"NYI"
assert
not
self
.
quant_config
.
use_int8_w8a16
,
"NYI"
assert
not
self
.
quant_config
.
use_int4_w4a16
,
"NYI"
assert
self
.
quant_config
.
ocp_mx_scheme
is
None
,
"NYI"
assert
max_num_tokens
>
0
assert
num_dispatchers
>
0
self
.
max_num_tokens
=
max_num_tokens
self
.
num_dispatchers
=
num_dispatchers
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
return
(
mk
.
FusedMoEActivationFormat
.
BatchedExperts
,
mk
.
FusedMoEActivationFormat
.
BatchedExperts
,
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
BatchedExperts
@
staticmethod
def
_supports_current_device
()
->
bool
:
return
current_platform
.
is_cuda_alike
()
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
return
False
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
p
=
current_platform
device_supports_fp8
=
(
p
.
is_rocm
()
and
p
.
rocm
.
on_gfx9
())
or
(
p
.
is_cuda
()
and
p
.
has_device_capability
((
8
,
9
))
)
SUPPORTED_W_A_FP8
=
[
(
kFp8Static128BlockSym
,
kFp8Dynamic128Sym
),
(
kFp8StaticChannelSym
,
kFp8DynamicTokenSym
),
(
kFp8StaticTensorSym
,
kFp8StaticTensorSym
),
(
kFp8StaticTensorSym
,
kFp8DynamicTensorSym
),
]
return
(
weight_key
,
activation_key
)
==
(
None
,
None
)
or
(
device_supports_fp8
and
(
weight_key
,
activation_key
)
in
SUPPORTED_W_A_FP8
)
@
staticmethod
def
_supports_activation
(
activation
:
str
)
->
bool
:
return
activation
in
[
"silu"
,
"gelu"
,
"swigluoai"
,
"silu_no_mul"
,
"gelu_no_mul"
,
"relu2_no_mul"
,
]
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
return
True
def
supports_chunking
(
self
)
->
bool
:
return
False
...
...
@@ -870,6 +963,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
activation
:
str
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
assert
self
.
num_dispatchers
is
not
None
assert
self
.
max_num_tokens
is
not
None
num_dp
=
self
.
num_dispatchers
num_experts
=
local_num_experts
max_num_tokens
=
self
.
max_num_tokens
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
42135d68
...
...
@@ -8,7 +8,11 @@ import torch
import
vllm._custom_ops
as
ops
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
batched_moe_align_block_size
,
moe_align_block_size
,
...
...
@@ -27,6 +31,13 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_moe_intermediate_size
,
marlin_quant_input
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kFp8Static128BlockSym
,
kFp8StaticChannelSym
,
kFp8StaticTensorSym
,
kNvfp4Static
,
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
...
...
@@ -522,7 +533,10 @@ def batched_fused_marlin_moe(
class
MarlinExpertsBase
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
,
max_num_tokens
:
int
|
None
=
None
,
num_dispatchers
:
int
|
None
=
None
,
w13_g_idx
:
torch
.
Tensor
|
None
=
None
,
w2_g_idx
:
torch
.
Tensor
|
None
=
None
,
w13_g_idx_sort_indices
:
torch
.
Tensor
|
None
=
None
,
...
...
@@ -541,7 +555,51 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
self
.
w13_g_idx_sort_indices
=
w13_g_idx_sort_indices
self
.
w2_g_idx_sort_indices
=
w2_g_idx_sort_indices
self
.
is_k_full
=
is_k_full
super
().
__init__
(
quant_config
)
super
().
__init__
(
moe_config
=
moe_config
,
quant_config
=
quant_config
,
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
num_dispatchers
,
)
@
staticmethod
def
_supports_current_device
()
->
bool
:
p
=
current_platform
return
p
.
is_cuda
()
and
p
.
has_device_capability
((
7
,
5
))
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
return
False
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
# TODO(rob): add int4, mxfp4, int8 as integrations
# are migrated to use the oracle one-by-one.
SUPPORTED_W
=
[
kFp8Static128BlockSym
,
kFp8StaticChannelSym
,
kFp8StaticTensorSym
,
kNvfp4Static
,
]
return
weight_key
in
SUPPORTED_W
@
staticmethod
def
_supports_activation
(
activation
:
str
)
->
bool
:
return
activation
in
[
"silu"
,
"gelu"
,
"swigluoai"
,
"silu_no_mul"
,
"gelu_no_mul"
,
"relu2_no_mul"
,
]
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
return
True
@
property
def
quant_type_id
(
self
)
->
int
:
...
...
@@ -587,38 +645,15 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
class
MarlinExperts
(
MarlinExpertsBase
):
def
__init__
(
self
,
quant_config
:
FusedMoEQuantConfig
,
w13_g_idx
:
torch
.
Tensor
|
None
=
None
,
w2_g_idx
:
torch
.
Tensor
|
None
=
None
,
w13_g_idx_sort_indices
:
torch
.
Tensor
|
None
=
None
,
w2_g_idx_sort_indices
:
torch
.
Tensor
|
None
=
None
,
is_k_full
:
bool
=
True
,
):
super
().
__init__
(
quant_config
,
w13_g_idx
,
w2_g_idx
,
w13_g_idx_sort_indices
,
w2_g_idx_sort_indices
,
is_k_full
,
)
def
supports_expert_map
(
self
)
->
bool
:
return
True
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
return
TopKWeightAndReduceNoOP
()
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
return
(
mk
.
FusedMoEActivationFormat
.
Standard
,
mk
.
FusedMoEActivationFormat
.
Standard
,
)
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
def
supports_chunking
(
self
)
->
bool
:
return
True
...
...
@@ -714,9 +749,10 @@ class MarlinExperts(MarlinExpertsBase):
class
BatchedMarlinExperts
(
MarlinExpertsBase
):
def
__init__
(
self
,
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
,
max_num_tokens
:
int
,
num_dispatchers
:
int
,
quant_config
:
FusedMoEQuantConfig
,
w13_g_idx
:
torch
.
Tensor
|
None
=
None
,
w2_g_idx
:
torch
.
Tensor
|
None
=
None
,
w13_g_idx_sort_indices
:
torch
.
Tensor
|
None
=
None
,
...
...
@@ -724,15 +760,16 @@ class BatchedMarlinExperts(MarlinExpertsBase):
is_k_full
:
bool
=
True
,
):
super
().
__init__
(
quant_config
,
w13_g_idx
,
w2_g_idx
,
w13_g_idx_sort_indices
,
w2_g_idx_sort_indices
,
is_k_full
,
moe_config
=
moe_config
,
quant_config
=
quant_config
,
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
num_dispatchers
,
w13_g_idx
=
w13_g_idx
,
w2_g_idx
=
w2_g_idx
,
w13_g_idx_sort_indices
=
w13_g_idx_sort_indices
,
w2_g_idx_sort_indices
=
w2_g_idx_sort_indices
,
is_k_full
=
is_k_full
,
)
self
.
max_num_tokens
=
max_num_tokens
self
.
num_dispatchers
=
num_dispatchers
def
supports_expert_map
(
self
)
->
bool
:
return
True
...
...
@@ -740,14 +777,9 @@ class BatchedMarlinExperts(MarlinExpertsBase):
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
return
TopKWeightAndReduceDelegate
()
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
return
(
mk
.
FusedMoEActivationFormat
.
BatchedExperts
,
mk
.
FusedMoEActivationFormat
.
BatchedExperts
,
)
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
BatchedExperts
def
supports_chunking
(
self
)
->
bool
:
return
False
...
...
@@ -763,9 +795,11 @@ class BatchedMarlinExperts(MarlinExpertsBase):
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
activation
:
str
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
assert
self
.
num_dispatchers
is
not
None
assert
self
.
max_num_tokens
is
not
None
num_dispatchers
=
self
.
num_dispatchers
num_experts
=
local_num_experts
max_num_tokens
=
M
if
self
.
max_num_tokens
is
None
else
self
.
max_num_tokens
max_num_tokens
=
self
.
max_num_tokens
workspace13
=
(
num_experts
*
max_num_tokens
*
num_dispatchers
,
max
(
K
,
N
*
2
))
workspace2
=
(
num_experts
*
max_num_tokens
*
num_dispatchers
,
N
)
output
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
K
)
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
42135d68
...
...
@@ -19,13 +19,11 @@ from vllm.model_executor.layers.batch_invariant import (
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FUSED_MOE_UNQUANTIZED_CONFIG
,
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
_get_config_dtype_str
,
)
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
_valid_deep_gemm
,
deep_gemm_moe_fp8
,
)
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
,
)
...
...
@@ -44,9 +42,16 @@ from vllm.model_executor.layers.fused_moe.utils import (
from
vllm.model_executor.layers.quantization.utils.mxfp4_utils
import
dequant_mxfp4
from
vllm.model_executor.layers.quantization.utils.mxfp6_utils
import
dequant_mxfp6
from
vllm.model_executor.layers.quantization.utils.ocp_mx_utils
import
OCP_MX_Scheme
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kFp8Dynamic128Sym
,
kFp8DynamicTokenSym
,
kFp8Static128BlockSym
,
kFp8StaticChannelSym
,
kFp8StaticTensorSym
,
)
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.deep_gemm
import
is_deep_gemm_e8m0_used
from
vllm.utils.torch_utils
import
direct_register_custom_op
,
is_torch_equal_or_newer
logger
=
init_logger
(
__name__
)
...
...
@@ -1534,66 +1539,36 @@ def fused_experts(
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
quant_config
:
FusedMoEQuantConfig
|
None
=
None
,
allow_deep_gemm
:
bool
=
False
,
)
->
torch
.
Tensor
:
if
quant_config
is
None
:
quant_config
=
FUSED_MOE_UNQUANTIZED_CONFIG
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
# However, on B200, we use DeepGemm for all cases because they only support
# E8M0 scale, which means we requantize the weight and input to the specific
# scale. Fallen back to cutlass or triton for some cases would cause
# accuracy issue.
if
(
allow_deep_gemm
and
quant_config
.
use_fp8_w8a8
and
(
is_deep_gemm_e8m0_used
()
or
_valid_deep_gemm
(
hidden_states
,
w1
,
w2
))
):
assert
quant_config
is
not
None
return
deep_gemm_moe_fp8
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
inplace
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
quant_config
.
w1_scale
,
w2_scale
=
quant_config
.
w2_scale
,
a1_scale
=
quant_config
.
a1_scale
,
a2_scale
=
quant_config
.
a2_scale
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
else
:
return
dispatch_fused_experts_func
(
inplace
)(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
quant_config
.
use_fp8_w8a8
,
use_int8_w8a8
=
quant_config
.
use_int8_w8a8
,
use_int8_w8a16
=
quant_config
.
use_int8_w8a16
,
use_int4_w4a16
=
quant_config
.
use_int4_w4a16
,
ocp_mx_scheme
=
quant_config
.
ocp_mx_scheme
,
per_channel_quant
=
quant_config
.
per_act_token_quant
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
quant_config
.
w1_scale
,
w2_scale
=
quant_config
.
w2_scale
,
w1_zp
=
quant_config
.
w1_zp
,
w2_zp
=
quant_config
.
w2_zp
,
a1_scale
=
quant_config
.
a1_scale
,
a2_scale
=
quant_config
.
a2_scale
,
block_shape
=
quant_config
.
block_shape
,
w1_bias
=
quant_config
.
w1_bias
,
w2_bias
=
quant_config
.
w2_bias
,
)
return
dispatch_fused_experts_func
(
inplace
)(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
quant_config
.
use_fp8_w8a8
,
use_int8_w8a8
=
quant_config
.
use_int8_w8a8
,
use_int8_w8a16
=
quant_config
.
use_int8_w8a16
,
use_int4_w4a16
=
quant_config
.
use_int4_w4a16
,
ocp_mx_scheme
=
quant_config
.
ocp_mx_scheme
,
per_channel_quant
=
quant_config
.
per_act_token_quant
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
quant_config
.
w1_scale
,
w2_scale
=
quant_config
.
w2_scale
,
w1_zp
=
quant_config
.
w1_zp
,
w2_zp
=
quant_config
.
w2_zp
,
a1_scale
=
quant_config
.
a1_scale
,
a2_scale
=
quant_config
.
a2_scale
,
block_shape
=
quant_config
.
block_shape
,
w1_bias
=
quant_config
.
w1_bias
,
w2_bias
=
quant_config
.
w2_bias
,
)
def
_get_config_quant_dtype
(
...
...
@@ -1924,19 +1899,53 @@ def fused_experts_impl(
class
TritonExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
,
):
super
().
__init__
(
quant_config
)
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
return
(
mk
.
FusedMoEActivationFormat
.
Standard
,
mk
.
FusedMoEActivationFormat
.
Standard
,
super
().
__init__
(
moe_config
,
quant_config
)
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
@
staticmethod
def
_supports_current_device
()
->
bool
:
return
current_platform
.
is_cuda_alike
()
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
return
False
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
p
=
current_platform
device_supports_fp8
=
(
p
.
is_rocm
()
and
p
.
rocm
.
on_gfx9
())
or
(
p
.
is_cuda
()
and
p
.
has_device_capability
((
8
,
9
))
)
if
not
device_supports_fp8
:
return
(
weight_key
,
activation_key
)
==
(
None
,
None
)
SUPPORTED_W_A
=
[
(
None
,
None
),
(
kFp8Static128BlockSym
,
kFp8Dynamic128Sym
),
(
kFp8StaticChannelSym
,
kFp8DynamicTokenSym
),
(
kFp8StaticTensorSym
,
kFp8DynamicTokenSym
),
(
kFp8StaticTensorSym
,
kFp8StaticTensorSym
),
]
return
(
weight_key
,
activation_key
)
in
SUPPORTED_W_A
@
staticmethod
def
_supports_activation
(
activation
:
str
)
->
bool
:
return
activation
in
[
"silu"
,
"gelu"
,
"swigluoai"
]
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
return
True
def
supports_chunking
(
self
)
->
bool
:
return
True
...
...
@@ -2111,11 +2120,43 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
class
TritonWNA16Experts
(
TritonExperts
):
def
__init__
(
self
,
quant_config
:
FusedMoEQuantConfig
,
):
super
().
__init__
(
quant_config
)
@
staticmethod
def
_supports_current_device
()
->
bool
:
raise
NotImplementedError
(
"TritonWNA16Experts is not yet used by an Oracle. "
"This method should not be called."
)
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
raise
NotImplementedError
(
"TritonWNA16Experts is not yet used by an Oracle. "
"This method should not be called."
)
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
raise
NotImplementedError
(
"TritonWNA16Experts is not yet used by an Oracle. "
"This method should not be called."
)
@
staticmethod
def
_supports_activation
(
activation
:
str
)
->
bool
:
raise
NotImplementedError
(
"TritonWNA16Experts is not yet used by an Oracle. "
"This method should not be called."
)
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
raise
NotImplementedError
(
"TritonWNA16Experts is not yet used by an Oracle. "
"This method should not be called."
)
def
apply
(
self
,
...
...
@@ -2254,10 +2295,12 @@ class TritonWNA16Experts(TritonExperts):
def
modular_triton_fused_moe
(
quant_config
:
FusedMoEQuantConfig
,
shared_experts
:
torch
.
nn
.
Module
|
None
=
None
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
,
shared_experts
:
torch
.
nn
.
Module
|
None
=
None
,
)
->
mk
.
FusedMoEModularKernel
:
return
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
TritonExperts
(
quant_config
),
TritonExperts
(
moe_config
,
quant_config
),
shared_experts
,
)
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
View file @
42135d68
...
...
@@ -9,12 +9,16 @@ from vllm import _custom_ops as ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
(
FUSED_MOE_UNQUANTIZED_CONFIG
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceNoOP
,
)
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
)
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.import_utils
import
has_triton_kernels
...
...
@@ -241,8 +245,43 @@ def make_routing_data(
class
BaseOAITritonExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
quant_config
:
FusedMoEQuantConfig
):
super
().
__init__
(
quant_config
)
@
staticmethod
def
_supports_current_device
()
->
bool
:
raise
NotImplementedError
(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
raise
NotImplementedError
(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
raise
NotImplementedError
(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
@
staticmethod
def
_supports_activation
(
activation
:
str
)
->
bool
:
raise
NotImplementedError
(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
raise
NotImplementedError
(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
def
supports_expert_map
(
self
)
->
bool
:
return
True
...
...
@@ -297,19 +336,9 @@ class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
class
OAITritonExperts
(
BaseOAITritonExperts
):
def
__init__
(
self
,
quant_config
:
FusedMoEQuantConfig
):
# TODO (varun) : Enable activation quantization
assert
quant_config
.
use_mxfp4_w4a16
,
"Supports only mxfp4_w4a16"
super
().
__init__
(
quant_config
)
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
return
(
mk
.
FusedMoEActivationFormat
.
Standard
,
mk
.
FusedMoEActivationFormat
.
Standard
,
)
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
def
supports_chunking
(
self
)
->
bool
:
return
True
...
...
@@ -391,19 +420,9 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
One use case for it is to inject LoRA modules on the activation and moe_sum.
"""
def
__init__
(
self
,
quant_config
:
FusedMoEQuantConfig
):
# TODO (varun) : Enable activation quantization
assert
quant_config
.
use_mxfp4_w4a16
,
"Supports only mxfp4_w4a16"
super
().
__init__
(
quant_config
)
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
return
(
mk
.
FusedMoEActivationFormat
.
Standard
,
mk
.
FusedMoEActivationFormat
.
Standard
,
)
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
def
supports_chunking
(
self
)
->
bool
:
return
True
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
42135d68
...
...
@@ -330,7 +330,6 @@ class FusedMoE(CustomOp):
is_sequence_parallel
=
False
,
expert_mapping
:
list
[
tuple
[
str
,
str
,
int
,
str
]]
|
None
=
None
,
n_shared_experts
:
int
|
None
=
None
,
routing_method_type
:
RoutingMethodType
|
None
=
None
,
router_logits_dtype
:
torch
.
dtype
|
None
=
None
,
):
super
().
__init__
()
...
...
@@ -519,10 +518,43 @@ class FusedMoE(CustomOp):
self
.
apply_router_weight_on_input
=
apply_router_weight_on_input
self
.
activation
=
activation
# TODO(bnell): in next PR move capture back to layer
capture
:
Callable
[[
torch
.
Tensor
],
None
]
|
None
=
None
if
(
self
.
vllm_config
.
model_config
is
not
None
and
self
.
vllm_config
.
model_config
.
enable_return_routed_experts
):
# In dummy runs, the capturer is not initialized.
capturer
=
RoutedExpertsCapturer
.
get_instance
()
if
capturer
is
not
None
:
capture
=
lambda
topk_ids
:
capturer
.
capture
(
self
.
layer_id
,
topk_ids
)
self
.
router
=
create_fused_moe_router
(
top_k
=
top_k
,
global_num_experts
=
self
.
global_num_experts
,
eplb_state
=
self
.
eplb_state
,
renormalize
=
renormalize
,
use_grouped_topk
=
use_grouped_topk
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
enable_eplb
=
enable_eplb
,
# TODO(bnell): once we can construct the MK at init time, we
# can make this a value.
indices_type_getter
=
lambda
:
self
.
quant_method
.
topk_indices_dtype
,
capture
=
capture
,
)
self
.
routing_method_type
:
RoutingMethodType
=
self
.
router
.
routing_method_type
self
.
moe_config
:
FusedMoEConfig
=
FusedMoEConfig
(
num_experts
=
self
.
global_num_experts
,
experts_per_token
=
top_k
,
hidden_dim
=
hidden_size
,
intermediate_size_per_partition
=
self
.
intermediate_size_per_partition
,
num_local_experts
=
self
.
local_num_experts
,
moe_parallel_config
=
self
.
moe_parallel_config
,
in_dtype
=
moe_in_dtype
,
...
...
@@ -531,6 +563,9 @@ class FusedMoE(CustomOp):
has_bias
=
has_bias
,
is_act_and_mul
=
is_act_and_mul
,
is_lora_enabled
=
vllm_config
.
lora_config
is
not
None
,
activation
=
activation
,
device
=
vllm_config
.
device_config
.
device
,
routing_method
=
self
.
routing_method_type
,
)
self
.
moe_config_use_flashinfer_cutlass_kernels
=
(
self
.
moe_config
.
use_flashinfer_cutlass_kernels
...
...
@@ -594,39 +629,6 @@ class FusedMoE(CustomOp):
self
.
batched_hidden_states
:
torch
.
Tensor
|
None
=
None
self
.
batched_router_logits
:
torch
.
Tensor
|
None
=
None
# TODO(bnell): in next PR move capture back to layer
capture
:
Callable
[[
torch
.
Tensor
],
None
]
|
None
=
None
if
(
self
.
vllm_config
.
model_config
is
not
None
and
self
.
vllm_config
.
model_config
.
enable_return_routed_experts
):
# In dummy runs, the capturer is not initialized.
capturer
=
RoutedExpertsCapturer
.
get_instance
()
if
capturer
is
not
None
:
capture
=
lambda
topk_ids
:
capturer
.
capture
(
self
.
layer_id
,
topk_ids
)
self
.
router
=
create_fused_moe_router
(
top_k
=
top_k
,
global_num_experts
=
self
.
global_num_experts
,
eplb_state
=
self
.
eplb_state
,
renormalize
=
renormalize
,
use_grouped_topk
=
use_grouped_topk
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
enable_eplb
=
enable_eplb
,
# TODO(bnell): once we can construct the MK at init time, we
# can make this a value.
indices_type_getter
=
lambda
:
self
.
quant_method
.
topk_indices_dtype
,
routing_method_type
=
routing_method_type
,
capture
=
capture
,
)
self
.
routing_method_type
:
RoutingMethodType
=
self
.
router
.
routing_method_type
# Note: maybe_init_modular_kernel should only be called by
# prepare_communication_buffer_for_model.
# This is called after all weight loading and post-processing, so it
...
...
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
42135d68
...
...
@@ -13,6 +13,7 @@ import vllm.envs as envs
from
vllm.forward_context
import
get_forward_context
,
is_forward_context_available
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
)
...
...
@@ -22,6 +23,9 @@ from vllm.model_executor.layers.fused_moe.utils import (
count_expert_num_tokens
,
disable_inplace
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
)
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.worker.ubatching
import
(
dbo_enabled
,
...
...
@@ -374,18 +378,51 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
def
__init__
(
self
,
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
,
max_num_tokens
:
int
|
None
=
None
,
num_dispatchers
:
int
|
None
=
None
,
):
"""
moe_config: MoE layer configuration.
quant_config: Quantization parameters for this experts instance.
"""
if
self
.
activation_format
()
==
FusedMoEActivationFormat
.
Standard
and
(
max_num_tokens
is
not
None
or
num_dispatchers
is
not
None
):
raise
ValueError
(
"max_num_tokens and num_dispatchers should only be set for "
"BatchedExperts activation format."
)
elif
self
.
activation_format
()
==
FusedMoEActivationFormat
.
BatchedExperts
and
(
max_num_tokens
is
None
or
num_dispatchers
is
None
):
raise
ValueError
(
"max_num_tokens and num_dispatchers must be set for "
"BatchedExperts activation format."
)
self
.
moe_config
=
moe_config
self
.
quant_config
=
quant_config
self
.
max_num_tokens
=
max_num_tokens
self
.
num_dispatchers
=
num_dispatchers
@
property
@
staticmethod
def
expects_unquantized_inputs
(
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
)
->
bool
:
"""
Whether or not the PrepareFinalize should defer input quantization
in the prepare step. If True, then the Experts kernel will
execute the input quantization itself.
Sample subclasses that override are AITER and FlashInfer CUTLASS.
"""
return
False
@
staticmethod
@
abstractmethod
def
activation_formats
(
self
,
)
->
tuple
[
FusedMoEActivationFormat
,
FusedMoEActivationFormat
]:
def
activation_format
()
->
FusedMoEActivationFormat
:
"""
A property which is a tuple of the input and output activation formats
for the 'apply' method.
...
...
@@ -435,6 +472,78 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
return
E
,
M
,
N
,
K
,
topk
#
# Various helpers for registering support for various features.
# Used by the oracle to select a particular kernel for a deployment.
#
@
staticmethod
def
is_supported_config
(
cls
:
type
[
"FusedMoEPermuteExpertsUnpermute"
],
moe_config
:
FusedMoEConfig
,
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
activation_format
:
FusedMoEActivationFormat
,
)
->
tuple
[
bool
,
str
|
None
]:
def
_make_reason
(
reason
:
str
)
->
str
:
return
f
"kernel does not support
{
reason
}
"
if
not
cls
.
_supports_current_device
():
return
False
,
_make_reason
(
"current device"
)
elif
not
(
moe_config
.
is_act_and_mul
or
cls
.
_supports_no_act_and_mul
()):
return
False
,
_make_reason
(
"no act_and_mul MLP layer"
)
elif
not
cls
.
_supports_activation
(
moe_config
.
activation
):
return
False
,
_make_reason
(
f
"
{
moe_config
.
activation
}
activation"
)
elif
not
cls
.
_supports_quant_scheme
(
weight_key
,
activation_key
):
return
False
,
_make_reason
(
"quantization scheme"
)
elif
not
cls
.
_supports_parallel_config
(
moe_config
.
moe_parallel_config
):
return
False
,
_make_reason
(
"parallel config"
)
elif
activation_format
!=
cls
.
activation_format
():
return
False
,
_make_reason
(
f
"
{
activation_format
.
value
}
activation format"
)
return
True
,
None
@
staticmethod
@
abstractmethod
def
_supports_current_device
()
->
bool
:
"""
Whether the kernel supports the current device type
(compute cability and current platform).
"""
raise
NotImplementedError
@
staticmethod
@
abstractmethod
def
_supports_no_act_and_mul
()
->
bool
:
"""
Whether the kernel supports act_and_mul=False, i.e.
non-gated MoE models like Nemotron-Nano.
"""
raise
NotImplementedError
@
staticmethod
@
abstractmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
raise
NotImplementedError
@
staticmethod
@
abstractmethod
def
_supports_activation
(
activation
:
str
)
->
bool
:
"""
Whether the kernel supports a particular act function.
"""
raise
NotImplementedError
@
staticmethod
@
abstractmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
"""
Whether the kernel supports deployment in expert parallel.
"""
raise
NotImplementedError
#
# Various helpers for accessing quantization parameters from the
# quant_config.
...
...
@@ -715,12 +824,12 @@ class FusedMoEModularKernel(torch.nn.Module):
self
.
_post_init_setup
()
assert
(
prepare_finalize
.
activation_format
==
fused_experts
.
activation_format
s
[
0
]
prepare_finalize
.
activation_format
==
fused_experts
.
activation_format
()
),
(
f
"
{
prepare_finalize
.
__class__
.
__name__
}
."
f
"
{
prepare_finalize
.
activation_format
}
== "
f
"
{
fused_experts
.
__class__
.
__name__
}
."
f
"
{
fused_experts
.
activation_format
s
[
0
]
}
"
f
"
{
fused_experts
.
activation_format
()
}
"
)
def
_post_init_setup
(
self
):
...
...
vllm/model_executor/layers/fused_moe/oracle/fp8.py
View file @
42135d68
This diff is collapsed.
Click to expand it.
vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
View file @
42135d68
...
...
@@ -14,21 +14,11 @@ from vllm.model_executor.layers.fused_moe.config import (
nvfp4_moe_quant_config
,
nvfp4_w4a16_moe_quant_config
,
)
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
CutlassExpertsFp4
,
)
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
FlashInferExperts
,
)
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
MarlinExperts
,
)
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
)
from
vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe
import
(
is_flashinfer_fp4_cutedsl_moe_available
,
is_flashinfer_fp4_cutlass_moe_available
,
is_supported_config_trtllm
,
prepare_nvfp4_moe_layer_for_fi_or_cutlass
,
)
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
...
...
@@ -36,27 +26,26 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
get_flashinfer_moe_backend
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
is_fp4_marlin_supported
,
prepare_nvfp4_moe_layer_for_marlin
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
cutlass_fp4_supported
,
QuantKey
,
)
logger
=
init_logger
(
__name__
)
class
NvFp4MoeBackend
(
Enum
):
FLASHINFER_
CUTLASS
=
"FlashInfer CUTLASS
"
FLASHINFER_
TRTLLM
=
"FlashInfer TRTLLM
"
FLASHINFER_CUTEDSL
=
"F
lashInfer
CUTEDSL"
VLLM_CUTLASS
=
"
v
LLM
CUTASS"
MARLIN
=
"
vLLM
MARLIN"
FLASHINFER_
TRTLLM
=
"FLASHINFER_TRTLLM
"
FLASHINFER_
CUTLASS
=
"FLASHINFER_CUTLASS
"
FLASHINFER_CUTEDSL
=
"F
LASHINFER_
CUTEDSL"
VLLM_CUTLASS
=
"
V
LLM
_
CUT
L
ASS"
MARLIN
=
"MARLIN"
FLASHINFER_NVFP4_MOE_BACKENDS
=
[
NvFp4MoeBackend
.
FLASHINFER_CUTLASS
,
NvFp4MoeBackend
.
FLASHINFER_TRTLLM
,
NvFp4MoeBackend
.
FLASHINFER_CUTLASS
,
NvFp4MoeBackend
.
FLASHINFER_CUTEDSL
,
]
...
...
@@ -72,44 +61,208 @@ def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool:
# of all experts in Expert Parallel Mode when all experts are not
# on the same rank.
return
backend
in
[
NvFp4MoeBackend
.
FLASHINFER_CUTLASS
,
return
backend
in
FLASHINFER_NVFP4_MOE_BACKENDS
def
backend_to_kernel_cls
(
backend
:
NvFp4MoeBackend
,
)
->
type
[
mk
.
FusedMoEPermuteExpertsUnpermute
]:
if
backend
==
NvFp4MoeBackend
.
FLASHINFER_TRTLLM
:
raise
NotImplementedError
(
"FLASHINFER_TRTLLM doesn't support Modular Kernel Interface"
)
elif
backend
==
NvFp4MoeBackend
.
FLASHINFER_CUTLASS
:
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
FlashInferExperts
,
)
return
FlashInferExperts
elif
backend
==
NvFp4MoeBackend
.
FLASHINFER_CUTEDSL
:
from
vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe
import
(
FlashInferCuteDSLExperts
,
)
return
FlashInferCuteDSLExperts
elif
backend
==
NvFp4MoeBackend
.
VLLM_CUTLASS
:
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
CutlassExpertsFp4
,
)
return
CutlassExpertsFp4
elif
backend
==
NvFp4MoeBackend
.
MARLIN
:
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
MarlinExperts
,
)
return
MarlinExperts
else
:
raise
ValueError
(
f
"Unknown NvFP4 MoE backend:
{
backend
.
value
}
"
)
def
select_nvfp4_moe_backend
(
config
:
FusedMoEConfig
,
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
tuple
[
NvFp4MoeBackend
,
type
[
mk
.
FusedMoEPermuteExpertsUnpermute
]
|
None
]:
"""
Select the primary NvFP4 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
# NOTE: the kernels are selected in the following order.
AVAILABLE_BACKENDS
=
[
NvFp4MoeBackend
.
FLASHINFER_TRTLLM
,
NvFp4MoeBackend
.
FLASHINFER_CUTEDSL
,
NvFp4MoeBackend
.
FLASHINFER_CUTLASS
,
NvFp4MoeBackend
.
VLLM_CUTLASS
,
NvFp4MoeBackend
.
MARLIN
,
]
# NOTE(rob): this is kind of a hack. We need to peak into
# the prepare-finalize selection to determine if we are using
# the batched or standard expert format.
use_batched
=
(
config
.
moe_parallel_config
.
use_deepep_ll_kernels
or
config
.
moe_parallel_config
.
use_pplx_kernels
)
activation_format
=
(
mk
.
FusedMoEActivationFormat
.
BatchedExperts
if
use_batched
else
mk
.
FusedMoEActivationFormat
.
Standard
)
def
select_nvfp4_moe_backend
()
->
NvFp4MoeBackend
:
def
_make_log_backend
(
backend
:
NvFp4MoeBackend
):
return
f
"Using
{
backend
.
value
}
backend for NvFp4 MoE"
available_backend_strs
=
[
b
.
value
for
b
in
AVAILABLE_BACKENDS
]
return
(
f
"Using '
{
backend
.
value
}
' NvFp4 MoE backend out "
f
"of potential backends:
{
available_backend_strs
}
."
)
if
cutlass_fp4_supported
()
and
not
envs
.
VLLM_TEST_FORCE_FP8_MARLIN
:
allow_flashinfer
=
(
is_flashinfer_fp4_cutlass_moe_available
()
or
is_flashinfer_fp4_cutedsl_moe_available
()
def
_make_log_unsupported
(
backend
:
NvFp4MoeBackend
,
reason
:
str
|
None
)
->
str
:
if
reason
:
return
(
f
"NvFp4 MoE backend '
{
backend
.
value
}
' does not support the "
f
"deployment configuration since
{
reason
}
."
)
else
:
return
(
f
"NvFp4 MoE backend '
{
backend
.
value
}
' does not support the "
"deployment configuration."
)
def
_return_or_raise
(
backend
:
NvFp4MoeBackend
,
config
:
FusedMoEConfig
,
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
activation_format
:
mk
.
FusedMoEActivationFormat
,
)
->
tuple
[
NvFp4MoeBackend
,
type
[
mk
.
FusedMoEPermuteExpertsUnpermute
]]:
k_cls
=
backend_to_kernel_cls
(
backend
)
supported
,
reason
=
k_cls
.
is_supported_config
(
k_cls
,
config
,
weight_key
,
activation_key
,
activation_format
)
if
allow_flashinfer
and
envs
.
VLLM_USE_FLASHINFER_MOE_FP4
:
backend
=
fi_2_vllm_backend_map
[
get_flashinfer_moe_backend
()]
if
supported
:
logger
.
info_once
(
_make_log_backend
(
backend
))
return
backend
,
k_cls
raise
ValueError
(
_make_log_unsupported
(
backend
,
reason
))
if
envs
.
is_set
(
"VLLM_USE_FLASHINFER_MOE_FP4"
):
if
not
envs
.
VLLM_USE_FLASHINFER_MOE_FP4
:
# If the user rejects FlashInfer remove those backends.
for
b
in
FLASHINFER_NVFP4_MOE_BACKENDS
:
AVAILABLE_BACKENDS
.
remove
(
b
)
elif
envs
.
is_set
(
"VLLM_FLASHINFER_MOE_BACKEND"
):
# If user is explicit about backend, validate it.
fi_backend
=
get_flashinfer_moe_backend
()
if
fi_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
:
backend
=
NvFp4MoeBackend
.
FLASHINFER_TRTLLM
supported
,
reason
=
is_supported_config_trtllm
(
config
,
weight_key
,
activation_key
,
activation_format
)
if
supported
:
logger
.
info_once
(
_make_log_backend
(
backend
))
return
backend
,
None
else
:
raise
ValueError
(
_make_log_unsupported
(
backend
,
reason
))
else
:
backend
=
fi_2_vllm_backend_map
[
fi_backend
]
return
_return_or_raise
(
backend
,
config
,
weight_key
,
activation_key
,
activation_format
)
else
:
backend
=
NvFp4MoeBackend
.
VLLM_CUTLASS
elif
is_fp4_marlin_supported
():
# If the user is not explicit about the backend, try each.
for
backend
in
FLASHINFER_NVFP4_MOE_BACKENDS
:
if
backend
==
NvFp4MoeBackend
.
FLASHINFER_TRTLLM
:
k_cls
=
None
supported
,
reason
=
is_supported_config_trtllm
(
config
,
weight_key
,
activation_key
,
activation_format
,
)
else
:
k_cls
=
backend_to_kernel_cls
(
backend
)
supported
,
reason
=
k_cls
.
is_supported_config
(
k_cls
,
config
,
weight_key
,
activation_key
,
activation_format
,
)
if
supported
:
logger
.
info_once
(
_make_log_backend
(
backend
),
scope
=
"local"
)
return
backend
,
None
else
:
logger
.
debug_once
(
_make_log_unsupported
(
backend
,
reason
),
scope
=
"local"
)
raise
NotImplementedError
(
"Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no "
"FlashInfer NVFP4 MoE backend supports the configuration."
)
if
envs
.
VLLM_TEST_FORCE_FP8_MARLIN
:
backend
=
NvFp4MoeBackend
.
MARLIN
else
:
raise
ValueError
(
"No NvFp4 kernel backend available for NvFp4 MoE."
)
# Log warning if FI backend requested but not available.
if
(
backend
not
in
FLASHINFER_NVFP4_MOE_BACKENDS
and
envs
.
VLLM_USE_FLASHINFER_MOE_FP4
):
logger
.
warning_once
(
"Requested FlashInfer backend for NvFp4 MoE, but it's not available. "
"Falling back to %s for NvFp4 MoE"
,
backend
.
value
,
scope
=
"local"
,
return
_return_or_raise
(
backend
,
config
,
weight_key
,
activation_key
,
activation_format
)
else
:
logger
.
info_once
(
_make_log_backend
(
backend
),
scope
=
"local"
)
return
backend
# Select kernels in order of backend.
for
backend
in
AVAILABLE_BACKENDS
:
if
backend
==
NvFp4MoeBackend
.
FLASHINFER_TRTLLM
:
k_cls
=
None
# type: ignore[assignment]
supported
,
reason
=
is_supported_config_trtllm
(
config
,
weight_key
,
activation_key
,
activation_format
,
)
else
:
k_cls
=
backend_to_kernel_cls
(
backend
)
supported
,
reason
=
k_cls
.
is_supported_config
(
k_cls
,
config
,
weight_key
,
activation_key
,
activation_format
,
)
if
supported
:
logger
.
info_once
(
_make_log_backend
(
backend
),
scope
=
"local"
)
return
backend
,
k_cls
else
:
logger
.
debug_once
(
_make_log_unsupported
(
backend
,
reason
),
scope
=
"local"
)
raise
NotImplementedError
(
"No NvFp4 MoE backend supports the deployment configuration."
)
def
convert_to_nvfp4_moe_kernel_format
(
...
...
@@ -238,55 +391,69 @@ def make_nvfp4_moe_quant_config(
)
def
make_nvfp4_moe_kernel
(
backend
:
NvFp4MoeBackend
,
quant_config
:
FusedMoEQuantConfig
,
def
make_nvfp4_moe_kernel_for_mkm
(
moe_config
:
FusedMoEConfig
,
)
->
mk
.
FusedMoEModularKernel
|
None
:
assert
moe_config
.
dp_size
==
1
quant_config
:
FusedMoEQuantConfig
,
experts_cls
:
type
[
mk
.
FusedMoEPermuteExpertsUnpermute
],
prepare_finalize
:
mk
.
FusedMoEPrepareAndFinalize
,
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
if
prepare_finalize
.
activation_format
==
mk
.
FusedMoEActivationFormat
.
BatchedExperts
:
max_num_tokens_per_rank
=
prepare_finalize
.
max_num_tokens_per_rank
()
assert
max_num_tokens_per_rank
is
not
None
experts
=
experts_cls
(
moe_config
=
moe_config
,
quant_config
=
quant_config
,
max_num_tokens
=
max_num_tokens_per_rank
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
)
else
:
experts
=
experts_cls
(
moe_config
=
moe_config
,
quant_config
=
quant_config
,
)
UNSUPPORTED_BACKENDS
=
[
# TRTLLM does not use the modular kernl abstraction.
NvFp4MoeBackend
.
FLASHINFER_TRTLLM
,
# CUTEDSL is used with BATCHED (masked) format only.
# TODO: add here once we support dp/ep via the oracle.
NvFp4MoeBackend
.
FLASHINFER_CUTEDSL
,
]
logger
.
debug_once
(
"Using %s"
,
experts
.
__class__
.
__name__
)
return
experts
if
backend
in
UNSUPPORTED_BACKENDS
:
return
None
elif
backend
==
NvFp4MoeBackend
.
FLASHINFER_CUTLASS
:
return
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(
defer_input_quant
=
True
),
FlashInferExperts
(
out_dtype
=
moe_config
.
in_dtype
,
quant_config
=
quant_config
,
ep_rank
=
moe_config
.
ep_rank
,
ep_size
=
moe_config
.
ep_size
,
tp_rank
=
moe_config
.
tp_rank
,
tp_size
=
moe_config
.
tp_size
,
use_dp
=
False
,
use_deepseek_fp8_block_scale
=
False
,
),
def
make_nvfp4_moe_kernel
(
moe_quant_config
:
FusedMoEQuantConfig
,
moe_config
:
FusedMoEConfig
,
experts_cls
:
type
[
mk
.
FusedMoEPermuteExpertsUnpermute
],
)
->
mk
.
FusedMoEModularKernel
:
# TODO(rob): unify after we merge tp and dp/ep.
if
(
moe_config
.
moe_parallel_config
.
use_all2all_kernels
and
moe_config
.
moe_parallel_config
.
all2all_backend
not
in
[
"allgather_reducescatter"
,
"naive"
]
):
raise
ValueError
(
"NvFP4 Oracle should not create non-naive A2A P/F. "
"This should happen via the ModularKernelMethod."
)
elif
backend
==
NvFp4MoeBackend
.
VLLM_CUTLASS
:
return
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(
defer_input_quant
=
True
),
CutlassExpertsFp4
(
out_dtype
=
moe_config
.
in_dtype
,
# TODO(rob): see what impact this has on expert map?
max_experts_per_worker
=
moe_config
.
num_experts
,
quant_config
=
quant_config
,
),
)
# Create Prepare/Finalize.
prepare_finalize
=
MoEPrepareAndFinalizeNoEP
(
defer_input_quant
=
experts_cls
.
expects_unquantized_inputs
(
moe_config
,
moe_quant_config
),
)
elif
backend
==
NvFp4MoeBackend
.
MARLIN
:
return
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
()
,
MarlinExperts
(
quant_config
=
quant_config
)
,
)
# Create Experts.
experts
=
experts_cls
(
moe_config
=
moe_config
,
quant_config
=
moe_
quant_config
,
)
else
:
raise
ValueError
(
f
"Unknown NvFp4 MoE backend:
{
backend
}
"
)
# NOTE(rob): we only want the mk to control the shared_expert
# if using all2all (for SBO). bnell is making this explict in
# the new MoE runner class.
kernel
=
mk
.
FusedMoEModularKernel
(
prepare_finalize
,
experts
,
shared_experts
=
None
,
moe_parallel_config
=
moe_config
.
moe_parallel_config
,
)
# TODO(rob): update inplace logic to be part of the kernel.
return
kernel
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment