Unverified Commit 42135d68 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor] Oracle Select FP8+NVFP4 Kernels In Priority (#32414)

parent e14467be
...@@ -18,7 +18,7 @@ from vllm.logger import init_logger ...@@ -18,7 +18,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kFp8StaticTensorSym, kFp8StaticTensorSym,
kNvfp4Quant, kNvfp4Dynamic,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -41,7 +41,7 @@ silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr( ...@@ -41,7 +41,7 @@ silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr(
torch.ops._C, "silu_and_mul_nvfp4_quant" torch.ops._C, "silu_and_mul_nvfp4_quant"
) )
if silu_and_mul_nvfp4_quant_supported: if silu_and_mul_nvfp4_quant_supported:
FUSED_OPS[kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501 FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
class ActivationQuantPattern(ABC): class ActivationQuantPattern(ABC):
...@@ -129,7 +129,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern): ...@@ -129,7 +129,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
""" """
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(kNvfp4Quant) super().__init__(kNvfp4Dynamic)
def get_inputs(self) -> list[torch.Tensor]: def get_inputs(self) -> list[torch.Tensor]:
result = self.empty_quant(5, 32) result = self.empty_quant(5, 32)
......
...@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8DynamicTensorSym, kFp8DynamicTensorSym,
kFp8DynamicTokenSym, kFp8DynamicTokenSym,
kFp8StaticTensorSym, kFp8StaticTensorSym,
kNvfp4Quant, kNvfp4Dynamic,
kStaticTensorScale, kStaticTensorScale,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -63,7 +63,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = { ...@@ -63,7 +63,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
} }
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default
if current_platform.is_cuda(): if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
......
...@@ -16,7 +16,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config ...@@ -16,7 +16,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kNvfp4Quant, kNvfp4Dynamic,
kStaticTensorScale, kStaticTensorScale,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -217,7 +217,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): ...@@ -217,7 +217,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
""" """
def __init__(self, layer: Attention, dtype: torch.dtype) -> None: def __init__(self, layer: Attention, dtype: torch.dtype) -> None:
super().__init__(layer, kNvfp4Quant, dtype) super().__init__(layer, kNvfp4Dynamic, dtype)
def _register(self, pm_pass: PatternMatcherPass) -> None: def _register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
......
...@@ -21,7 +21,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -21,7 +21,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8DynamicTensorSym, kFp8DynamicTensorSym,
kFp8DynamicTokenSym, kFp8DynamicTokenSym,
kFp8StaticTensorSym, kFp8StaticTensorSym,
kNvfp4Quant, kNvfp4Dynamic,
) )
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -38,7 +38,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = { ...@@ -38,7 +38,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
} }
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
if current_platform.is_cuda(): if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
......
...@@ -7,11 +7,20 @@ import torch ...@@ -7,11 +7,20 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceDelegate,
) )
from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kFp8Static128BlockSym,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
...@@ -19,6 +28,7 @@ from vllm.utils.deep_gemm import ( ...@@ -19,6 +28,7 @@ from vllm.utils.deep_gemm import (
fp8_m_grouped_gemm_nt_masked, fp8_m_grouped_gemm_nt_masked,
get_mk_alignment_for_contiguous_layout, get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used, is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
) )
from vllm.utils.math_utils import cdiv, round_up from vllm.utils.math_utils import cdiv, round_up
...@@ -253,29 +263,52 @@ def persistent_masked_m_silu_mul_quant( ...@@ -253,29 +263,52 @@ def persistent_masked_m_silu_mul_quant(
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int, max_num_tokens: int,
num_dispatchers: int, num_dispatchers: int,
quant_config: FusedMoEQuantConfig,
): ):
""" """
max_num_tokens: Maximum number of tokens from a DP Rank max_num_tokens: Maximum number of tokens from a DP Rank
num_dispatchers: The number of DP dispatchers. num_dispatchers: The number of DP dispatchers.
quant_config: Quantization configuration quant_config: Quantization configuration
""" """
super().__init__(quant_config) super().__init__(
moe_config=moe_config,
quant_config=quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
)
assert self.block_shape == get_mk_alignment_for_contiguous_layout() assert self.block_shape == get_mk_alignment_for_contiguous_layout()
assert self.quant_config.use_fp8_w8a8 assert self.quant_config.use_fp8_w8a8
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
@property @staticmethod
def activation_formats( def activation_format() -> mk.FusedMoEActivationFormat:
self, return mk.FusedMoEActivationFormat.BatchedExperts
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return ( @staticmethod
mk.FusedMoEActivationFormat.BatchedExperts, def _supports_current_device() -> bool:
mk.FusedMoEActivationFormat.BatchedExperts, return is_deep_gemm_supported()
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [(kFp8Static128BlockSym, kFp8Dynamic128Sym)]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu"]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
return False return False
...@@ -310,6 +343,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -310,6 +343,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# FIXME (varun): We should be able to dispatch only from the leader # FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks # DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed. # end up sending their tokens. This needs to be fixed.
assert self.num_dispatchers is not None
assert self.max_num_tokens is not None
num_dispatchers = self.num_dispatchers num_dispatchers = self.num_dispatchers
num_experts = local_num_experts num_experts = local_num_experts
max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens
......
...@@ -862,6 +862,7 @@ class FusedMoEParallelConfig: ...@@ -862,6 +862,7 @@ class FusedMoEParallelConfig:
use_ep: bool # whether to use EP or not use_ep: bool # whether to use EP or not
all2all_backend: str # all2all backend for MoE communication all2all_backend: str # all2all backend for MoE communication
enable_eplb: bool # whether to enable expert load balancing
@property @property
def use_all2all_kernels(self): def use_all2all_kernels(self):
...@@ -882,6 +883,16 @@ class FusedMoEParallelConfig: ...@@ -882,6 +883,16 @@ class FusedMoEParallelConfig:
def use_deepep_ll_kernels(self): def use_deepep_ll_kernels(self):
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency" return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
@property
def use_batched_activation_format(self):
return self.use_deepep_ll_kernels or self.use_pplx_kernels
@property
def use_naive_all2all_kernels(self):
return self.use_all2all_kernels and (
self.all2all_backend in ["naive", "allgather_reducescatter"]
)
@staticmethod @staticmethod
def flatten_tp_across_dp_and_pcp( def flatten_tp_across_dp_and_pcp(
tp_size: int, dp_size: int, dp_rank: int, pcp_size: int, pcp_rank: int tp_size: int, dp_size: int, dp_rank: int, pcp_size: int, pcp_rank: int
...@@ -999,6 +1010,7 @@ class FusedMoEParallelConfig: ...@@ -999,6 +1010,7 @@ class FusedMoEParallelConfig:
ep_rank=0, ep_rank=0,
use_ep=False, use_ep=False,
all2all_backend=vllm_parallel_config.all2all_backend, all2all_backend=vllm_parallel_config.all2all_backend,
enable_eplb=vllm_parallel_config.enable_eplb,
) )
# DP + EP / TP + EP / DP + TP + EP # DP + EP / TP + EP / DP + TP + EP
assert use_ep assert use_ep
...@@ -1017,6 +1029,24 @@ class FusedMoEParallelConfig: ...@@ -1017,6 +1029,24 @@ class FusedMoEParallelConfig:
ep_rank=ep_rank, ep_rank=ep_rank,
use_ep=True, use_ep=True,
all2all_backend=vllm_parallel_config.all2all_backend, all2all_backend=vllm_parallel_config.all2all_backend,
enable_eplb=vllm_parallel_config.enable_eplb,
)
@classmethod
def make_no_parallel(cls) -> "FusedMoEParallelConfig":
"""For usage in CI/CD and testing."""
return FusedMoEParallelConfig(
tp_size=1,
tp_rank=0,
pcp_size=1,
pcp_rank=0,
dp_size=1,
dp_rank=0,
ep_size=1,
ep_rank=0,
use_ep=False,
all2all_backend="naive",
enable_eplb=False,
) )
...@@ -1026,8 +1056,11 @@ class FusedMoEConfig: ...@@ -1026,8 +1056,11 @@ class FusedMoEConfig:
num_experts: int num_experts: int
experts_per_token: int experts_per_token: int
hidden_dim: int hidden_dim: int
intermediate_size_per_partition: int
num_local_experts: int num_local_experts: int
activation: str
device: torch.device | str
routing_method: RoutingMethodType
moe_parallel_config: FusedMoEParallelConfig moe_parallel_config: FusedMoEParallelConfig
# The activation type. # The activation type.
......
...@@ -7,7 +7,11 @@ import torch ...@@ -7,7 +7,11 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_permute, moe_permute,
moe_unpermute, moe_unpermute,
...@@ -23,6 +27,19 @@ from vllm.model_executor.layers.fused_moe.utils import ( ...@@ -23,6 +27,19 @@ from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, _resize_cache,
apply_moe_activation, apply_moe_activation,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
kNvfp4Dynamic,
kNvfp4Static,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_group_gemm_supported,
)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -238,29 +255,57 @@ def run_cutlass_moe_fp8( ...@@ -238,29 +255,57 @@ def run_cutlass_moe_fp8(
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
e: int, moe_config: FusedMoEConfig,
n: int,
k: int,
out_dtype: torch.dtype | None,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
device: torch.dtype, max_num_tokens: int | None = None,
num_dispatchers: int | None = None,
): ):
super().__init__(
moe_config=moe_config,
quant_config=quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
)
assert quant_config.use_fp8_w8a8 assert quant_config.use_fp8_w8a8
super().__init__(quant_config)
# E: num_experts e = moe_config.num_local_experts
# N: intermediate size per partition n = moe_config.intermediate_size_per_partition
# K: hidden dim k = moe_config.hidden_dim
device = moe_config.device
ab_strides1_c_strides2 = torch.full((e,), k, device=device, dtype=torch.int64) ab_strides1_c_strides2 = torch.full((e,), k, device=device, dtype=torch.int64)
ab_strides2 = torch.full((e,), n, device=device, dtype=torch.int64) ab_strides2 = torch.full((e,), n, device=device, dtype=torch.int64)
c_strides1 = torch.full((e,), 2 * n, device=device, dtype=torch.int64) c_strides1 = torch.full((e,), 2 * n, device=device, dtype=torch.int64)
self.out_dtype = out_dtype self.out_dtype = moe_config.in_dtype
self.ab_strides1 = ab_strides1_c_strides2 self.ab_strides1 = ab_strides1_c_strides2
self.ab_strides2 = ab_strides2 self.ab_strides2 = ab_strides2
self.c_strides1 = c_strides1 self.c_strides1 = c_strides1
self.c_strides2 = ab_strides1_c_strides2 self.c_strides2 = ab_strides1_c_strides2
@staticmethod
def _supports_current_device() -> bool:
return cutlass_group_gemm_supported()
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu", "gelu", "swigluoai"]
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl. # Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate() return TopKWeightAndReduceDelegate()
...@@ -291,7 +336,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -291,7 +336,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
expert_num_tokens = expert_tokens_meta.expert_num_tokens expert_num_tokens = expert_tokens_meta.expert_num_tokens
use_batched_format = ( use_batched_format = (
self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts self.activation_format() == mk.FusedMoEActivationFormat.BatchedExperts
) )
in_dtype = hidden_states.dtype in_dtype = hidden_states.dtype
...@@ -324,20 +369,23 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -324,20 +369,23 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
class CutlassExpertsFp8(CutlassExpertsFp8Base): class CutlassExpertsFp8(CutlassExpertsFp8Base):
@property @staticmethod
def activation_formats( def activation_format() -> mk.FusedMoEActivationFormat:
self, return mk.FusedMoEActivationFormat.Standard
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return ( @staticmethod
mk.FusedMoEActivationFormat.Standard, def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
mk.FusedMoEActivationFormat.Standard, # CutlassExpertsFp8 does not support expert map, which is
) # needed for STANDARD activation format kernels in DP/EP mode.
# Note that the BATCHED activation format does not use
# the expert map for identifying experts.
return not moe_parallel_config.use_all2all_kernels
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
return True return True
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return True return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# topk weights and reduction are fused in moe_unpermute cuda kernel # topk weights and reduction are fused in moe_unpermute cuda kernel
...@@ -365,26 +413,16 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): ...@@ -365,26 +413,16 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
def __init__( @staticmethod
self, def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
max_experts_per_worker: int, # BATCHED activation format works with EP because
num_dispatchers: int, # expert_map is not used to identify experts (the
*args, # info is encoded/managed by the P/F logic).
**kwargs, return True
):
super().__init__(*args, **kwargs)
assert max_experts_per_worker > 0
self.max_experts_per_worker = max_experts_per_worker
self.num_dispatchers = num_dispatchers
@property @staticmethod
def activation_formats( def activation_format() -> mk.FusedMoEActivationFormat:
self, return mk.FusedMoEActivationFormat.BatchedExperts
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts,
)
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
return False return False
...@@ -408,14 +446,15 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): ...@@ -408,14 +446,15 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dp = self.num_dispatchers num_dp = self.num_dispatchers
assert num_dp is not None assert num_dp is not None
experts_per_worker = self.moe_config.num_local_experts
activation_out_dim = self.adjust_N_for_activation(N, activation) activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K)) workspace1 = (experts_per_worker, M * num_dp, max(N, K))
workspace2 = ( workspace2 = (
self.max_experts_per_worker, experts_per_worker,
M * num_dp, M * num_dp,
max(activation_out_dim, K), max(activation_out_dim, K),
) )
output = (self.max_experts_per_worker, M, K) output = (experts_per_worker, M, K)
return (workspace1, workspace2, output) return (workspace1, workspace2, output)
...@@ -601,34 +640,41 @@ def run_cutlass_moe_fp4( ...@@ -601,34 +640,41 @@ def run_cutlass_moe_fp4(
return return
# Split into batched and non-batched
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( @staticmethod
self, def expects_unquantized_inputs(
max_experts_per_worker: int, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig
out_dtype: torch.dtype, ) -> bool:
quant_config: FusedMoEQuantConfig, return True
use_batched_format: bool = False,
):
super().__init__(quant_config)
self.max_experts_per_worker = max_experts_per_worker
self.out_dtype = out_dtype
self.use_batched_format = use_batched_format
@property @staticmethod
def activation_formats( def _supports_current_device() -> bool:
self, return current_platform.has_device_capability((10, 0))
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
if self.use_batched_format: @staticmethod
return ( def _supports_no_act_and_mul() -> bool:
mk.FusedMoEActivationFormat.BatchedExperts, return False
mk.FusedMoEActivationFormat.BatchedExperts,
) @staticmethod
else: def _supports_quant_scheme(
return ( weight_key: QuantKey | None,
mk.FusedMoEActivationFormat.Standard, activation_key: QuantKey | None,
mk.FusedMoEActivationFormat.Standard, ) -> bool:
) return (weight_key, activation_key) == (kNvfp4Static, kNvfp4Dynamic)
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu", "gelu", "swigluoai"]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
# CutlassExpertsFp4 does not support expert map, which is
# needed for STANDARD activation format kernels in EP mode.
return moe_parallel_config.ep_size == 1
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return False return False
...@@ -640,7 +686,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -640,7 +686,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
return TopKWeightAndReduceNoOP() return TopKWeightAndReduceNoOP()
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
return self.out_dtype if self.out_dtype is not None else act_dtype return act_dtype
def workspace_shapes( def workspace_shapes(
self, self,
...@@ -653,18 +699,9 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -653,18 +699,9 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
activation_out_dim = self.adjust_N_for_activation(N, activation) workspace1 = (M * topk, max(2 * N, K))
workspace1: tuple[int, ...] = () workspace2 = (M * topk, N)
workspace2: tuple[int, ...] = () output = (M, K)
output: tuple[int, ...] = ()
if self.use_batched_format:
workspace1 = (self.max_experts_per_worker, M, max(N, K))
workspace2 = (self.max_experts_per_worker, M, activation_out_dim)
output = (self.max_experts_per_worker, M, K)
else:
workspace1 = (M * topk, max(2 * N, K))
workspace2 = (M * topk, N)
output = (M, K)
return (workspace1, workspace2, output) return (workspace1, workspace2, output)
def apply( def apply(
...@@ -869,10 +906,11 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -869,10 +906,11 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
c_strides2: torch.Tensor, c_strides2: torch.Tensor,
s_strides1: torch.Tensor, s_strides1: torch.Tensor,
s_strides2: torch.Tensor, s_strides2: torch.Tensor,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
group_size: int, group_size: int,
): ):
super().__init__(quant_config) super().__init__(moe_config=moe_config, quant_config=quant_config)
self.out_dtype = out_dtype self.out_dtype = out_dtype
self.a_strides1 = a_strides1 self.a_strides1 = a_strides1
self.a_strides2 = a_strides2 self.a_strides2 = a_strides2
...@@ -884,13 +922,46 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -884,13 +922,46 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
self.s_strides2 = s_strides2 self.s_strides2 = s_strides2
self.group_size = group_size self.group_size = group_size
@property @staticmethod
def activation_formats( def activation_format() -> mk.FusedMoEActivationFormat:
self, return mk.FusedMoEActivationFormat.Standard
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return ( @staticmethod
mk.FusedMoEActivationFormat.Standard, def _supports_current_device() -> bool:
mk.FusedMoEActivationFormat.Standard, raise NotImplementedError(
"CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
raise NotImplementedError(
"CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
raise NotImplementedError(
"CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_activation(activation: str) -> bool:
raise NotImplementedError(
"CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
raise NotImplementedError(
"CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
"This method should not be called."
) )
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
...@@ -947,7 +1018,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -947,7 +1018,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
expert_num_tokens = None expert_num_tokens = None
use_batched_format = ( use_batched_format = (
self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts self.activation_format() == mk.FusedMoEActivationFormat.BatchedExperts
) )
assert not use_batched_format, "batched format not supported" assert not use_batched_format, "batched format not supported"
...@@ -1003,6 +1074,7 @@ def cutlass_moe_w4a8_fp8( ...@@ -1003,6 +1074,7 @@ def cutlass_moe_w4a8_fp8(
s_strides1: torch.Tensor, s_strides1: torch.Tensor,
s_strides2: torch.Tensor, s_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
activation: str = "silu", activation: str = "silu",
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
...@@ -1076,6 +1148,7 @@ def cutlass_moe_w4a8_fp8( ...@@ -1076,6 +1148,7 @@ def cutlass_moe_w4a8_fp8(
c_strides2=c_strides2, c_strides2=c_strides2,
s_strides1=s_strides1, s_strides1=s_strides1,
s_strides2=s_strides2, s_strides2=s_strides2,
moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
group_size=group_size, group_size=group_size,
), ),
......
...@@ -6,17 +6,15 @@ import torch ...@@ -6,17 +6,15 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
compute_aligned_M, compute_aligned_M,
deepgemm_moe_permute, deepgemm_moe_permute,
deepgemm_unpermute_and_reduce, deepgemm_unpermute_and_reduce,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
) )
...@@ -26,9 +24,15 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( ...@@ -26,9 +24,15 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8_packed_for_deepgemm, per_token_group_quant_fp8_packed_for_deepgemm,
silu_mul_per_token_group_quant_fp8_colmajor, silu_mul_per_token_group_quant_fp8_colmajor,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kFp8Static128BlockSym,
)
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT, DeepGemmQuantScaleFMT,
get_mk_alignment_for_contiguous_layout, get_mk_alignment_for_contiguous_layout,
is_deep_gemm_supported,
m_grouped_fp8_gemm_nt_contiguous, m_grouped_fp8_gemm_nt_contiguous,
) )
from vllm.utils.import_utils import has_deep_gemm from vllm.utils.import_utils import has_deep_gemm
...@@ -109,21 +113,42 @@ def _valid_deep_gemm( ...@@ -109,21 +113,42 @@ def _valid_deep_gemm(
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config: FusedMoEQuantConfig): def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
super().__init__(quant_config) super().__init__(moe_config=moe_config, quant_config=quant_config)
assert quant_config.block_shape == get_mk_alignment_for_contiguous_layout() assert quant_config.block_shape == get_mk_alignment_for_contiguous_layout()
assert quant_config.quant_dtype == torch.float8_e4m3fn assert quant_config.quant_dtype == torch.float8_e4m3fn
assert not quant_config.per_act_token_quant assert not quant_config.per_act_token_quant
assert not quant_config.per_out_ch_quant assert not quant_config.per_out_ch_quant
@property @staticmethod
def activation_formats( def activation_format() -> mk.FusedMoEActivationFormat:
self, return mk.FusedMoEActivationFormat.Standard
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return ( @staticmethod
mk.FusedMoEActivationFormat.Standard, def _supports_current_device() -> bool:
mk.FusedMoEActivationFormat.Standard, return is_deep_gemm_supported()
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu"]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
return True return True
...@@ -283,82 +308,3 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -283,82 +308,3 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_map=expert_map, expert_map=expert_map,
output=output, output=output,
) )
def deep_gemm_moe_fp8(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with DeepGemm
grouped gemm.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1 (torch.Tensor): The first set of fp8 quantized expert weights.
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2 (torch.Tensor): The second set of fp8 quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The token->expert mapping for topk_weights.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
Returns:
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
"""
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=get_mk_alignment_for_contiguous_layout(),
)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
DeepGemmExperts(quant_config),
)
return fn(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
...@@ -6,6 +6,8 @@ from abc import ABC, abstractmethod ...@@ -6,6 +6,8 @@ from abc import ABC, abstractmethod
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC): class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
...@@ -16,18 +18,78 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC): ...@@ -16,18 +18,78 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
experts: mk.FusedMoEPermuteExpertsUnpermute, experts: mk.FusedMoEPermuteExpertsUnpermute,
fallback_experts: mk.FusedMoEPermuteExpertsUnpermute, fallback_experts: mk.FusedMoEPermuteExpertsUnpermute,
): ):
super().__init__(experts.quant_config) super().__init__(
moe_config=experts.moe_config, quant_config=experts.quant_config
)
self.fallback_experts = fallback_experts self.fallback_experts = fallback_experts
self.experts = experts self.experts = experts
@property @staticmethod
def activation_formats( def get_clses() -> tuple[
self, type[mk.FusedMoEPermuteExpertsUnpermute],
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: type[mk.FusedMoEPermuteExpertsUnpermute],
assert ( ]:
self.fallback_experts.activation_formats == self.experts.activation_formats """
Get the cls for the experts and fallback experts.
Subclasses should implement this method, so that
we have a consistent way to call the _supports_*
class methods below.
"""
raise NotImplementedError(
"Subclasses must return the cls for the experts and fallback experts."
)
@classmethod
def activation_format(
cls: type["FallbackExperts"],
) -> mk.FusedMoEActivationFormat:
experts_cls, fallback_cls = cls.get_clses()
assert experts_cls.activation_format() == fallback_cls.activation_format()
return experts_cls.activation_format()
@classmethod
def _supports_current_device(cls) -> bool:
experts_cls, fallback_cls = cls.get_clses()
return (
experts_cls._supports_current_device()
and fallback_cls._supports_current_device()
)
@classmethod
def _supports_no_act_and_mul(cls) -> bool:
experts_cls, fallback_cls = cls.get_clses()
return (
experts_cls._supports_no_act_and_mul()
and fallback_cls._supports_no_act_and_mul()
) )
return self.fallback_experts.activation_formats
@classmethod
def _supports_quant_scheme(
cls,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
experts_cls, fallback_cls = cls.get_clses()
return experts_cls._supports_quant_scheme(
weight_key, activation_key
) and fallback_cls._supports_quant_scheme(weight_key, activation_key)
@classmethod
def _supports_activation(cls, activation: str) -> bool:
experts_cls, fallback_cls = cls.get_clses()
return experts_cls._supports_activation(
activation
) and fallback_cls._supports_activation(activation)
@classmethod
def _supports_parallel_config(
cls, moe_parallel_config: FusedMoEParallelConfig
) -> bool:
experts_cls, fallback_cls = cls.get_clses()
return experts_cls._supports_parallel_config(
moe_parallel_config
) and fallback_cls._supports_parallel_config(moe_parallel_config)
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
assert ( assert (
......
...@@ -6,13 +6,22 @@ import torch ...@@ -6,13 +6,22 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceDelegate,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Dynamic,
kNvfp4Static,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import ( from vllm.utils.flashinfer import (
flashinfer_cutedsl_grouped_gemm_nt_masked, flashinfer_cutedsl_grouped_gemm_nt_masked,
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
scaled_fp4_grouped_quantize, scaled_fp4_grouped_quantize,
silu_and_mul_scaled_nvfp4_experts_quantize, silu_and_mul_scaled_nvfp4_experts_quantize,
) )
...@@ -20,54 +29,54 @@ from vllm.utils.flashinfer import ( ...@@ -20,54 +29,54 @@ from vllm.utils.flashinfer import (
logger = init_logger(__name__) logger = init_logger(__name__)
def is_valid_flashinfer_cutedsl_fused_moe(
hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor
) -> bool:
"""
Check if the given problem size is supported by the FlashInfer CuteDSL MoE
kernel.
"""
if not has_flashinfer_cutedsl_grouped_gemm_nt_masked():
logger.debug_once(
"FlashInferCuteDSLExperts disabled: "
"flashinfer_cutedsl_fused_moe not available."
)
return False
# Data type checks
if (
w1.dtype != torch.uint8
or w2.dtype != torch.uint8
or hidden_states.dtype not in [torch.float32, torch.float16, torch.bfloat16]
):
logger.debug_once(
"FlashInferCuteDSLExperts disabled: w1/w2 must be torch.uint8 "
f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be "
f"float32, float16, or bfloat16 (got {hidden_states.dtype})."
)
return False
return True
class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute): class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
out_dtype: torch.dtype, moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
max_num_tokens: int,
num_dispatchers: int,
): ):
super().__init__(quant_config) super().__init__(
moe_config=moe_config,
quant_config=quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
)
assert quant_config.quant_dtype == "nvfp4", ( assert quant_config.quant_dtype == "nvfp4", (
"Only nvfp4 quantization are currently supported." "Only nvfp4 quantization are currently supported."
) )
self.out_dtype = out_dtype self.out_dtype = moe_config.in_dtype
@property @staticmethod
def activation_formats( def activation_format() -> mk.FusedMoEActivationFormat:
self, return mk.FusedMoEActivationFormat.BatchedExperts
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return ( @staticmethod
mk.FusedMoEActivationFormat.BatchedExperts, def _supports_current_device() -> bool:
mk.FusedMoEActivationFormat.BatchedExperts, return current_platform.is_device_capability_family(100)
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [
(kNvfp4Static, kNvfp4Dynamic),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu"]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return False return False
......
...@@ -5,13 +5,22 @@ import torch ...@@ -5,13 +5,22 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 FusedMoEParallelConfig,
create_flashinfer_prepare_finalize, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kFp8Static128BlockSym,
kFp8StaticTensorSym,
kNvfp4Dynamic,
kNvfp4Static,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import ( from vllm.utils.flashinfer import (
flashinfer_cutlass_fused_moe, flashinfer_cutlass_fused_moe,
has_flashinfer_cutlass_fused_moe, has_flashinfer_cutlass_fused_moe,
...@@ -50,40 +59,100 @@ def is_valid_flashinfer_cutlass_fused_moe( ...@@ -50,40 +59,100 @@ def is_valid_flashinfer_cutlass_fused_moe(
class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
out_dtype: torch.dtype, moe_config: mk.FusedMoEConfig,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
ep_rank: int = 0,
ep_size: int = 1,
tp_rank: int = 0,
tp_size: int = 1,
use_dp: bool = False,
use_deepseek_fp8_block_scale: bool = False,
): ):
super().__init__(quant_config) super().__init__(moe_config, quant_config)
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), ( assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
"Only nvfp4, fp8, bfloat16 and" "Only nvfp4, fp8, bfloat16 and"
" float16 quantization are currently supported." " float16 quantization are currently supported."
) )
self.ep_rank = ep_rank self.ep_rank = moe_config.moe_parallel_config.ep_rank
self.ep_size = ep_size self.ep_size = moe_config.moe_parallel_config.ep_size
self.tp_rank = tp_rank self.tp_rank = moe_config.moe_parallel_config.tp_rank
self.tp_size = tp_size self.tp_size = moe_config.moe_parallel_config.tp_size
self.out_dtype = out_dtype self.out_dtype = moe_config.in_dtype
self.use_dp = use_dp self.use_dp = moe_config.moe_parallel_config.dp_size > 1
# Enables DeepSeek-style FP8 block-scale path: # Enables DeepSeek-style FP8 block-scale path:
# - pass per-block weight scales to the kernel # - pass per-block weight scales to the kernel
# - skip input activation quantization (kernel applies scaling) # - skip input activation quantization (kernel applies scaling)
self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale self.use_deepseek_fp8_block_scale = quant_config.is_block_quantized
@property @staticmethod
def activation_formats( def expects_unquantized_inputs(
self, moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: ) -> bool:
# NVFP4 TP kernels and FP8 block-quantized kernels apply
# input quantization inside FusedMoEPermuteExpertsUnpermute.
return (
quant_config.use_nvfp4_w4a4
and not moe_config.moe_parallel_config.use_all2all_kernels
) or (quant_config.use_fp8_w8a8 and quant_config.is_block_quantized)
@staticmethod
def _supports_current_device() -> bool:
return ( return (
mk.FusedMoEActivationFormat.Standard, current_platform.is_cuda()
mk.FusedMoEActivationFormat.Standard, and (
current_platform.is_device_capability((9, 0))
or current_platform.is_device_capability_family(100)
)
and has_flashinfer_cutlass_fused_moe()
) )
@staticmethod
def _supports_no_act_and_mul() -> bool:
return True
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
# The following are supported by FlashInferExperts:
# * unquantized
# * fp8 static per-tensor on 9.0+
# * fp8 block on 9.0
# * nvfp4 on 10.0+
p = current_platform
scheme = (weight_key, activation_key)
return (
(
scheme
in [
(None, None),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
]
)
or (
(scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym))
and (p.is_device_capability((9, 0)))
)
or (
(scheme == (kNvfp4Static, kNvfp4Dynamic))
and (p.is_device_capability_family(100))
)
)
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu", "relu2_no_mul"]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
# FLASHINFER_CUTLASS currently uses its down P/F, which does not
# work with SP. This will be removed in follow up after we get
# rid of the FlashInfer specific P/F function.
return (
moe_parallel_config.dp_size == 1
or moe_parallel_config.dp_size == moe_parallel_config.ep_size
)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return False return False
...@@ -231,85 +300,3 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -231,85 +300,3 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# No support for LoRA in flashinfer_cutlass_fused_moe. # No support for LoRA in flashinfer_cutlass_fused_moe.
# See TODOs in flashinfer functions runMoe and runMoeMinLantency. # See TODOs in flashinfer functions runMoe and runMoeMinLantency.
raise NotImplementedError("LoRA is not supported for flashinfer_cutlass_moe") raise NotImplementedError("LoRA is not supported for flashinfer_cutlass_moe")
def flashinfer_cutlass_moe_fp4(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_config: FusedMoEQuantConfig,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
fused_experts = mk.FusedMoEModularKernel(
create_flashinfer_prepare_finalize(
use_dp=False, use_nvfp4=True, enable_alltoallv=False
),
FlashInferExperts(
out_dtype=hidden_states.dtype,
quant_config=quant_config,
use_dp=False,
),
)
return fused_experts(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
def flashinfer_cutlass_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_config: FusedMoEQuantConfig,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
tp_rank: int = 0,
tp_size: int = 1,
ep_rank: int = 0,
ep_size: int = 1,
use_dp: bool = False,
) -> torch.Tensor:
fused_experts = mk.FusedMoEModularKernel(
create_flashinfer_prepare_finalize(use_dp=use_dp),
FlashInferExperts(
out_dtype=hidden_states.dtype,
quant_config=quant_config,
tp_rank=tp_rank,
tp_size=tp_size,
ep_rank=ep_rank,
ep_size=ep_size,
),
)
return fused_experts(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
...@@ -3,7 +3,12 @@ ...@@ -3,7 +3,12 @@
import torch import torch
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
calculate_tile_tokens_dim, calculate_tile_tokens_dim,
...@@ -11,8 +16,107 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( ...@@ -11,8 +16,107 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kFp8Static128BlockSym,
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
#
# Methods used by the oracle for kernel selection.
#
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
# Add check flashinfer trtllm is available
return p.is_cuda() and p.is_device_capability_family(100)
def _supports_no_act_and_mul() -> bool:
"""Does not support non-gated MoE (i.e. Nanotron-Mini)."""
return False
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Fp8 per-tensor and Fp8 block."""
SUPPORTED_W_A = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
def _supports_activation(activation: str) -> bool:
"""Supports silu activation only."""
return activation in ["silu"]
def _supports_routing_method(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
routing_method: RoutingMethodType,
) -> bool:
"""Monolithic kernels need to express router support."""
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
# NOTE(rob): potentially allow others here. This is a conservative list.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
# NOTE(rob): kernel requires Llama4.
return routing_method == RoutingMethodType.Llama4
else:
raise ValueError("Unsupported quantization scheme.")
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""Supports TRTLLM Kernel does not support EPLB."""
return not moe_parallel_config.enable_eplb
def is_supported_config_trtllm(
moe_config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[bool, str | None]:
"""
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
"""
def _make_reason(reason: str) -> str:
return f"kernel does not support {reason}"
if not _supports_current_device():
return False, _make_reason("current device")
elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
return False, _make_reason("no act_and_mul MLP layer")
elif not _supports_activation(moe_config.activation):
return False, _make_reason(f"{moe_config.activation} activation")
elif not _supports_quant_scheme(weight_key, activation_key):
return False, _make_reason("quantization scheme")
elif not _supports_parallel_config(moe_config.moe_parallel_config):
return False, _make_reason("parallel config")
elif not _supports_routing_method(
weight_key, activation_key, moe_config.routing_method
):
return False, _make_reason("routing method")
elif activation_format != mk.FusedMoEActivationFormat.Standard:
return False, _make_reason("activation format")
return True, None
def flashinfer_fused_moe_blockscale_fp8( def flashinfer_fused_moe_blockscale_fp8(
routing_logits: torch.Tensor, routing_logits: torch.Tensor,
......
...@@ -5,7 +5,11 @@ ...@@ -5,7 +5,11 @@
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceDelegate,
...@@ -17,7 +21,17 @@ from vllm.model_executor.layers.fused_moe.utils import ( ...@@ -17,7 +21,17 @@ from vllm.model_executor.layers.fused_moe.utils import (
normalize_batched_scales_shape, normalize_batched_scales_shape,
normalize_scales_shape, normalize_scales_shape,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
group_broadcast,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
...@@ -633,25 +647,62 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -633,25 +647,62 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int, max_num_tokens: int,
num_dispatchers: int, num_dispatchers: int,
quant_config: FusedMoEQuantConfig,
): ):
super().__init__(quant_config) super().__init__(
moe_config=moe_config,
quant_config=quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
)
assert not self.quant_config.use_int8_w8a8, "NYI" assert not self.quant_config.use_int8_w8a8, "NYI"
assert not self.quant_config.use_int8_w8a16, "NYI" assert not self.quant_config.use_int8_w8a16, "NYI"
assert not self.quant_config.use_int4_w4a16, "NYI" assert not self.quant_config.use_int4_w4a16, "NYI"
assert self.quant_config.ocp_mx_scheme is None, "NYI" assert self.quant_config.ocp_mx_scheme is None, "NYI"
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
@property @staticmethod
def activation_formats( def activation_format() -> mk.FusedMoEActivationFormat:
self, return mk.FusedMoEActivationFormat.BatchedExperts
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return ( @staticmethod
mk.FusedMoEActivationFormat.BatchedExperts, def _supports_current_device() -> bool:
mk.FusedMoEActivationFormat.BatchedExperts, raise NotImplementedError(
"NaiveBatchedExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
raise NotImplementedError(
"NaiveBatchedExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
raise NotImplementedError(
"NaiveBatchedExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_activation(activation: str) -> bool:
raise NotImplementedError(
"NaiveBatchedExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
raise NotImplementedError(
"NaiveBatchedExperts is not yet used by an Oracle. "
"This method should not be called."
) )
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
...@@ -675,6 +726,8 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -675,6 +726,8 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert self.num_dispatchers is not None
assert self.max_num_tokens is not None
num_dp = self.num_dispatchers num_dp = self.num_dispatchers
num_experts = local_num_experts num_experts = local_num_experts
workspace13 = (num_experts, self.max_num_tokens * num_dp, K) workspace13 = (num_experts, self.max_num_tokens * num_dp, K)
...@@ -826,29 +879,69 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -826,29 +879,69 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int, max_num_tokens: int,
num_dispatchers: int, num_dispatchers: int,
quant_config: FusedMoEQuantConfig,
): ):
super().__init__(quant_config) super().__init__(
moe_config=moe_config,
quant_config=quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
)
assert not self.quant_config.use_int8_w8a8, "NYI" assert not self.quant_config.use_int8_w8a8, "NYI"
assert not self.quant_config.use_int8_w8a16, "NYI" assert not self.quant_config.use_int8_w8a16, "NYI"
assert not self.quant_config.use_int4_w4a16, "NYI" assert not self.quant_config.use_int4_w4a16, "NYI"
assert self.quant_config.ocp_mx_scheme is None, "NYI" assert self.quant_config.ocp_mx_scheme is None, "NYI"
assert max_num_tokens > 0
assert num_dispatchers > 0
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
@property @staticmethod
def activation_formats( def activation_format() -> mk.FusedMoEActivationFormat:
self, return mk.FusedMoEActivationFormat.BatchedExperts
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return ( @staticmethod
mk.FusedMoEActivationFormat.BatchedExperts, def _supports_current_device() -> bool:
mk.FusedMoEActivationFormat.BatchedExperts, return current_platform.is_cuda_alike()
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
p = current_platform
device_supports_fp8 = (p.is_rocm() and p.rocm.on_gfx9()) or (
p.is_cuda() and p.has_device_capability((8, 9))
)
SUPPORTED_W_A_FP8 = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
]
return (weight_key, activation_key) == (None, None) or (
device_supports_fp8 and (weight_key, activation_key) in SUPPORTED_W_A_FP8
) )
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in [
"silu",
"gelu",
"swigluoai",
"silu_no_mul",
"gelu_no_mul",
"relu2_no_mul",
]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
return False return False
...@@ -870,6 +963,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -870,6 +963,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert self.num_dispatchers is not None
assert self.max_num_tokens is not None
num_dp = self.num_dispatchers num_dp = self.num_dispatchers
num_experts = local_num_experts num_experts = local_num_experts
max_num_tokens = self.max_num_tokens max_num_tokens = self.max_num_tokens
......
...@@ -8,7 +8,11 @@ import torch ...@@ -8,7 +8,11 @@ import torch
import vllm._custom_ops as ops import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
batched_moe_align_block_size, batched_moe_align_block_size,
moe_align_block_size, moe_align_block_size,
...@@ -27,6 +31,13 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -27,6 +31,13 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_moe_intermediate_size, marlin_moe_intermediate_size,
marlin_quant_input, marlin_quant_input,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
kNvfp4Static,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
...@@ -522,7 +533,10 @@ def batched_fused_marlin_moe( ...@@ -522,7 +533,10 @@ def batched_fused_marlin_moe(
class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
max_num_tokens: int | None = None,
num_dispatchers: int | None = None,
w13_g_idx: torch.Tensor | None = None, w13_g_idx: torch.Tensor | None = None,
w2_g_idx: torch.Tensor | None = None, w2_g_idx: torch.Tensor | None = None,
w13_g_idx_sort_indices: torch.Tensor | None = None, w13_g_idx_sort_indices: torch.Tensor | None = None,
...@@ -541,7 +555,51 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -541,7 +555,51 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
self.w13_g_idx_sort_indices = w13_g_idx_sort_indices self.w13_g_idx_sort_indices = w13_g_idx_sort_indices
self.w2_g_idx_sort_indices = w2_g_idx_sort_indices self.w2_g_idx_sort_indices = w2_g_idx_sort_indices
self.is_k_full = is_k_full self.is_k_full = is_k_full
super().__init__(quant_config) super().__init__(
moe_config=moe_config,
quant_config=quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
)
@staticmethod
def _supports_current_device() -> bool:
p = current_platform
return p.is_cuda() and p.has_device_capability((7, 5))
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
# TODO(rob): add int4, mxfp4, int8 as integrations
# are migrated to use the oracle one-by-one.
SUPPORTED_W = [
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
kNvfp4Static,
]
return weight_key in SUPPORTED_W
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in [
"silu",
"gelu",
"swigluoai",
"silu_no_mul",
"gelu_no_mul",
"relu2_no_mul",
]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
@property @property
def quant_type_id(self) -> int: def quant_type_id(self) -> int:
...@@ -587,38 +645,15 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -587,38 +645,15 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
class MarlinExperts(MarlinExpertsBase): class MarlinExperts(MarlinExpertsBase):
def __init__(
self,
quant_config: FusedMoEQuantConfig,
w13_g_idx: torch.Tensor | None = None,
w2_g_idx: torch.Tensor | None = None,
w13_g_idx_sort_indices: torch.Tensor | None = None,
w2_g_idx_sort_indices: torch.Tensor | None = None,
is_k_full: bool = True,
):
super().__init__(
quant_config,
w13_g_idx,
w2_g_idx,
w13_g_idx_sort_indices,
w2_g_idx_sort_indices,
is_k_full,
)
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return True return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP() return TopKWeightAndReduceNoOP()
@property @staticmethod
def activation_formats( def activation_format() -> mk.FusedMoEActivationFormat:
self, return mk.FusedMoEActivationFormat.Standard
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard,
)
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
return True return True
...@@ -714,9 +749,10 @@ class MarlinExperts(MarlinExpertsBase): ...@@ -714,9 +749,10 @@ class MarlinExperts(MarlinExpertsBase):
class BatchedMarlinExperts(MarlinExpertsBase): class BatchedMarlinExperts(MarlinExpertsBase):
def __init__( def __init__(
self, self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int, max_num_tokens: int,
num_dispatchers: int, num_dispatchers: int,
quant_config: FusedMoEQuantConfig,
w13_g_idx: torch.Tensor | None = None, w13_g_idx: torch.Tensor | None = None,
w2_g_idx: torch.Tensor | None = None, w2_g_idx: torch.Tensor | None = None,
w13_g_idx_sort_indices: torch.Tensor | None = None, w13_g_idx_sort_indices: torch.Tensor | None = None,
...@@ -724,15 +760,16 @@ class BatchedMarlinExperts(MarlinExpertsBase): ...@@ -724,15 +760,16 @@ class BatchedMarlinExperts(MarlinExpertsBase):
is_k_full: bool = True, is_k_full: bool = True,
): ):
super().__init__( super().__init__(
quant_config, moe_config=moe_config,
w13_g_idx, quant_config=quant_config,
w2_g_idx, max_num_tokens=max_num_tokens,
w13_g_idx_sort_indices, num_dispatchers=num_dispatchers,
w2_g_idx_sort_indices, w13_g_idx=w13_g_idx,
is_k_full, w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=is_k_full,
) )
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return True return True
...@@ -740,14 +777,9 @@ class BatchedMarlinExperts(MarlinExpertsBase): ...@@ -740,14 +777,9 @@ class BatchedMarlinExperts(MarlinExpertsBase):
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceDelegate() return TopKWeightAndReduceDelegate()
@property @staticmethod
def activation_formats( def activation_format() -> mk.FusedMoEActivationFormat:
self, return mk.FusedMoEActivationFormat.BatchedExperts
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts,
)
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
return False return False
...@@ -763,9 +795,11 @@ class BatchedMarlinExperts(MarlinExpertsBase): ...@@ -763,9 +795,11 @@ class BatchedMarlinExperts(MarlinExpertsBase):
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str, activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert self.num_dispatchers is not None
assert self.max_num_tokens is not None
num_dispatchers = self.num_dispatchers num_dispatchers = self.num_dispatchers
num_experts = local_num_experts num_experts = local_num_experts
max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens max_num_tokens = self.max_num_tokens
workspace13 = (num_experts * max_num_tokens * num_dispatchers, max(K, N * 2)) workspace13 = (num_experts * max_num_tokens * num_dispatchers, max(K, N * 2))
workspace2 = (num_experts * max_num_tokens * num_dispatchers, N) workspace2 = (num_experts * max_num_tokens * num_dispatchers, N)
output = (num_experts, max_num_tokens * num_dispatchers, K) output = (num_experts, max_num_tokens * num_dispatchers, K)
......
...@@ -19,13 +19,11 @@ from vllm.model_executor.layers.batch_invariant import ( ...@@ -19,13 +19,11 @@ from vllm.model_executor.layers.batch_invariant import (
) )
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
_get_config_dtype_str, _get_config_dtype_str,
) )
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm,
deep_gemm_moe_fp8,
)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size, moe_align_block_size,
) )
...@@ -44,9 +42,16 @@ from vllm.model_executor.layers.fused_moe.utils import ( ...@@ -44,9 +42,16 @@ from vllm.model_executor.layers.fused_moe.utils import (
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6 from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kFp8DynamicTokenSym,
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -1534,66 +1539,36 @@ def fused_experts( ...@@ -1534,66 +1539,36 @@ def fused_experts(
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
quant_config: FusedMoEQuantConfig | None = None, quant_config: FusedMoEQuantConfig | None = None,
allow_deep_gemm: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if quant_config is None: if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
# For now, disable DeepGemm for small N (<= 512) until better return dispatch_fused_experts_func(inplace)(
# permute/unpermute ops are available. hidden_states=hidden_states,
# However, on B200, we use DeepGemm for all cases because they only support w1=w1,
# E8M0 scale, which means we requantize the weight and input to the specific w2=w2,
# scale. Fallen back to cutlass or triton for some cases would cause topk_weights=topk_weights,
# accuracy issue. topk_ids=topk_ids,
if ( activation=activation,
allow_deep_gemm apply_router_weight_on_input=apply_router_weight_on_input,
and quant_config.use_fp8_w8a8 use_fp8_w8a8=quant_config.use_fp8_w8a8,
and (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2)) use_int8_w8a8=quant_config.use_int8_w8a8,
): use_int8_w8a16=quant_config.use_int8_w8a16,
assert quant_config is not None use_int4_w4a16=quant_config.use_int4_w4a16,
return deep_gemm_moe_fp8( ocp_mx_scheme=quant_config.ocp_mx_scheme,
hidden_states=hidden_states, per_channel_quant=quant_config.per_act_token_quant,
w1=w1, global_num_experts=global_num_experts,
w2=w2, expert_map=expert_map,
topk_weights=topk_weights, w1_scale=quant_config.w1_scale,
topk_ids=topk_ids, w2_scale=quant_config.w2_scale,
inplace=inplace, w1_zp=quant_config.w1_zp,
activation=activation, w2_zp=quant_config.w2_zp,
global_num_experts=global_num_experts, a1_scale=quant_config.a1_scale,
expert_map=expert_map, a2_scale=quant_config.a2_scale,
w1_scale=quant_config.w1_scale, block_shape=quant_config.block_shape,
w2_scale=quant_config.w2_scale, w1_bias=quant_config.w1_bias,
a1_scale=quant_config.a1_scale, w2_bias=quant_config.w2_bias,
a2_scale=quant_config.a2_scale, )
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
return dispatch_fused_experts_func(inplace)(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=quant_config.use_fp8_w8a8,
use_int8_w8a8=quant_config.use_int8_w8a8,
use_int8_w8a16=quant_config.use_int8_w8a16,
use_int4_w4a16=quant_config.use_int4_w4a16,
ocp_mx_scheme=quant_config.ocp_mx_scheme,
per_channel_quant=quant_config.per_act_token_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=quant_config.w1_scale,
w2_scale=quant_config.w2_scale,
w1_zp=quant_config.w1_zp,
w2_zp=quant_config.w2_zp,
a1_scale=quant_config.a1_scale,
a2_scale=quant_config.a2_scale,
block_shape=quant_config.block_shape,
w1_bias=quant_config.w1_bias,
w2_bias=quant_config.w2_bias,
)
def _get_config_quant_dtype( def _get_config_quant_dtype(
...@@ -1924,19 +1899,53 @@ def fused_experts_impl( ...@@ -1924,19 +1899,53 @@ def fused_experts_impl(
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
): ):
super().__init__(quant_config) super().__init__(moe_config, quant_config)
@property @staticmethod
def activation_formats( def activation_format() -> mk.FusedMoEActivationFormat:
self, return mk.FusedMoEActivationFormat.Standard
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return ( @staticmethod
mk.FusedMoEActivationFormat.Standard, def _supports_current_device() -> bool:
mk.FusedMoEActivationFormat.Standard, return current_platform.is_cuda_alike()
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
p = current_platform
device_supports_fp8 = (p.is_rocm() and p.rocm.on_gfx9()) or (
p.is_cuda() and p.has_device_capability((8, 9))
) )
if not device_supports_fp8:
return (weight_key, activation_key) == (None, None)
SUPPORTED_W_A = [
(None, None),
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
(kFp8StaticTensorSym, kFp8DynamicTokenSym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu", "gelu", "swigluoai"]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
return True return True
...@@ -2111,11 +2120,43 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -2111,11 +2120,43 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
class TritonWNA16Experts(TritonExperts): class TritonWNA16Experts(TritonExperts):
def __init__( @staticmethod
self, def _supports_current_device() -> bool:
quant_config: FusedMoEQuantConfig, raise NotImplementedError(
): "TritonWNA16Experts is not yet used by an Oracle. "
super().__init__(quant_config) "This method should not be called."
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
raise NotImplementedError(
"TritonWNA16Experts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
raise NotImplementedError(
"TritonWNA16Experts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_activation(activation: str) -> bool:
raise NotImplementedError(
"TritonWNA16Experts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
raise NotImplementedError(
"TritonWNA16Experts is not yet used by an Oracle. "
"This method should not be called."
)
def apply( def apply(
self, self,
...@@ -2254,10 +2295,12 @@ class TritonWNA16Experts(TritonExperts): ...@@ -2254,10 +2295,12 @@ class TritonWNA16Experts(TritonExperts):
def modular_triton_fused_moe( def modular_triton_fused_moe(
quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEModularKernel: ) -> mk.FusedMoEModularKernel:
return mk.FusedMoEModularKernel( return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoEP(),
TritonExperts(quant_config), TritonExperts(moe_config, quant_config),
shared_experts, shared_experts,
) )
...@@ -9,12 +9,16 @@ from vllm import _custom_ops as ops ...@@ -9,12 +9,16 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEParallelConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
) )
from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.import_utils import has_triton_kernels from vllm.utils.import_utils import has_triton_kernels
...@@ -241,8 +245,43 @@ def make_routing_data( ...@@ -241,8 +245,43 @@ def make_routing_data(
class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config: FusedMoEQuantConfig): @staticmethod
super().__init__(quant_config) def _supports_current_device() -> bool:
raise NotImplementedError(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
raise NotImplementedError(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
raise NotImplementedError(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_activation(activation: str) -> bool:
raise NotImplementedError(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
raise NotImplementedError(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return True return True
...@@ -297,19 +336,9 @@ class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -297,19 +336,9 @@ class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
class OAITritonExperts(BaseOAITritonExperts): class OAITritonExperts(BaseOAITritonExperts):
def __init__(self, quant_config: FusedMoEQuantConfig): @staticmethod
# TODO (varun) : Enable activation quantization def activation_format() -> mk.FusedMoEActivationFormat:
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" return mk.FusedMoEActivationFormat.Standard
super().__init__(quant_config)
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard,
)
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
return True return True
...@@ -391,19 +420,9 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts): ...@@ -391,19 +420,9 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
One use case for it is to inject LoRA modules on the activation and moe_sum. One use case for it is to inject LoRA modules on the activation and moe_sum.
""" """
def __init__(self, quant_config: FusedMoEQuantConfig): @staticmethod
# TODO (varun) : Enable activation quantization def activation_format() -> mk.FusedMoEActivationFormat:
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" return mk.FusedMoEActivationFormat.Standard
super().__init__(quant_config)
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard,
)
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
return True return True
......
...@@ -330,7 +330,6 @@ class FusedMoE(CustomOp): ...@@ -330,7 +330,6 @@ class FusedMoE(CustomOp):
is_sequence_parallel=False, is_sequence_parallel=False,
expert_mapping: list[tuple[str, str, int, str]] | None = None, expert_mapping: list[tuple[str, str, int, str]] | None = None,
n_shared_experts: int | None = None, n_shared_experts: int | None = None,
routing_method_type: RoutingMethodType | None = None,
router_logits_dtype: torch.dtype | None = None, router_logits_dtype: torch.dtype | None = None,
): ):
super().__init__() super().__init__()
...@@ -519,10 +518,43 @@ class FusedMoE(CustomOp): ...@@ -519,10 +518,43 @@ class FusedMoE(CustomOp):
self.apply_router_weight_on_input = apply_router_weight_on_input self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = activation self.activation = activation
# TODO(bnell): in next PR move capture back to layer
capture: Callable[[torch.Tensor], None] | None = None
if (
self.vllm_config.model_config is not None
and self.vllm_config.model_config.enable_return_routed_experts
):
# In dummy runs, the capturer is not initialized.
capturer = RoutedExpertsCapturer.get_instance()
if capturer is not None:
capture = lambda topk_ids: capturer.capture(self.layer_id, topk_ids)
self.router = create_fused_moe_router(
top_k=top_k,
global_num_experts=self.global_num_experts,
eplb_state=self.eplb_state,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
num_fused_shared_experts=self.num_fused_shared_experts,
enable_eplb=enable_eplb,
# TODO(bnell): once we can construct the MK at init time, we
# can make this a value.
indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
capture=capture,
)
self.routing_method_type: RoutingMethodType = self.router.routing_method_type
self.moe_config: FusedMoEConfig = FusedMoEConfig( self.moe_config: FusedMoEConfig = FusedMoEConfig(
num_experts=self.global_num_experts, num_experts=self.global_num_experts,
experts_per_token=top_k, experts_per_token=top_k,
hidden_dim=hidden_size, hidden_dim=hidden_size,
intermediate_size_per_partition=self.intermediate_size_per_partition,
num_local_experts=self.local_num_experts, num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config, moe_parallel_config=self.moe_parallel_config,
in_dtype=moe_in_dtype, in_dtype=moe_in_dtype,
...@@ -531,6 +563,9 @@ class FusedMoE(CustomOp): ...@@ -531,6 +563,9 @@ class FusedMoE(CustomOp):
has_bias=has_bias, has_bias=has_bias,
is_act_and_mul=is_act_and_mul, is_act_and_mul=is_act_and_mul,
is_lora_enabled=vllm_config.lora_config is not None, is_lora_enabled=vllm_config.lora_config is not None,
activation=activation,
device=vllm_config.device_config.device,
routing_method=self.routing_method_type,
) )
self.moe_config_use_flashinfer_cutlass_kernels = ( self.moe_config_use_flashinfer_cutlass_kernels = (
self.moe_config.use_flashinfer_cutlass_kernels self.moe_config.use_flashinfer_cutlass_kernels
...@@ -594,39 +629,6 @@ class FusedMoE(CustomOp): ...@@ -594,39 +629,6 @@ class FusedMoE(CustomOp):
self.batched_hidden_states: torch.Tensor | None = None self.batched_hidden_states: torch.Tensor | None = None
self.batched_router_logits: torch.Tensor | None = None self.batched_router_logits: torch.Tensor | None = None
# TODO(bnell): in next PR move capture back to layer
capture: Callable[[torch.Tensor], None] | None = None
if (
self.vllm_config.model_config is not None
and self.vllm_config.model_config.enable_return_routed_experts
):
# In dummy runs, the capturer is not initialized.
capturer = RoutedExpertsCapturer.get_instance()
if capturer is not None:
capture = lambda topk_ids: capturer.capture(self.layer_id, topk_ids)
self.router = create_fused_moe_router(
top_k=top_k,
global_num_experts=self.global_num_experts,
eplb_state=self.eplb_state,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
num_fused_shared_experts=self.num_fused_shared_experts,
enable_eplb=enable_eplb,
# TODO(bnell): once we can construct the MK at init time, we
# can make this a value.
indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
routing_method_type=routing_method_type,
capture=capture,
)
self.routing_method_type: RoutingMethodType = self.router.routing_method_type
# Note: maybe_init_modular_kernel should only be called by # Note: maybe_init_modular_kernel should only be called by
# prepare_communication_buffer_for_model. # prepare_communication_buffer_for_model.
# This is called after all weight loading and post-processing, so it # This is called after all weight loading and post-processing, so it
......
...@@ -13,6 +13,7 @@ import vllm.envs as envs ...@@ -13,6 +13,7 @@ import vllm.envs as envs
from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
...@@ -22,6 +23,9 @@ from vllm.model_executor.layers.fused_moe.utils import ( ...@@ -22,6 +23,9 @@ from vllm.model_executor.layers.fused_moe.utils import (
count_expert_num_tokens, count_expert_num_tokens,
disable_inplace, disable_inplace,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.worker.ubatching import ( from vllm.v1.worker.ubatching import (
dbo_enabled, dbo_enabled,
...@@ -374,18 +378,51 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -374,18 +378,51 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
def __init__( def __init__(
self, self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
max_num_tokens: int | None = None,
num_dispatchers: int | None = None,
): ):
""" """
moe_config: MoE layer configuration.
quant_config: Quantization parameters for this experts instance. quant_config: Quantization parameters for this experts instance.
""" """
if self.activation_format() == FusedMoEActivationFormat.Standard and (
max_num_tokens is not None or num_dispatchers is not None
):
raise ValueError(
"max_num_tokens and num_dispatchers should only be set for "
"BatchedExperts activation format."
)
elif self.activation_format() == FusedMoEActivationFormat.BatchedExperts and (
max_num_tokens is None or num_dispatchers is None
):
raise ValueError(
"max_num_tokens and num_dispatchers must be set for "
"BatchedExperts activation format."
)
self.moe_config = moe_config
self.quant_config = quant_config self.quant_config = quant_config
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
@property @staticmethod
def expects_unquantized_inputs(
moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
"""
Whether or not the PrepareFinalize should defer input quantization
in the prepare step. If True, then the Experts kernel will
execute the input quantization itself.
Sample subclasses that override are AITER and FlashInfer CUTLASS.
"""
return False
@staticmethod
@abstractmethod @abstractmethod
def activation_formats( def activation_format() -> FusedMoEActivationFormat:
self,
) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]:
""" """
A property which is a tuple of the input and output activation formats A property which is a tuple of the input and output activation formats
for the 'apply' method. for the 'apply' method.
...@@ -435,6 +472,78 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -435,6 +472,78 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
return E, M, N, K, topk return E, M, N, K, topk
#
# Various helpers for registering support for various features.
# Used by the oracle to select a particular kernel for a deployment.
#
@staticmethod
def is_supported_config(
cls: type["FusedMoEPermuteExpertsUnpermute"],
moe_config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: FusedMoEActivationFormat,
) -> tuple[bool, str | None]:
def _make_reason(reason: str) -> str:
return f"kernel does not support {reason}"
if not cls._supports_current_device():
return False, _make_reason("current device")
elif not (moe_config.is_act_and_mul or cls._supports_no_act_and_mul()):
return False, _make_reason("no act_and_mul MLP layer")
elif not cls._supports_activation(moe_config.activation):
return False, _make_reason(f"{moe_config.activation} activation")
elif not cls._supports_quant_scheme(weight_key, activation_key):
return False, _make_reason("quantization scheme")
elif not cls._supports_parallel_config(moe_config.moe_parallel_config):
return False, _make_reason("parallel config")
elif activation_format != cls.activation_format():
return False, _make_reason(f"{activation_format.value} activation format")
return True, None
@staticmethod
@abstractmethod
def _supports_current_device() -> bool:
"""
Whether the kernel supports the current device type
(compute cability and current platform).
"""
raise NotImplementedError
@staticmethod
@abstractmethod
def _supports_no_act_and_mul() -> bool:
"""
Whether the kernel supports act_and_mul=False, i.e.
non-gated MoE models like Nemotron-Nano.
"""
raise NotImplementedError
@staticmethod
@abstractmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
raise NotImplementedError
@staticmethod
@abstractmethod
def _supports_activation(activation: str) -> bool:
"""
Whether the kernel supports a particular act function.
"""
raise NotImplementedError
@staticmethod
@abstractmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""
Whether the kernel supports deployment in expert parallel.
"""
raise NotImplementedError
# #
# Various helpers for accessing quantization parameters from the # Various helpers for accessing quantization parameters from the
# quant_config. # quant_config.
...@@ -715,12 +824,12 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -715,12 +824,12 @@ class FusedMoEModularKernel(torch.nn.Module):
self._post_init_setup() self._post_init_setup()
assert ( assert (
prepare_finalize.activation_format == fused_experts.activation_formats[0] prepare_finalize.activation_format == fused_experts.activation_format()
), ( ), (
f"{prepare_finalize.__class__.__name__}." f"{prepare_finalize.__class__.__name__}."
f"{prepare_finalize.activation_format} == " f"{prepare_finalize.activation_format} == "
f"{fused_experts.__class__.__name__}." f"{fused_experts.__class__.__name__}."
f"{fused_experts.activation_formats[0]}" f"{fused_experts.activation_format()}"
) )
def _post_init_setup(self): def _post_init_setup(self):
......
...@@ -14,21 +14,11 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -14,21 +14,11 @@ from vllm.model_executor.layers.fused_moe.config import (
nvfp4_moe_quant_config, nvfp4_moe_quant_config,
nvfp4_w4a16_moe_quant_config, nvfp4_w4a16_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP, MoEPrepareAndFinalizeNoEP,
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
is_flashinfer_fp4_cutedsl_moe_available, is_supported_config_trtllm,
is_flashinfer_fp4_cutlass_moe_available,
prepare_nvfp4_moe_layer_for_fi_or_cutlass, prepare_nvfp4_moe_layer_for_fi_or_cutlass,
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
...@@ -36,27 +26,26 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( ...@@ -36,27 +26,26 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
get_flashinfer_moe_backend, get_flashinfer_moe_backend,
) )
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
is_fp4_marlin_supported,
prepare_nvfp4_moe_layer_for_marlin, prepare_nvfp4_moe_layer_for_marlin,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported, QuantKey,
) )
logger = init_logger(__name__) logger = init_logger(__name__)
class NvFp4MoeBackend(Enum): class NvFp4MoeBackend(Enum):
FLASHINFER_CUTLASS = "FlashInfer CUTLASS" FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"
FLASHINFER_TRTLLM = "FlashInfer TRTLLM" FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS"
FLASHINFER_CUTEDSL = "FlashInfer CUTEDSL" FLASHINFER_CUTEDSL = "FLASHINFER_CUTEDSL"
VLLM_CUTLASS = "vLLM CUTASS" VLLM_CUTLASS = "VLLM_CUTLASS"
MARLIN = "vLLM MARLIN" MARLIN = "MARLIN"
FLASHINFER_NVFP4_MOE_BACKENDS = [ FLASHINFER_NVFP4_MOE_BACKENDS = [
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.FLASHINFER_TRTLLM, NvFp4MoeBackend.FLASHINFER_TRTLLM,
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.FLASHINFER_CUTEDSL, NvFp4MoeBackend.FLASHINFER_CUTEDSL,
] ]
...@@ -72,44 +61,208 @@ def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool: ...@@ -72,44 +61,208 @@ def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool:
# of all experts in Expert Parallel Mode when all experts are not # of all experts in Expert Parallel Mode when all experts are not
# on the same rank. # on the same rank.
return backend in [ return backend in FLASHINFER_NVFP4_MOE_BACKENDS
NvFp4MoeBackend.FLASHINFER_CUTLASS,
def backend_to_kernel_cls(
backend: NvFp4MoeBackend,
) -> type[mk.FusedMoEPermuteExpertsUnpermute]:
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
raise NotImplementedError(
"FLASHINFER_TRTLLM doesn't support Modular Kernel Interface"
)
elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
return FlashInferExperts
elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
FlashInferCuteDSLExperts,
)
return FlashInferCuteDSLExperts
elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4,
)
return CutlassExpertsFp4
elif backend == NvFp4MoeBackend.MARLIN:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
return MarlinExperts
else:
raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}")
def select_nvfp4_moe_backend(
config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]:
"""
Select the primary NvFP4 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
# NOTE: the kernels are selected in the following order.
AVAILABLE_BACKENDS = [
NvFp4MoeBackend.FLASHINFER_TRTLLM, NvFp4MoeBackend.FLASHINFER_TRTLLM,
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.VLLM_CUTLASS,
NvFp4MoeBackend.MARLIN,
] ]
# NOTE(rob): this is kind of a hack. We need to peak into
# the prepare-finalize selection to determine if we are using
# the batched or standard expert format.
use_batched = (
config.moe_parallel_config.use_deepep_ll_kernels
or config.moe_parallel_config.use_pplx_kernels
)
activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts
if use_batched
else mk.FusedMoEActivationFormat.Standard
)
def select_nvfp4_moe_backend() -> NvFp4MoeBackend:
def _make_log_backend(backend: NvFp4MoeBackend): def _make_log_backend(backend: NvFp4MoeBackend):
return f"Using {backend.value} backend for NvFp4 MoE" available_backend_strs = [b.value for b in AVAILABLE_BACKENDS]
return (
f"Using '{backend.value}' NvFp4 MoE backend out "
f"of potential backends: {available_backend_strs}."
)
if cutlass_fp4_supported() and not envs.VLLM_TEST_FORCE_FP8_MARLIN: def _make_log_unsupported(backend: NvFp4MoeBackend, reason: str | None) -> str:
allow_flashinfer = ( if reason:
is_flashinfer_fp4_cutlass_moe_available() return (
or is_flashinfer_fp4_cutedsl_moe_available() f"NvFp4 MoE backend '{backend.value}' does not support the "
f"deployment configuration since {reason}."
)
else:
return (
f"NvFp4 MoE backend '{backend.value}' does not support the "
"deployment configuration."
)
def _return_or_raise(
backend: NvFp4MoeBackend,
config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
) )
if allow_flashinfer and envs.VLLM_USE_FLASHINFER_MOE_FP4: if supported:
backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()] logger.info_once(_make_log_backend(backend))
return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason))
if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"):
if not envs.VLLM_USE_FLASHINFER_MOE_FP4:
# If the user rejects FlashInfer remove those backends.
for b in FLASHINFER_NVFP4_MOE_BACKENDS:
AVAILABLE_BACKENDS.remove(b)
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
# If user is explicit about backend, validate it.
fi_backend = get_flashinfer_moe_backend()
if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
backend = NvFp4MoeBackend.FLASHINFER_TRTLLM
supported, reason = is_supported_config_trtllm(
config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend))
return backend, None
else:
raise ValueError(_make_log_unsupported(backend, reason))
else:
backend = fi_2_vllm_backend_map[fi_backend]
return _return_or_raise(
backend, config, weight_key, activation_key, activation_format
)
else: else:
backend = NvFp4MoeBackend.VLLM_CUTLASS # If the user is not explicit about the backend, try each.
elif is_fp4_marlin_supported(): for backend in FLASHINFER_NVFP4_MOE_BACKENDS:
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
k_cls = None
supported, reason = is_supported_config_trtllm(
config,
weight_key,
activation_key,
activation_format,
)
else:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, None
else:
logger.debug_once(
_make_log_unsupported(backend, reason), scope="local"
)
raise NotImplementedError(
"Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no "
"FlashInfer NVFP4 MoE backend supports the configuration."
)
if envs.VLLM_TEST_FORCE_FP8_MARLIN:
backend = NvFp4MoeBackend.MARLIN backend = NvFp4MoeBackend.MARLIN
else: return _return_or_raise(
raise ValueError("No NvFp4 kernel backend available for NvFp4 MoE.") backend, config, weight_key, activation_key, activation_format
# Log warning if FI backend requested but not available.
if (
backend not in FLASHINFER_NVFP4_MOE_BACKENDS
and envs.VLLM_USE_FLASHINFER_MOE_FP4
):
logger.warning_once(
"Requested FlashInfer backend for NvFp4 MoE, but it's not available. "
"Falling back to %s for NvFp4 MoE",
backend.value,
scope="local",
) )
else:
logger.info_once(_make_log_backend(backend), scope="local") # Select kernels in order of backend.
return backend for backend in AVAILABLE_BACKENDS:
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
k_cls = None # type: ignore[assignment]
supported, reason = is_supported_config_trtllm(
config,
weight_key,
activation_key,
activation_format,
)
else:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
else:
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
raise NotImplementedError(
"No NvFp4 MoE backend supports the deployment configuration."
)
def convert_to_nvfp4_moe_kernel_format( def convert_to_nvfp4_moe_kernel_format(
...@@ -238,55 +391,69 @@ def make_nvfp4_moe_quant_config( ...@@ -238,55 +391,69 @@ def make_nvfp4_moe_quant_config(
) )
def make_nvfp4_moe_kernel( def make_nvfp4_moe_kernel_for_mkm(
backend: NvFp4MoeBackend,
quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
) -> mk.FusedMoEModularKernel | None: quant_config: FusedMoEQuantConfig,
assert moe_config.dp_size == 1 experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
) -> mk.FusedMoEPermuteExpertsUnpermute:
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
experts = experts_cls(
moe_config=moe_config,
quant_config=quant_config,
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
)
else:
experts = experts_cls(
moe_config=moe_config,
quant_config=quant_config,
)
UNSUPPORTED_BACKENDS = [ logger.debug_once("Using %s", experts.__class__.__name__)
# TRTLLM does not use the modular kernl abstraction. return experts
NvFp4MoeBackend.FLASHINFER_TRTLLM,
# CUTEDSL is used with BATCHED (masked) format only.
# TODO: add here once we support dp/ep via the oracle.
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
]
if backend in UNSUPPORTED_BACKENDS:
return None
elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: def make_nvfp4_moe_kernel(
return mk.FusedMoEModularKernel( moe_quant_config: FusedMoEQuantConfig,
MoEPrepareAndFinalizeNoEP(defer_input_quant=True), moe_config: FusedMoEConfig,
FlashInferExperts( experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
out_dtype=moe_config.in_dtype, ) -> mk.FusedMoEModularKernel:
quant_config=quant_config, # TODO(rob): unify after we merge tp and dp/ep.
ep_rank=moe_config.ep_rank, if (
ep_size=moe_config.ep_size, moe_config.moe_parallel_config.use_all2all_kernels
tp_rank=moe_config.tp_rank, and moe_config.moe_parallel_config.all2all_backend
tp_size=moe_config.tp_size, not in ["allgather_reducescatter", "naive"]
use_dp=False, ):
use_deepseek_fp8_block_scale=False, raise ValueError(
), "NvFP4 Oracle should not create non-naive A2A P/F. "
"This should happen via the ModularKernelMethod."
) )
elif backend == NvFp4MoeBackend.VLLM_CUTLASS: # Create Prepare/Finalize.
return mk.FusedMoEModularKernel( prepare_finalize = MoEPrepareAndFinalizeNoEP(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True), defer_input_quant=experts_cls.expects_unquantized_inputs(
CutlassExpertsFp4( moe_config, moe_quant_config
out_dtype=moe_config.in_dtype, ),
# TODO(rob): see what impact this has on expert map? )
max_experts_per_worker=moe_config.num_experts,
quant_config=quant_config,
),
)
elif backend == NvFp4MoeBackend.MARLIN: # Create Experts.
return mk.FusedMoEModularKernel( experts = experts_cls(
MoEPrepareAndFinalizeNoEP(), moe_config=moe_config,
MarlinExperts(quant_config=quant_config), quant_config=moe_quant_config,
) )
else: # NOTE(rob): we only want the mk to control the shared_expert
raise ValueError(f"Unknown NvFp4 MoE backend: {backend}") # if using all2all (for SBO). bnell is making this explict in
# the new MoE runner class.
kernel = mk.FusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts=None,
moe_parallel_config=moe_config.moe_parallel_config,
)
# TODO(rob): update inplace logic to be part of the kernel.
return kernel
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment