Unverified Commit 42135d68 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor] Oracle Select FP8+NVFP4 Kernels In Priority (#32414)

parent e14467be
...@@ -123,7 +123,6 @@ def convert_to_unquantized_kernel_format( ...@@ -123,7 +123,6 @@ def convert_to_unquantized_kernel_format(
def make_unquantized_moe_kernel( def make_unquantized_moe_kernel(
layer: torch.nn.Module,
backend: UnquantizedMoeBackend, backend: UnquantizedMoeBackend,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
...@@ -141,12 +140,8 @@ def make_unquantized_moe_kernel( ...@@ -141,12 +140,8 @@ def make_unquantized_moe_kernel(
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoEP(),
FlashInferExperts( FlashInferExperts(
out_dtype=layer.params_dtype, moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
tp_rank=moe_config.moe_parallel_config.tp_rank,
tp_size=moe_config.moe_parallel_config.tp_size,
ep_rank=moe_config.moe_parallel_config.ep_rank,
ep_size=moe_config.moe_parallel_config.ep_size,
), ),
) )
use_inplace = False use_inplace = False
...@@ -157,13 +152,19 @@ def make_unquantized_moe_kernel( ...@@ -157,13 +152,19 @@ def make_unquantized_moe_kernel(
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoEP(),
AiterExperts(quant_config), AiterExperts(
moe_config=moe_config,
quant_config=quant_config,
),
) )
elif backend == UnquantizedMoeBackend.TRITON: elif backend == UnquantizedMoeBackend.TRITON:
from vllm.model_executor.layers.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe import TritonExperts
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoEP(),
TritonExperts(quant_config), TritonExperts(
moe_config=moe_config,
quant_config=quant_config,
),
) )
return kernel, use_inplace return kernel, use_inplace
...@@ -9,11 +9,21 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk ...@@ -9,11 +9,21 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEParallelConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
)
class QuantMethod(IntEnum): class QuantMethod(IntEnum):
...@@ -269,17 +279,49 @@ def rocm_aiter_fused_experts( ...@@ -269,17 +279,49 @@ def rocm_aiter_fused_experts(
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config): @staticmethod
super().__init__(quant_config) def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def expects_unquantized_inputs(
fused_moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
# AITER fused MoE kernels handle input quantization internally.
return True
@property @staticmethod
def activation_formats( def _supports_current_device() -> bool:
self, return rocm_aiter_ops.is_fused_moe_enabled()
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return ( @staticmethod
mk.FusedMoEActivationFormat.Standard, def _supports_no_act_and_mul() -> bool:
mk.FusedMoEActivationFormat.Standard, return False
)
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
# TODO(rob): AITER also supports MXFP4, which is not
# yet supported via an Oracle. Once it is, we will add
# MXFP4 to this list.
SUPPORTED_W_A = [
(None, None),
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu", "gelu"]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
def supports_expert_map(self): def supports_expert_map(self):
return True return True
......
...@@ -34,6 +34,11 @@ class CustomRoutingRouter(BaseRouter): ...@@ -34,6 +34,11 @@ class CustomRoutingRouter(BaseRouter):
@property @property
def routing_method_type(self) -> RoutingMethodType: def routing_method_type(self) -> RoutingMethodType:
from vllm.model_executor.models.llama4 import Llama4MoE
# NOTE: FLASHINFER_TRTLLM support the Llama4 router.
if self.custom_routing_function == Llama4MoE.custom_routing_function:
return RoutingMethodType.Llama4
return RoutingMethodType.Custom return RoutingMethodType.Custom
def _compute_routing( def _compute_routing(
......
...@@ -261,7 +261,6 @@ class GroupedTopKRouter(BaseRouter): ...@@ -261,7 +261,6 @@ class GroupedTopKRouter(BaseRouter):
num_fused_shared_experts: int = 0, num_fused_shared_experts: int = 0,
enable_eplb: bool = False, enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None,
routing_method_type: RoutingMethodType | None = None,
): ):
super().__init__( super().__init__(
top_k=top_k, top_k=top_k,
...@@ -278,13 +277,12 @@ class GroupedTopKRouter(BaseRouter): ...@@ -278,13 +277,12 @@ class GroupedTopKRouter(BaseRouter):
self.e_score_correction_bias = e_score_correction_bias self.e_score_correction_bias = e_score_correction_bias
self.num_fused_shared_experts = num_fused_shared_experts self.num_fused_shared_experts = num_fused_shared_experts
# Determine routing method type if scoring_func == "sigmoid":
if routing_method_type is not None:
self._routing_method_type = routing_method_type
elif scoring_func == "sigmoid":
self._routing_method_type = RoutingMethodType.DeepSeekV3 self._routing_method_type = RoutingMethodType.DeepSeekV3
else: else:
self._routing_method_type = RoutingMethodType.TopK # NOTE: this prohibits the FLASHINFER_TRTLLM kernels from
# being selected, since they only support DeepSeek-style.
self._routing_method_type = RoutingMethodType.Unspecified
@property @property
def routing_method_type(self) -> RoutingMethodType: def routing_method_type(self) -> RoutingMethodType:
......
...@@ -6,7 +6,6 @@ import torch ...@@ -6,7 +6,6 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
from vllm.model_executor.layers.fused_moe.router.custom_routing_router import ( from vllm.model_executor.layers.fused_moe.router.custom_routing_router import (
CustomRoutingRouter, CustomRoutingRouter,
...@@ -36,7 +35,6 @@ def create_fused_moe_router( ...@@ -36,7 +35,6 @@ def create_fused_moe_router(
global_num_experts: int, global_num_experts: int,
renormalize: bool = True, renormalize: bool = True,
indices_type_getter: Callable[[], torch.dtype | None] | None = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None,
routing_method_type: RoutingMethodType | None = None,
# grouped topk parameters # grouped topk parameters
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: int | None = None, num_expert_group: int | None = None,
...@@ -128,7 +126,6 @@ def create_fused_moe_router( ...@@ -128,7 +126,6 @@ def create_fused_moe_router(
num_fused_shared_experts=num_fused_shared_experts, num_fused_shared_experts=num_fused_shared_experts,
enable_eplb=enable_eplb, enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter, indices_type_getter=indices_type_getter,
routing_method_type=routing_method_type,
) )
router.capture = capture router.capture = capture
return router return router
......
...@@ -5,7 +5,10 @@ ...@@ -5,7 +5,10 @@
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
...@@ -17,19 +20,22 @@ class TritonOrCutlassExperts(FallbackExperts): ...@@ -17,19 +20,22 @@ class TritonOrCutlassExperts(FallbackExperts):
def __init__( def __init__(
self, self,
e: int, moe_config: FusedMoEConfig,
n: int,
k: int,
out_dtype: torch.dtype | None,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
device: torch.dtype,
): ):
self.is_sm100 = current_platform.has_device_capability(100) self.is_sm100 = current_platform.has_device_capability(100)
super().__init__( super().__init__(
experts=CutlassExpertsFp8(e, n, k, out_dtype, quant_config, device), experts=CutlassExpertsFp8(moe_config, quant_config),
fallback_experts=TritonExperts(quant_config), fallback_experts=TritonExperts(moe_config, quant_config),
) )
@staticmethod
def get_clses() -> tuple[
type[mk.FusedMoEPermuteExpertsUnpermute],
type[mk.FusedMoEPermuteExpertsUnpermute],
]:
return (CutlassExpertsFp8, TritonExperts)
def workspace_shapes( def workspace_shapes(
self, self,
M: int, M: int,
......
...@@ -4,7 +4,10 @@ ...@@ -4,7 +4,10 @@
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts, DeepGemmExperts,
_valid_deep_gemm, _valid_deep_gemm,
...@@ -20,12 +23,19 @@ from vllm.utils.deep_gemm import ( ...@@ -20,12 +23,19 @@ from vllm.utils.deep_gemm import (
class TritonOrDeepGemmExperts(FallbackExperts): class TritonOrDeepGemmExperts(FallbackExperts):
"""DeepGemm with fallback to Triton for low latency shapes.""" """DeepGemm with fallback to Triton for low latency shapes."""
def __init__(self, quant_config: FusedMoEQuantConfig): def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
super().__init__( super().__init__(
experts=DeepGemmExperts(quant_config), experts=DeepGemmExperts(moe_config, quant_config),
fallback_experts=TritonExperts(quant_config), fallback_experts=TritonExperts(moe_config, quant_config),
) )
@staticmethod
def get_clses() -> tuple[
type[mk.FusedMoEPermuteExpertsUnpermute],
type[mk.FusedMoEPermuteExpertsUnpermute],
]:
return (DeepGemmExperts, TritonExperts)
def workspace_shapes( def workspace_shapes(
self, self,
M: int, M: int,
......
...@@ -6,37 +6,73 @@ import torch ...@@ -6,37 +6,73 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
moe: FusedMoEConfig, moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
gemm1_alpha, gemm1_alpha,
gemm1_beta, gemm1_beta,
gemm1_clamp_limit, gemm1_clamp_limit,
max_capture_size, max_capture_size,
): ):
super().__init__(quant_config) super().__init__(moe_config, quant_config)
self.moe = moe
self.gemm1_alpha = gemm1_alpha self.gemm1_alpha = gemm1_alpha
self.gemm1_beta = gemm1_beta self.gemm1_beta = gemm1_beta
self.gemm1_clamp_limit = gemm1_clamp_limit self.gemm1_clamp_limit = gemm1_clamp_limit
self.max_capture_size = max_capture_size self.max_capture_size = max_capture_size
@property @staticmethod
def activation_formats( def activation_format() -> mk.FusedMoEActivationFormat:
self, return mk.FusedMoEActivationFormat.Standard
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return ( @staticmethod
mk.FusedMoEActivationFormat.Standard, def _supports_current_device() -> bool:
mk.FusedMoEActivationFormat.Standard, raise NotImplementedError(
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
raise NotImplementedError(
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
raise NotImplementedError(
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_activation(activation: str) -> bool:
raise NotImplementedError(
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
raise NotImplementedError(
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
) )
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
...@@ -86,7 +122,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -86,7 +122,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk = topk_ids.size(-1) topk = topk_ids.size(-1)
local_num_experts = w1.size(0) local_num_experts = w1.size(0)
intermediate_size = w2.size(1) intermediate_size = w2.size(1)
local_expert_offset = self.moe.ep_rank * local_num_experts local_expert_offset = self.moe_config.ep_rank * local_num_experts
x_quant = hidden_states x_quant = hidden_states
x_scale = a1q_scale x_scale = a1q_scale
......
...@@ -96,13 +96,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -96,13 +96,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
): ):
logger.debug("BatchedTritonExperts %s", self.moe) logger.debug("BatchedTritonExperts %s", self.moe)
return BatchedTritonExperts( return BatchedTritonExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config,
max_num_tokens=self.moe.max_num_tokens, max_num_tokens=self.moe.max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
) )
else: else:
logger.debug("TritonExperts %s", self.moe) logger.debug("TritonExperts %s", self.moe)
return TritonExperts(self.moe_quant_config) return TritonExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config,
)
def create_weights( def create_weights(
self, self,
...@@ -192,7 +196,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -192,7 +196,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
assert self.moe_quant_config is not None assert self.moe_quant_config is not None
self.kernel, self.use_inplace = make_unquantized_moe_kernel( self.kernel, self.use_inplace = make_unquantized_moe_kernel(
layer=layer,
backend=self.unquantized_backend, backend=self.unquantized_backend,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
......
...@@ -739,6 +739,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -739,6 +739,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
return BatchedMarlinExperts( return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank, max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
moe_config=self.moe,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx, w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx, w2_g_idx=w2_g_idx,
...@@ -749,6 +750,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -749,6 +750,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
else: else:
# Standard Marlin experts for AWQ # Standard Marlin experts for AWQ
return MarlinExperts( return MarlinExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx, w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx, w2_g_idx=w2_g_idx,
......
...@@ -19,7 +19,6 @@ from vllm.logger import init_logger ...@@ -19,7 +19,6 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoE,
FusedMoEActivationFormat, FusedMoEActivationFormat,
FusedMoEConfig,
FusedMoEMethodBase, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute, FusedMoEPermuteExpertsUnpermute,
FusedMoERouter, FusedMoERouter,
...@@ -27,9 +26,9 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -27,9 +26,9 @@ from vllm.model_executor.layers.fused_moe import (
UnquantizedFusedMoEMethod, UnquantizedFusedMoEMethod,
) )
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config, RoutingMethodType,
fp8_w8a16_moe_quant_config,
int4_w4a16_moe_quant_config, int4_w4a16_moe_quant_config,
int4_w4afp8_moe_quant_config, int4_w4afp8_moe_quant_config,
int8_w8a8_moe_quant_config, int8_w8a8_moe_quant_config,
...@@ -45,15 +44,17 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( ...@@ -45,15 +44,17 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend, Fp8MoeBackend,
convert_to_fp8_moe_kernel_format, convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel, make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config,
select_fp8_moe_backend, select_fp8_moe_backend,
) )
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
FLASHINFER_NVFP4_MOE_BACKENDS,
NvFp4MoeBackend, NvFp4MoeBackend,
convert_to_nvfp4_moe_kernel_format, convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend, is_global_sf_supported_for_nvfp4_backend,
make_mxfp4_moe_quant_config, make_mxfp4_moe_quant_config,
make_nvfp4_moe_kernel, make_nvfp4_moe_kernel,
make_nvfp4_moe_kernel_for_mkm,
make_nvfp4_moe_quant_config, make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend, select_nvfp4_moe_backend,
) )
...@@ -62,10 +63,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compress ...@@ -62,10 +63,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compress
WNA16_SUPPORTED_TYPES_MAP, WNA16_SUPPORTED_TYPES_MAP,
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe, flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe, flashinfer_trtllm_fp4_routed_moe,
select_nvfp4_gemm_impl, )
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_input_tensor_strategy_moe, process_fp8_input_tensor_strategy_moe,
...@@ -79,12 +82,18 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -79,12 +82,18 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_moe_permute_scales, marlin_moe_permute_scales,
) )
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
is_fp4_marlin_supported,
prepare_moe_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
convert_bf16_scales_to_fp8, convert_bf16_scales_to_fp8,
convert_packed_uint4b8_to_signed_int4_inplace, convert_packed_uint4b8_to_signed_int4_inplace,
kFp8Dynamic128Sym,
kFp8DynamicTokenSym,
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
kNvfp4Dynamic,
kNvfp4Static,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
...@@ -200,7 +209,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ...@@ -200,7 +209,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
f"or None for NVFP4A16, found {input_quant}", f"or None for NVFP4A16, found {input_quant}",
) )
return CompressedTensorsW4A4Nvfp4MoEMethod( return CompressedTensorsW4A4Nvfp4MoEMethod(
layer.moe_config, layer_name, use_marlin=input_quant is None layer.moe_config, layer_name, use_a16=(input_quant is None)
) )
elif ( elif (
quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
...@@ -234,6 +243,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -234,6 +243,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
super().__init__(moe) super().__init__(moe)
self.group_size = 32 self.group_size = 32
self.mxfp4_backend = NvFp4MoeBackend.MARLIN self.mxfp4_backend = NvFp4MoeBackend.MARLIN
self.experts_cls = MarlinExperts
self.kernel: mk.FusedMoEModularKernel | None = None self.kernel: mk.FusedMoEModularKernel | None = None
def create_weights( def create_weights(
...@@ -327,9 +337,9 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -327,9 +337,9 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config is not None: if self.moe_quant_config is not None:
self.kernel = make_nvfp4_moe_kernel( self.kernel = make_nvfp4_moe_kernel(
backend=self.mxfp4_backend, moe_quant_config=self.moe_quant_config,
quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
experts_cls=self.experts_cls,
) )
def apply( def apply(
...@@ -368,34 +378,30 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -368,34 +378,30 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
self, self,
moe: FusedMoEConfig, moe: FusedMoEConfig,
layer_name: str | None = None, layer_name: str | None = None,
use_marlin: bool = False, use_a16: bool = False,
): ):
super().__init__(moe) super().__init__(moe)
self.group_size = 16 self.group_size = 16
if use_marlin:
if is_fp4_marlin_supported():
self.nvfp4_backend = NvFp4MoeBackend.MARLIN
else:
raise ValueError(
"Marlin FP4 MoE kernel requested but not ",
"supported on current platform.",
)
else:
self.nvfp4_backend = select_nvfp4_moe_backend()
# TODO: move this type of check into the oracle. # Select experts implementation.
if not self.moe.is_act_and_mul and self.nvfp4_backend not in [ self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend(
NvFp4MoeBackend.FLASHINFER_CUTLASS, config=self.moe,
NvFp4MoeBackend.MARLIN, weight_key=kNvfp4Static,
]: activation_key=None if use_a16 else kNvfp4Dynamic,
raise NotImplementedError( )
"Non-gated activations are only supported by FlashInfer "
f"CUTLASS and Marlin NvFP4 MoE backends, not {self.nvfp4_backend}." # Delay creation of the kernel until after process-weights.
) self.kernel: mk.FusedMoEModularKernel | None = None
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend self.nvfp4_backend
) )
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights( def create_weights(
self, self,
...@@ -571,35 +577,40 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -571,35 +577,40 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer.w13_input_scale = a13_scale layer.w13_input_scale = a13_scale
layer.w2_input_scale = a2_scale layer.w2_input_scale = a2_scale
# Initialize the kernel that will be called in apply(). # Setup modular kernel for TP case and naive DP/EP case.
# In non-naive DP/EP case, we will create a ModularKernelMethod.
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
use_dp = self.moe.dp_size > 1 if self.moe_quant_config and (
if self.moe_quant_config is not None and not use_dp: (not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None
self.kernel = make_nvfp4_moe_kernel( self.kernel = make_nvfp4_moe_kernel(
backend=self.nvfp4_backend, moe_quant_config=self.moe_quant_config,
quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
experts_cls=self.experts_cls,
) )
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM] if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
if self.nvfp4_backend in UNSUPPORTED:
return None return None
elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
# TP case: avoid convert to ModularKernelMethod - to be refactored. # For no-EP case, don't use the MKM framework.
if self.moe.dp_size == 1: if not self.moe.moe_parallel_config.use_all2all_kernels:
return None return None
# For now, fp4 moe only works with the flashinfer dispatcher.
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe self.moe,
use_deepseek_fp8_block_scale=False,
) )
logger.debug_once("%s", prepare_finalize.__class__.__name__) logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize return prepare_finalize
else: return super().maybe_make_prepare_finalize(routing_tables)
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
...@@ -607,14 +618,13 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -607,14 +618,13 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None assert self.moe_quant_config is not None
"""Return the appropriate GEMM experts implementation.""" assert self.experts_cls is not None
experts = select_nvfp4_gemm_impl( return make_nvfp4_moe_kernel_for_mkm(
self.moe, moe_config=self.moe,
self.moe_quant_config, quant_config=self.moe_quant_config,
allow_flashinfer=(self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS), experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module self, layer: torch.nn.Module
...@@ -727,33 +737,41 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -727,33 +737,41 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
"For FP8 Fused MoE layer, we require either per tensor or " "For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization." "channelwise, dynamic per token quantization."
) )
self.fp8_backend = select_fp8_moe_backend(
block_quant=self.block_quant, ct2vllm_weight = {
tp_size=moe.tp_size, QuantizationStrategy.CHANNEL: kFp8StaticChannelSym,
with_lora_support=moe.is_lora_enabled, QuantizationStrategy.TENSOR: kFp8StaticTensorSym,
is_act_and_mul=moe.is_act_and_mul, QuantizationStrategy.BLOCK: kFp8Static128BlockSym,
# TODO(rob): enable selecting this externally. }
ct2vllm_act = {
QuantizationStrategy.TOKEN: kFp8DynamicTokenSym,
QuantizationStrategy.TENSOR: (
kFp8StaticTensorSym if self.static_input_scales else kFp8Dynamic128Sym
),
}
weight_key = ct2vllm_weight[self.weight_quant.strategy]
if weight_key == kFp8Static128BlockSym:
activation_key = kFp8Dynamic128Sym
else:
activation_key = ct2vllm_act[self.input_quant.strategy]
# Select Fp8 MoE backend
self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
config=self.moe,
weight_key=weight_key,
activation_key=activation_key,
allow_vllm_cutlass=True, allow_vllm_cutlass=True,
) )
if self.fp8_backend != Fp8MoeBackend.MARLIN:
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
per_channel_quant = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
)
if per_act_token != per_channel_quant:
raise NotImplementedError(
"For FP8 Fused MoE layers, per-token and per-channel must be "
"used together."
)
# TODO(rob): hook this up in a follow up PR.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
raise NotImplementedError(
"FlashInfer TRTLLM backend not supported for compressed-tensors yet."
)
self.disable_expert_map = False
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -970,140 +988,75 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -970,140 +988,75 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
replace_parameter(layer, "w13_weight_scale", w13_scale) replace_parameter(layer, "w13_weight_scale", w13_scale)
replace_parameter(layer, "w2_weight_scale", w2_scale) replace_parameter(layer, "w2_weight_scale", w2_scale)
# Setup modular kernel for TP case and naive DP/EP case.
# In non-naive DP/EP case, we will create a ModularKernelMethod.
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: if self.moe_quant_config and (
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None
self.kernel, self.use_inplace = make_fp8_moe_kernel( self.kernel, self.use_inplace = make_fp8_moe_kernel(
layer=layer,
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
) )
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]: if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None return None
else: elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
return super().maybe_make_prepare_finalize(routing_tables) # For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute: ) -> FusedMoEPermuteExpertsUnpermute:
# cutlass path
assert self.moe_quant_config is not None assert self.moe_quant_config is not None
if self.fp8_backend == Fp8MoeBackend.VLLM_CUTLASS: assert self.experts_cls is not None
from vllm.model_executor.layers.fused_moe import ( return make_fp8_moe_kernel_for_mkm(
CutlassBatchedExpertsFp8, moe_config=self.moe,
CutlassExpertsFp8, quant_config=self.moe_quant_config,
) experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
experts: FusedMoEPermuteExpertsUnpermute
num_dispatchers = prepare_finalize.num_dispatchers()
if (
prepare_finalize.activation_format
== FusedMoEActivationFormat.BatchedExperts
):
logger.debug("CutlassBatchedExpertsFp8(%s)", self.__class__.__name__)
experts = CutlassBatchedExpertsFp8(
max_experts_per_worker=self.moe.num_local_experts,
num_dispatchers=num_dispatchers,
out_dtype=self.moe.in_dtype,
e=layer.local_num_experts,
n=layer.intermediate_size_per_partition,
k=layer.hidden_size,
device=layer.w13_weight.device,
quant_config=self.moe_quant_config,
)
else:
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
experts = CutlassExpertsFp8(
out_dtype=self.moe.in_dtype,
e=layer.local_num_experts,
n=layer.intermediate_size_per_partition,
k=layer.hidden_size,
device=layer.w13_weight.device,
quant_config=self.moe_quant_config,
)
# TODO(rob): investigate disable_expert_map
self.disable_expert_map = (
num_dispatchers > 1 or not experts.supports_expert_map()
)
return experts
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
TritonExperts,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
) )
assert self.fp8_backend not in [Fp8MoeBackend.AITER, Fp8MoeBackend.MARLIN]
if (
prepare_finalize.activation_format
== FusedMoEActivationFormat.BatchedExperts
):
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__)
return BatchedDeepGemmExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
)
else:
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
return BatchedTritonExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
)
else:
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__)
return TritonOrDeepGemmExperts(self.moe_quant_config)
else:
logger.debug("TritonExperts(%s)", self.__class__.__name__)
return TritonExperts(self.moe_quant_config)
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig | None:
if self.fp8_backend == Fp8MoeBackend.MARLIN: w1_scale = layer.w13_weight_scale
return fp8_w8a16_moe_quant_config( w2_scale = layer.w2_weight_scale
w1_scale=layer.w13_weight_scale, a1_scale = layer.w13_input_scale
w2_scale=layer.w2_weight_scale, a2_scale = layer.w2_input_scale
block_shape=self.weight_block_size,
)
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL
return fp8_w8a8_moe_quant_config( return make_fp8_moe_quant_config(
w1_scale=layer.w13_weight_scale, fp8_backend=self.fp8_backend,
w2_scale=layer.w2_weight_scale, w1_scale=w1_scale,
a1_scale=layer.w13_input_scale, w2_scale=w2_scale,
a2_scale=layer.w2_input_scale, a1_scale=a1_scale,
per_act_token_quant=per_act_token, a2_scale=a2_scale,
per_out_ch_quant=per_channel_quant, per_act_token_quant=(
block_shape=layer.weight_block_size, self.input_quant.strategy == QuantizationStrategy.TOKEN
),
per_out_ch_quant=(self.input_quant.strategy == QuantizationStrategy.TOKEN),
block_shape=self.weight_block_size,
) )
def apply( def apply(
...@@ -1113,6 +1066,56 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1113,6 +1066,56 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for `FlashInfer TRTLLM FP8 MoE`."
)
assert layer.activation == "silu"
if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
e_score_correction_bias = (
layer.e_score_correction_bias.to(x.dtype)
if layer.e_score_correction_bias is not None
else None
)
routing_method_type = layer.routing_method_type
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits,
routing_bias=e_score_correction_bias,
x=x,
w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale,
w2_weight=layer.w2_weight,
w2_weight_scale_inv=layer.w2_weight_scale,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
intermediate_size=layer.intermediate_size_per_partition,
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size,
routing_method_type=routing_method_type,
routed_scaling=layer.routed_scaling_factor,
)
else:
return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
topk_weights, topk_ids = router.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
...@@ -1130,7 +1133,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1130,7 +1133,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
# TODO(rob): investigate the disable_expert_map introduced by: # TODO(rob): investigate the disable_expert_map introduced by:
# https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501 # https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501
expert_map=None if self.disable_expert_map else layer.expert_map, expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
...@@ -1596,6 +1599,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1596,6 +1599,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
return BatchedMarlinExperts( return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank, max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
moe_config=self.moe,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
w13_g_idx=layer.w13_weight_g_idx, w13_g_idx=layer.w13_weight_g_idx,
w2_g_idx=layer.w2_weight_g_idx, w2_g_idx=layer.w2_weight_g_idx,
...@@ -1605,6 +1609,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1605,6 +1609,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
) )
else: else:
return MarlinExperts( return MarlinExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
w13_g_idx=layer.w13_weight_g_idx, w13_g_idx=layer.w13_weight_g_idx,
w2_g_idx=layer.w2_weight_g_idx, w2_g_idx=layer.w2_weight_g_idx,
...@@ -1854,7 +1859,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1854,7 +1859,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight = layer.w13_weight_packed layer.w13_weight = layer.w13_weight_packed
layer.w2_weight = layer.w2_weight_packed layer.w2_weight = layer.w2_weight_packed
return TritonWNA16Experts(quant_config=self.moe_quant_config) return TritonWNA16Experts(
moe_config=self.moe, quant_config=self.moe_quant_config
)
else: else:
raise NotImplementedError( raise NotImplementedError(
"TritonExperts requires Triton. " "TritonExperts requires Triton. "
...@@ -2467,6 +2474,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -2467,6 +2474,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
c_strides2=self.a_strides1_c_strides2, c_strides2=self.a_strides1_c_strides2,
s_strides1=self.s_strides1, s_strides1=self.s_strides1,
s_strides2=self.s_strides2, s_strides2=self.s_strides2,
moe_config=self.moe,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
group_size=self.group_size, group_size=self.group_size,
) )
...@@ -2505,6 +2513,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -2505,6 +2513,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight_packed, layer.w2_weight_packed,
topk_weights, topk_weights,
topk_ids, topk_ids,
moe_config=self.moe,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
activation=layer.activation, activation=layer.activation,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
......
...@@ -19,7 +19,6 @@ from vllm.model_executor.layers.batch_invariant import ( ...@@ -19,7 +19,6 @@ from vllm.model_executor.layers.batch_invariant import (
) )
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoE,
FusedMoEActivationFormat,
FusedMoEMethodBase, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalize,
...@@ -35,6 +34,7 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( ...@@ -35,6 +34,7 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend, Fp8MoeBackend,
convert_to_fp8_moe_kernel_format, convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel, make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config, make_fp8_moe_quant_config,
select_fp8_moe_backend, select_fp8_moe_backend,
) )
...@@ -55,7 +55,6 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod ...@@ -55,7 +55,6 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe, apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize, build_flashinfer_fp8_cutlass_moe_prepare_finalize,
select_cutlass_fp8_gemm_impl,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, W8A8BlockFp8LinearOp,
...@@ -79,8 +78,10 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( ...@@ -79,8 +78,10 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
is_layer_skipped, is_layer_skipped,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym, kFp8DynamicTensorSym,
kFp8DynamicTokenSym, kFp8DynamicTokenSym,
kFp8Static128BlockSym,
kFp8StaticTensorSym, kFp8StaticTensorSym,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
...@@ -658,38 +659,36 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -658,38 +659,36 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.weight_scale_name = ( self.weight_scale_name = (
"weight_scale_inv" if self.block_quant else "weight_scale" "weight_scale_inv" if self.block_quant else "weight_scale"
) )
self.fp8_backend = select_fp8_moe_backend(
block_quant=self.block_quant,
tp_size=layer.moe_parallel_config.tp_size,
with_lora_support=self.moe.is_lora_enabled,
is_act_and_mul=self.moe.is_act_and_mul,
)
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: # Set weight key and activation key for kernel compatibility
if self.block_quant and self.weight_block_size != [128, 128]: if self.block_quant:
raise NotImplementedError( weight_key = kFp8Static128BlockSym
"FlashInfer CUTLASS FP8 MoE backend only supports block " activation_key = kFp8Dynamic128Sym
"size [128, 128]." else:
) weight_key = kFp8StaticTensorSym
if layer.activation != "silu": activation_key = (
raise NotImplementedError( kFp8StaticTensorSym
"FlashInfer CUTLASS FP8 MoE backend only supports SiLU " if self.quant_config.activation_scheme == "static"
"activation function, but got {layer.activation}." else kFp8DynamicTensorSym
)
dynamic_per_token = (
not self.block_quant and self.quant_config.activation_scheme != "static"
)
if dynamic_per_token and self.fp8_backend in [
Fp8MoeBackend.FLASHINFER_TRTLLM,
Fp8MoeBackend.FLASHINFER_CUTLASS,
]:
raise NotImplementedError(
"FlashInfer FP8 MoE backend does not support dynamic per token "
"activation quantization."
) )
# Select Fp8 MoE backend
self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
config=self.moe,
weight_key=weight_key,
activation_key=activation_key,
allow_vllm_cutlass=False,
)
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights( def create_weights(
self, self,
layer: Module, layer: Module,
...@@ -842,14 +841,21 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -842,14 +841,21 @@ class Fp8MoEMethod(FusedMoEMethodBase):
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale) replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale) replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
# Setup modular kernel for TP case. # Setup modular kernel for TP case and naive DP/EP case.
# In non-naive DP/EP case, we will create a ModularKernelMethod.
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: if self.moe_quant_config and (
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None
self.kernel, self.use_inplace = make_fp8_moe_kernel( self.kernel, self.use_inplace = make_fp8_moe_kernel(
layer=layer,
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
) )
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
...@@ -904,13 +910,13 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -904,13 +910,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
if self.fp8_backend in [ if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
Fp8MoeBackend.AITER,
Fp8MoeBackend.MARLIN,
Fp8MoeBackend.FLASHINFER_TRTLLM,
]:
return None return None
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
# For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe, self.moe,
use_deepseek_fp8_block_scale=self.block_quant, use_deepseek_fp8_block_scale=self.block_quant,
...@@ -924,73 +930,14 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -924,73 +930,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute: ) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import (
BatchedDeepGemmExperts,
BatchedTritonExperts,
TritonExperts,
TritonOrDeepGemmExperts,
)
if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
raise NotImplementedError(
"Marlin and ROCm AITER are not supported with all2all yet."
)
assert self.moe_quant_config is not None assert self.moe_quant_config is not None
assert self.experts_cls is not None
if ( return make_fp8_moe_kernel_for_mkm(
prepare_finalize.activation_format moe_config=self.moe,
== FusedMoEActivationFormat.BatchedExperts quant_config=self.moe_quant_config,
): experts_cls=self.experts_cls,
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() prepare_finalize=prepare_finalize,
assert max_num_tokens_per_rank is not None )
experts_impl = (
BatchedDeepGemmExperts
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM
else BatchedTritonExperts
)
logger.debug(
"%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
experts_impl.__name__,
self.__class__.__name__,
max_num_tokens_per_rank,
self.weight_block_size,
False,
)
return experts_impl(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
)
elif self.moe.is_lora_enabled:
return TritonExperts(quant_config=self.moe_quant_config)
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
# Select GEMM experts with block-scale when weights are block-quantized
experts = select_cutlass_fp8_gemm_impl(
self.moe,
self.moe_quant_config,
use_deepseek_fp8_block_scale=self.block_quant,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
elif self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
logger.debug(
"TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__,
self.weight_block_size,
False,
)
return TritonOrDeepGemmExperts(self.moe_quant_config)
else:
assert self.fp8_backend == Fp8MoeBackend.TRITON
logger.debug(
"TritonExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__,
self.weight_block_size,
False,
)
return TritonExperts(self.moe_quant_config)
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module self, layer: torch.nn.Module
...@@ -1067,7 +1014,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1067,7 +1014,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routed_scaling=layer.routed_scaling_factor, routed_scaling=layer.routed_scaling_factor,
) )
else: else:
result = apply_fi_trtllm_fp8_per_tensor_moe( return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer, layer=layer,
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
......
...@@ -875,6 +875,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -875,6 +875,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
return BatchedMarlinExperts( return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank, max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
moe_config=self.moe,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx, w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx, w2_g_idx=w2_g_idx,
...@@ -885,6 +886,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -885,6 +886,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
else: else:
# Standard Marlin experts for GPTQ # Standard Marlin experts for GPTQ
return MarlinExperts( return MarlinExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx, w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx, w2_g_idx=w2_g_idx,
......
...@@ -27,15 +27,16 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( ...@@ -27,15 +27,16 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend, Fp8MoeBackend,
convert_to_fp8_moe_kernel_format, convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel, make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config, make_fp8_moe_quant_config,
select_fp8_moe_backend, select_fp8_moe_backend,
) )
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
FLASHINFER_NVFP4_MOE_BACKENDS,
NvFp4MoeBackend, NvFp4MoeBackend,
convert_to_nvfp4_moe_kernel_format, convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend, is_global_sf_supported_for_nvfp4_backend,
make_nvfp4_moe_kernel, make_nvfp4_moe_kernel,
make_nvfp4_moe_kernel_for_mkm,
make_nvfp4_moe_quant_config, make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend, select_nvfp4_moe_backend,
) )
...@@ -57,12 +58,10 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( ...@@ -57,12 +58,10 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize, build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe, flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe, flashinfer_trtllm_fp4_routed_moe,
select_nvfp4_gemm_impl,
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe, apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize, build_flashinfer_fp8_cutlass_moe_prepare_finalize,
select_cutlass_fp8_gemm_impl,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, W8A8BlockFp8LinearOp,
...@@ -84,6 +83,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -84,6 +83,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8DynamicTokenSym, kFp8DynamicTokenSym,
kFp8StaticTensorSym, kFp8StaticTensorSym,
kFp8StaticTokenSym, kFp8StaticTokenSym,
kNvfp4Dynamic,
kNvfp4Static,
swizzle_blockscale, swizzle_blockscale,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
...@@ -728,14 +729,23 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -728,14 +729,23 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
super().__init__(moe_config) super().__init__(moe_config)
self.quant_config = quant_config self.quant_config = quant_config
assert self.quant_config.is_checkpoint_fp8_serialized assert self.quant_config.is_checkpoint_fp8_serialized
self.fp8_backend = select_fp8_moe_backend(
block_quant=False, # Select Fp8 MoE backend
tp_size=moe_config.moe_parallel_config.tp_size, self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
with_lora_support=self.moe.is_lora_enabled, config=self.moe,
is_act_and_mul=self.moe.is_act_and_mul, weight_key=kFp8StaticTensorSym,
activation_key=kFp8StaticTensorSym,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
...@@ -744,8 +754,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -744,8 +754,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None return None
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
# TP case: avoid convert to ModularKernelMethod - to be refactored. # For no-EP case, don't use the MKM framework.
if self.moe.dp_size == 1: if not self.moe.moe_parallel_config.use_all2all_kernels:
return None return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
...@@ -762,12 +772,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -762,12 +772,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None assert self.moe_quant_config is not None
experts = select_cutlass_fp8_gemm_impl( assert self.experts_cls is not None
self.moe, return make_fp8_moe_kernel_for_mkm(
self.moe_quant_config, moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def create_weights( def create_weights(
self, self,
...@@ -876,14 +887,15 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -876,14 +887,15 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
replace_parameter(layer, "w13_weight_scale", w13_scale) replace_parameter(layer, "w13_weight_scale", w13_scale)
replace_parameter(layer, "w2_weight_scale", w2_scale) replace_parameter(layer, "w2_weight_scale", w2_scale)
# Setup modular kernel for TP case. # Setup modular kernel.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: if self.moe_quant_config:
assert self.experts_cls is not None
self.kernel, self.use_inplace = make_fp8_moe_kernel( self.kernel, self.use_inplace = make_fp8_moe_kernel(
layer=layer,
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
) )
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
...@@ -1335,32 +1347,35 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1335,32 +1347,35 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
) -> None: ) -> None:
super().__init__(moe_config) super().__init__(moe_config)
self.quant_config = quant_config self.quant_config = quant_config
self.nvfp4_backend = select_nvfp4_moe_backend() # Select experts implementation.
# TODO: move this type of check into the oracle. self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend(
if not self.moe.is_act_and_mul and self.nvfp4_backend not in [ config=self.moe,
NvFp4MoeBackend.FLASHINFER_CUTLASS, weight_key=kNvfp4Static,
NvFp4MoeBackend.MARLIN, activation_key=kNvfp4Dynamic,
]: )
raise NotImplementedError(
"Non-gated activations are only supported by FlashInfer " # Delay creation of the kernel until after process-weights.
f"CUTLASS and Marlin NvFP4 MoE backends, not {self.nvfp4_backend}." self.kernel: mk.FusedMoEModularKernel | None = None
)
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend self.nvfp4_backend
) )
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM] if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
if self.nvfp4_backend in UNSUPPORTED:
return None return None
elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
# TP case: avoid convert to ModularKernelMethod - to be refactored. # For no-EP case, don't use the MKM framework.
if self.moe.dp_size == 1: if not self.moe.moe_parallel_config.use_all2all_kernels:
return None return None
# For now, fp4 moe only works with the flashinfer dispatcher. # For now, fp4 moe only works with the flashinfer dispatcher.
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
...@@ -1377,13 +1392,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1377,13 +1392,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None assert self.moe_quant_config is not None
experts = select_nvfp4_gemm_impl( assert self.experts_cls is not None
self.moe, return make_nvfp4_moe_kernel_for_mkm(
self.moe_quant_config, moe_config=self.moe,
allow_flashinfer=self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS, quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def uses_weight_scale_2_pattern(self) -> bool: def uses_weight_scale_2_pattern(self) -> bool:
""" """
...@@ -1554,13 +1569,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1554,13 +1569,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
replace_parameter(layer, "w2_weight_scale_2", w2_scale_2) replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
replace_parameter(layer, "w2_input_scale", a2_scale) replace_parameter(layer, "w2_input_scale", a2_scale)
# Setup modular kernel for TP case and naive DP/EP case.
# In non-naive DP/EP case, we will create a ModularKernelMethod.
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
use_dp = self.moe.dp_size > 1 if self.moe_quant_config and (
if self.moe_quant_config is not None and not use_dp: (not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None
self.kernel = make_nvfp4_moe_kernel( self.kernel = make_nvfp4_moe_kernel(
backend=self.nvfp4_backend, moe_quant_config=self.moe_quant_config,
quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
experts_cls=self.experts_cls,
) )
@property @property
......
...@@ -853,6 +853,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -853,6 +853,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
max_num_tokens=max_num_tokens_per_rank, max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
moe_config=self.moe,
) )
else: else:
raise NotImplementedError( raise NotImplementedError(
...@@ -875,11 +876,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -875,11 +876,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
} }
return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs) return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs)
elif self.mxfp4_backend == Mxfp4Backend.MARLIN: elif self.mxfp4_backend == Mxfp4Backend.MARLIN:
return MarlinExperts(self.moe_quant_config) return MarlinExperts(self.moe, self.moe_quant_config)
elif self.mxfp4_backend == Mxfp4Backend.TRITON: elif self.mxfp4_backend == Mxfp4Backend.TRITON:
if self.moe.is_lora_enabled: if self.moe.is_lora_enabled:
return UnfusedOAITritonExperts(self.moe_quant_config) return UnfusedOAITritonExperts(self.moe, self.moe_quant_config)
return OAITritonExperts(self.moe_quant_config) return OAITritonExperts(self.moe, self.moe_quant_config)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP" f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP"
......
...@@ -11,19 +11,16 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk ...@@ -11,19 +11,16 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEParallelConfig,
RoutingMethodType, RoutingMethodType,
) )
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
FlashInferCuteDSLExperts,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize, create_flashinfer_prepare_finalize,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Dynamic,
kNvfp4Static,
swizzle_blockscale, swizzle_blockscale,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -47,6 +44,86 @@ __all__ = [ ...@@ -47,6 +44,86 @@ __all__ = [
"build_flashinfer_fp4_cutlass_moe_prepare_finalize", "build_flashinfer_fp4_cutlass_moe_prepare_finalize",
] ]
#
# Methods used by the oracle for kernel selection.
#
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)
def _supports_no_act_and_mul() -> bool:
"""Does not support non-gated MoE (i.e. Nemotron-Nano)."""
return False
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Nvfp4 quantization."""
SUPPORTED_W_A = [
(kNvfp4Static, kNvfp4Dynamic),
]
return (weight_key, activation_key) in SUPPORTED_W_A
def _supports_activation(activation: str) -> bool:
"""Supports silu activation only."""
return activation in ["silu"]
def _supports_routing_method(
routing_method: RoutingMethodType,
) -> bool:
"""Monolithic kernels need to express router support."""
# NOTE(rob): potentially allow others here. This is a conservative list.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
RoutingMethodType.Llama4,
]
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""Supports EP."""
return True
def is_supported_config_trtllm(
moe_config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[bool, str | None]:
"""
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
"""
def _make_reason(reason: str) -> str:
return f"kernel does not support {reason}"
if not _supports_current_device():
return False, _make_reason("current device")
elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
return False, _make_reason("no act_and_mul MLP layer")
elif not _supports_activation(moe_config.activation):
return False, _make_reason(f"{moe_config.activation} activation")
elif not _supports_quant_scheme(weight_key, activation_key):
return False, _make_reason("quantization scheme")
elif not _supports_parallel_config(moe_config.moe_parallel_config):
return False, _make_reason("parallel config")
elif not _supports_routing_method(moe_config.routing_method):
return False, _make_reason("routing method")
elif activation_format != mk.FusedMoEActivationFormat.Standard:
return False, _make_reason("activation format")
return True, None
def is_flashinfer_fp4_cutlass_moe_available() -> bool: def is_flashinfer_fp4_cutlass_moe_available() -> bool:
"""Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used.""" """Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
...@@ -96,37 +173,6 @@ def build_flashinfer_fp4_cutlass_moe_prepare_finalize( ...@@ -96,37 +173,6 @@ def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
) )
def select_nvfp4_gemm_impl(
moe: FusedMoEConfig,
moe_quant_config: FusedMoEQuantConfig,
allow_flashinfer: bool,
) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
if allow_flashinfer:
if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm":
return FlashInferCuteDSLExperts(
out_dtype=moe.in_dtype,
quant_config=moe_quant_config,
)
elif envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput":
return FlashInferExperts(
out_dtype=moe.in_dtype,
quant_config=moe_quant_config,
ep_rank=moe.moe_parallel_config.ep_rank,
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,
tp_size=moe.moe_parallel_config.tp_size,
use_dp=moe.moe_parallel_config.dp_size > 1,
)
# native cutlass experts currently don't support DP; TP case won't call this
raise ValueError(
"CutlassExpertsFp4 doesn't support DP. Use flashinfer CUTLASS "
"Fused MoE backend instead (set VLLM_USE_FLASHINFER_MOE_FP4=1)"
)
def prepare_static_weights_for_trtllm_fp4_moe( def prepare_static_weights_for_trtllm_fp4_moe(
# args_dequant, # args_dequant,
# args, # args,
......
...@@ -9,10 +9,6 @@ from vllm import envs ...@@ -9,10 +9,6 @@ from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
) )
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize, create_flashinfer_prepare_finalize,
...@@ -203,33 +199,6 @@ def build_flashinfer_fp8_cutlass_moe_prepare_finalize( ...@@ -203,33 +199,6 @@ def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
) )
def select_cutlass_fp8_gemm_impl(
moe: FusedMoEConfig | None,
quant_config: FusedMoEQuantConfig,
out_dtype: torch.dtype | None = None,
use_deepseek_fp8_block_scale: bool = False,
) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return a GEMM *experts* implementation for fused-MoE layers"""
if moe is not None:
return FlashInferExperts(
out_dtype=moe.in_dtype,
quant_config=quant_config,
ep_rank=moe.moe_parallel_config.ep_rank,
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,
tp_size=moe.moe_parallel_config.tp_size,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
)
assert out_dtype is not None, "If moe config is None, out_dtype must be passed"
return FlashInferExperts(
out_dtype=out_dtype,
quant_config=quant_config,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
)
def get_flashinfer_moe_backend() -> FlashinferMoeBackend: def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
backend_map = { backend_map = {
"throughput": FlashinferMoeBackend.CUTLASS, "throughput": FlashinferMoeBackend.CUTLASS,
......
...@@ -48,6 +48,7 @@ class GroupShape(_GroupShape): ...@@ -48,6 +48,7 @@ class GroupShape(_GroupShape):
# Aliases for common quantization group shapes # Aliases for common quantization group shapes
PER_TENSOR: ClassVar["GroupShape"] PER_TENSOR: ClassVar["GroupShape"]
PER_TOKEN: ClassVar["GroupShape"] PER_TOKEN: ClassVar["GroupShape"]
PER_CHANNEL: ClassVar["GroupShape"]
def is_per_tensor(self) -> bool: def is_per_tensor(self) -> bool:
return self.row == -1 and self.col == -1 return self.row == -1 and self.col == -1
...@@ -55,12 +56,16 @@ class GroupShape(_GroupShape): ...@@ -55,12 +56,16 @@ class GroupShape(_GroupShape):
def is_per_token(self) -> bool: def is_per_token(self) -> bool:
return self.row == 1 and self.col == -1 return self.row == 1 and self.col == -1
def is_per_channel(self) -> bool:
return self.row == -1 and self.col == 1
def is_per_group(self) -> bool: def is_per_group(self) -> bool:
return self.row == 1 and self.col >= 1 return self.row == 1 and self.col >= 1
GroupShape.PER_TENSOR = GroupShape(-1, -1) GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1) GroupShape.PER_TOKEN = GroupShape(1, -1)
GroupShape.PER_CHANNEL = GroupShape(-1, 1)
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -77,16 +82,12 @@ class ScaleDesc: ...@@ -77,16 +82,12 @@ class ScaleDesc:
group_shape: GroupShape group_shape: GroupShape
def __str__(self): def __str__(self):
group_shape = ( d = {
"per_tensor" GroupShape.PER_TENSOR: "per_tensor",
if self.group_shape == GroupShape.PER_TENSOR GroupShape.PER_TOKEN: "per_token",
else ( GroupShape.PER_CHANNEL: "per_channel",
"per_token" }
if self.group_shape == GroupShape.PER_TOKEN group_shape = d.get(self.group_shape, str(self.group_shape))
else str(self.group_shape)
)
)
return ( return (
f"{fx.graph.dtype_abbrs[self.dtype]}," f"{fx.graph.dtype_abbrs[self.dtype]},"
f"{'static' if self.static else 'dynamic'},{group_shape}" f"{'static' if self.static else 'dynamic'},{group_shape}"
...@@ -126,15 +127,28 @@ kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True) ...@@ -126,15 +127,28 @@ kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True)
kStaticTokenScale = ScaleDesc(torch.float32, True, GroupShape.PER_TOKEN) kStaticTokenScale = ScaleDesc(torch.float32, True, GroupShape.PER_TOKEN)
kFp8StaticTokenSym = QuantKey(FP8_DTYPE, kStaticTokenScale, symmetric=True) kFp8StaticTokenSym = QuantKey(FP8_DTYPE, kStaticTokenScale, symmetric=True)
kStaticChannelScale = ScaleDesc(torch.float32, True, GroupShape.PER_CHANNEL)
kFp8StaticChannelSym = QuantKey(FP8_DTYPE, kStaticChannelScale, symmetric=True)
kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN) kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True) kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)
kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16)) kNvfp4DynamicGroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16))
kNvfp4Quant = QuantKey(FP4_DTYPE, scale=kNvfp4GroupScale, scale2=kStaticTensorScale) kNvfp4Dynamic = QuantKey(
FP4_DTYPE, scale=kNvfp4DynamicGroupScale, scale2=kStaticTensorScale
)
kNvfp4StaticGroupScale = ScaleDesc(FP8_DTYPE, True, GroupShape(1, 16))
kNvfp4Static = QuantKey(
FP4_DTYPE, scale=kNvfp4StaticGroupScale, scale2=kStaticTensorScale
)
kDynamic128Scale = ScaleDesc(torch.float32, False, GroupShape(1, 128)) kDynamic128Scale = ScaleDesc(torch.float32, False, GroupShape(1, 128))
kFp8Dynamic128Sym = QuantKey(FP8_DTYPE, kDynamic128Scale, symmetric=True) kFp8Dynamic128Sym = QuantKey(FP8_DTYPE, kDynamic128Scale, symmetric=True)
kStatic128BlockScale = ScaleDesc(torch.float32, True, GroupShape(128, 128))
kFp8Static128BlockSym = QuantKey(FP8_DTYPE, kStatic128BlockScale, symmetric=True)
kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64)) kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64))
kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True) kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True)
......
...@@ -43,7 +43,6 @@ from vllm.distributed import ( ...@@ -43,7 +43,6 @@ from vllm.distributed import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -172,7 +171,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -172,7 +171,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts, num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel, is_sequence_parallel=self.is_sequence_parallel,
routing_method_type=RoutingMethodType.Renormalize,
) )
self.gate = ReplicatedLinear( self.gate = ReplicatedLinear(
......
...@@ -34,7 +34,6 @@ from vllm.model_executor.layers.fla.ops import ( ...@@ -34,7 +34,6 @@ from vllm.model_executor.layers.fla.ops import (
fused_recurrent_gated_delta_rule, fused_recurrent_gated_delta_rule,
) )
from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.layernorm import ( from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3NextRMSNorm, GemmaRMSNorm as Qwen3NextRMSNorm,
) )
...@@ -181,7 +180,6 @@ class Qwen3NextSparseMoeBlock(nn.Module): ...@@ -181,7 +180,6 @@ class Qwen3NextSparseMoeBlock(nn.Module):
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts, num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel, is_sequence_parallel=self.is_sequence_parallel,
routing_method_type=RoutingMethodType.Renormalize,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment