Unverified Commit 42135d68 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor] Oracle Select FP8+NVFP4 Kernels In Priority (#32414)

parent e14467be
......@@ -18,7 +18,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
kNvfp4Quant,
kNvfp4Dynamic,
)
from vllm.platforms import current_platform
......@@ -41,7 +41,7 @@ silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr(
torch.ops._C, "silu_and_mul_nvfp4_quant"
)
if silu_and_mul_nvfp4_quant_supported:
FUSED_OPS[kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
class ActivationQuantPattern(ABC):
......@@ -129,7 +129,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
"""
def __init__(self) -> None:
super().__init__(kNvfp4Quant)
super().__init__(kNvfp4Dynamic)
def get_inputs(self) -> list[torch.Tensor]:
result = self.empty_quant(5, 32)
......
......@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kNvfp4Quant,
kNvfp4Dynamic,
kStaticTensorScale,
)
from vllm.platforms import current_platform
......@@ -63,7 +63,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
......
......@@ -16,7 +16,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Quant,
kNvfp4Dynamic,
kStaticTensorScale,
)
from vllm.platforms import current_platform
......@@ -217,7 +217,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
"""
def __init__(self, layer: Attention, dtype: torch.dtype) -> None:
super().__init__(layer, kNvfp4Quant, dtype)
super().__init__(layer, kNvfp4Dynamic, dtype)
def _register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
......
......@@ -21,7 +21,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kNvfp4Quant,
kNvfp4Dynamic,
)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
......@@ -38,7 +38,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
......
......@@ -7,11 +7,20 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kFp8Static128BlockSym,
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (
......@@ -19,6 +28,7 @@ from vllm.utils.deep_gemm import (
fp8_m_grouped_gemm_nt_masked,
get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
)
from vllm.utils.math_utils import cdiv, round_up
......@@ -253,29 +263,52 @@ def persistent_masked_m_silu_mul_quant(
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int,
num_dispatchers: int,
quant_config: FusedMoEQuantConfig,
):
"""
max_num_tokens: Maximum number of tokens from a DP Rank
num_dispatchers: The number of DP dispatchers.
quant_config: Quantization configuration
"""
super().__init__(quant_config)
super().__init__(
moe_config=moe_config,
quant_config=quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
)
assert self.block_shape == get_mk_alignment_for_contiguous_layout()
assert self.quant_config.use_fp8_w8a8
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts,
)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
@staticmethod
def _supports_current_device() -> bool:
return is_deep_gemm_supported()
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [(kFp8Static128BlockSym, kFp8Dynamic128Sym)]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu"]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
def supports_chunking(self) -> bool:
return False
......@@ -310,6 +343,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
assert self.num_dispatchers is not None
assert self.max_num_tokens is not None
num_dispatchers = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens
......
......@@ -862,6 +862,7 @@ class FusedMoEParallelConfig:
use_ep: bool # whether to use EP or not
all2all_backend: str # all2all backend for MoE communication
enable_eplb: bool # whether to enable expert load balancing
@property
def use_all2all_kernels(self):
......@@ -882,6 +883,16 @@ class FusedMoEParallelConfig:
def use_deepep_ll_kernels(self):
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
@property
def use_batched_activation_format(self):
return self.use_deepep_ll_kernels or self.use_pplx_kernels
@property
def use_naive_all2all_kernels(self):
return self.use_all2all_kernels and (
self.all2all_backend in ["naive", "allgather_reducescatter"]
)
@staticmethod
def flatten_tp_across_dp_and_pcp(
tp_size: int, dp_size: int, dp_rank: int, pcp_size: int, pcp_rank: int
......@@ -999,6 +1010,7 @@ class FusedMoEParallelConfig:
ep_rank=0,
use_ep=False,
all2all_backend=vllm_parallel_config.all2all_backend,
enable_eplb=vllm_parallel_config.enable_eplb,
)
# DP + EP / TP + EP / DP + TP + EP
assert use_ep
......@@ -1017,6 +1029,24 @@ class FusedMoEParallelConfig:
ep_rank=ep_rank,
use_ep=True,
all2all_backend=vllm_parallel_config.all2all_backend,
enable_eplb=vllm_parallel_config.enable_eplb,
)
@classmethod
def make_no_parallel(cls) -> "FusedMoEParallelConfig":
"""For usage in CI/CD and testing."""
return FusedMoEParallelConfig(
tp_size=1,
tp_rank=0,
pcp_size=1,
pcp_rank=0,
dp_size=1,
dp_rank=0,
ep_size=1,
ep_rank=0,
use_ep=False,
all2all_backend="naive",
enable_eplb=False,
)
......@@ -1026,8 +1056,11 @@ class FusedMoEConfig:
num_experts: int
experts_per_token: int
hidden_dim: int
intermediate_size_per_partition: int
num_local_experts: int
activation: str
device: torch.device | str
routing_method: RoutingMethodType
moe_parallel_config: FusedMoEParallelConfig
# The activation type.
......
......@@ -7,7 +7,11 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_permute,
moe_unpermute,
......@@ -23,6 +27,19 @@ from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
apply_moe_activation,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
kNvfp4Dynamic,
kNvfp4Static,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_group_gemm_supported,
)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
......@@ -238,29 +255,57 @@ def run_cutlass_moe_fp8(
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
e: int,
n: int,
k: int,
out_dtype: torch.dtype | None,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
device: torch.dtype,
max_num_tokens: int | None = None,
num_dispatchers: int | None = None,
):
super().__init__(
moe_config=moe_config,
quant_config=quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
)
assert quant_config.use_fp8_w8a8
super().__init__(quant_config)
# E: num_experts
# N: intermediate size per partition
# K: hidden dim
e = moe_config.num_local_experts
n = moe_config.intermediate_size_per_partition
k = moe_config.hidden_dim
device = moe_config.device
ab_strides1_c_strides2 = torch.full((e,), k, device=device, dtype=torch.int64)
ab_strides2 = torch.full((e,), n, device=device, dtype=torch.int64)
c_strides1 = torch.full((e,), 2 * n, device=device, dtype=torch.int64)
self.out_dtype = out_dtype
self.out_dtype = moe_config.in_dtype
self.ab_strides1 = ab_strides1_c_strides2
self.ab_strides2 = ab_strides2
self.c_strides1 = c_strides1
self.c_strides2 = ab_strides1_c_strides2
@staticmethod
def _supports_current_device() -> bool:
return cutlass_group_gemm_supported()
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu", "gelu", "swigluoai"]
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
......@@ -291,7 +336,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
expert_num_tokens = expert_tokens_meta.expert_num_tokens
use_batched_format = (
self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts
self.activation_format() == mk.FusedMoEActivationFormat.BatchedExperts
)
in_dtype = hidden_states.dtype
......@@ -324,20 +369,23 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
class CutlassExpertsFp8(CutlassExpertsFp8Base):
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard,
)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
# CutlassExpertsFp8 does not support expert map, which is
# needed for STANDARD activation format kernels in DP/EP mode.
# Note that the BATCHED activation format does not use
# the expert map for identifying experts.
return not moe_parallel_config.use_all2all_kernels
def supports_chunking(self) -> bool:
return True
def supports_expert_map(self) -> bool:
return True
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# topk weights and reduction are fused in moe_unpermute cuda kernel
......@@ -365,26 +413,16 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
def __init__(
self,
max_experts_per_worker: int,
num_dispatchers: int,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
assert max_experts_per_worker > 0
self.max_experts_per_worker = max_experts_per_worker
self.num_dispatchers = num_dispatchers
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
# BATCHED activation format works with EP because
# expert_map is not used to identify experts (the
# info is encoded/managed by the P/F logic).
return True
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts,
)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
def supports_chunking(self) -> bool:
return False
......@@ -408,14 +446,15 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dp = self.num_dispatchers
assert num_dp is not None
experts_per_worker = self.moe_config.num_local_experts
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K))
workspace1 = (experts_per_worker, M * num_dp, max(N, K))
workspace2 = (
self.max_experts_per_worker,
experts_per_worker,
M * num_dp,
max(activation_out_dim, K),
)
output = (self.max_experts_per_worker, M, K)
output = (experts_per_worker, M, K)
return (workspace1, workspace2, output)
......@@ -601,34 +640,41 @@ def run_cutlass_moe_fp4(
return
# Split into batched and non-batched
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
max_experts_per_worker: int,
out_dtype: torch.dtype,
quant_config: FusedMoEQuantConfig,
use_batched_format: bool = False,
):
super().__init__(quant_config)
self.max_experts_per_worker = max_experts_per_worker
self.out_dtype = out_dtype
self.use_batched_format = use_batched_format
@staticmethod
def expects_unquantized_inputs(
moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
return True
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
if self.use_batched_format:
return (
mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts,
)
else:
return (
mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard,
)
@staticmethod
def _supports_current_device() -> bool:
return current_platform.has_device_capability((10, 0))
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
return (weight_key, activation_key) == (kNvfp4Static, kNvfp4Dynamic)
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu", "gelu", "swigluoai"]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
# CutlassExpertsFp4 does not support expert map, which is
# needed for STANDARD activation format kernels in EP mode.
return moe_parallel_config.ep_size == 1
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def supports_expert_map(self) -> bool:
return False
......@@ -640,7 +686,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
return TopKWeightAndReduceNoOP()
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
return self.out_dtype if self.out_dtype is not None else act_dtype
return act_dtype
def workspace_shapes(
self,
......@@ -653,18 +699,9 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1: tuple[int, ...] = ()
workspace2: tuple[int, ...] = ()
output: tuple[int, ...] = ()
if self.use_batched_format:
workspace1 = (self.max_experts_per_worker, M, max(N, K))
workspace2 = (self.max_experts_per_worker, M, activation_out_dim)
output = (self.max_experts_per_worker, M, K)
else:
workspace1 = (M * topk, max(2 * N, K))
workspace2 = (M * topk, N)
output = (M, K)
workspace1 = (M * topk, max(2 * N, K))
workspace2 = (M * topk, N)
output = (M, K)
return (workspace1, workspace2, output)
def apply(
......@@ -869,10 +906,11 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
c_strides2: torch.Tensor,
s_strides1: torch.Tensor,
s_strides2: torch.Tensor,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
group_size: int,
):
super().__init__(quant_config)
super().__init__(moe_config=moe_config, quant_config=quant_config)
self.out_dtype = out_dtype
self.a_strides1 = a_strides1
self.a_strides2 = a_strides2
......@@ -884,13 +922,46 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
self.s_strides2 = s_strides2
self.group_size = group_size
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard,
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
raise NotImplementedError(
"CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
raise NotImplementedError(
"CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
raise NotImplementedError(
"CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_activation(activation: str) -> bool:
raise NotImplementedError(
"CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
raise NotImplementedError(
"CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
"This method should not be called."
)
def supports_chunking(self) -> bool:
......@@ -947,7 +1018,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
expert_num_tokens = None
use_batched_format = (
self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts
self.activation_format() == mk.FusedMoEActivationFormat.BatchedExperts
)
assert not use_batched_format, "batched format not supported"
......@@ -1003,6 +1074,7 @@ def cutlass_moe_w4a8_fp8(
s_strides1: torch.Tensor,
s_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
activation: str = "silu",
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
......@@ -1076,6 +1148,7 @@ def cutlass_moe_w4a8_fp8(
c_strides2=c_strides2,
s_strides1=s_strides1,
s_strides2=s_strides2,
moe_config=moe_config,
quant_config=quant_config,
group_size=group_size,
),
......
......@@ -6,17 +6,15 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
compute_aligned_M,
deepgemm_moe_permute,
deepgemm_unpermute_and_reduce,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
......@@ -26,9 +24,15 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8_packed_for_deepgemm,
silu_mul_per_token_group_quant_fp8_colmajor,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kFp8Static128BlockSym,
)
from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
get_mk_alignment_for_contiguous_layout,
is_deep_gemm_supported,
m_grouped_fp8_gemm_nt_contiguous,
)
from vllm.utils.import_utils import has_deep_gemm
......@@ -109,21 +113,42 @@ def _valid_deep_gemm(
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config: FusedMoEQuantConfig):
super().__init__(quant_config)
def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
super().__init__(moe_config=moe_config, quant_config=quant_config)
assert quant_config.block_shape == get_mk_alignment_for_contiguous_layout()
assert quant_config.quant_dtype == torch.float8_e4m3fn
assert not quant_config.per_act_token_quant
assert not quant_config.per_out_ch_quant
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard,
)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
return is_deep_gemm_supported()
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu"]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
def supports_chunking(self) -> bool:
return True
......@@ -283,82 +308,3 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_map=expert_map,
output=output,
)
def deep_gemm_moe_fp8(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with DeepGemm
grouped gemm.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1 (torch.Tensor): The first set of fp8 quantized expert weights.
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2 (torch.Tensor): The second set of fp8 quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The token->expert mapping for topk_weights.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
Returns:
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
"""
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=get_mk_alignment_for_contiguous_layout(),
)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
DeepGemmExperts(quant_config),
)
return fn(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
......@@ -6,6 +6,8 @@ from abc import ABC, abstractmethod
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
......@@ -16,18 +18,78 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
experts: mk.FusedMoEPermuteExpertsUnpermute,
fallback_experts: mk.FusedMoEPermuteExpertsUnpermute,
):
super().__init__(experts.quant_config)
super().__init__(
moe_config=experts.moe_config, quant_config=experts.quant_config
)
self.fallback_experts = fallback_experts
self.experts = experts
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
assert (
self.fallback_experts.activation_formats == self.experts.activation_formats
@staticmethod
def get_clses() -> tuple[
type[mk.FusedMoEPermuteExpertsUnpermute],
type[mk.FusedMoEPermuteExpertsUnpermute],
]:
"""
Get the cls for the experts and fallback experts.
Subclasses should implement this method, so that
we have a consistent way to call the _supports_*
class methods below.
"""
raise NotImplementedError(
"Subclasses must return the cls for the experts and fallback experts."
)
@classmethod
def activation_format(
cls: type["FallbackExperts"],
) -> mk.FusedMoEActivationFormat:
experts_cls, fallback_cls = cls.get_clses()
assert experts_cls.activation_format() == fallback_cls.activation_format()
return experts_cls.activation_format()
@classmethod
def _supports_current_device(cls) -> bool:
experts_cls, fallback_cls = cls.get_clses()
return (
experts_cls._supports_current_device()
and fallback_cls._supports_current_device()
)
@classmethod
def _supports_no_act_and_mul(cls) -> bool:
experts_cls, fallback_cls = cls.get_clses()
return (
experts_cls._supports_no_act_and_mul()
and fallback_cls._supports_no_act_and_mul()
)
return self.fallback_experts.activation_formats
@classmethod
def _supports_quant_scheme(
cls,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
experts_cls, fallback_cls = cls.get_clses()
return experts_cls._supports_quant_scheme(
weight_key, activation_key
) and fallback_cls._supports_quant_scheme(weight_key, activation_key)
@classmethod
def _supports_activation(cls, activation: str) -> bool:
experts_cls, fallback_cls = cls.get_clses()
return experts_cls._supports_activation(
activation
) and fallback_cls._supports_activation(activation)
@classmethod
def _supports_parallel_config(
cls, moe_parallel_config: FusedMoEParallelConfig
) -> bool:
experts_cls, fallback_cls = cls.get_clses()
return experts_cls._supports_parallel_config(
moe_parallel_config
) and fallback_cls._supports_parallel_config(moe_parallel_config)
def supports_chunking(self) -> bool:
assert (
......
......@@ -6,13 +6,22 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Dynamic,
kNvfp4Static,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
flashinfer_cutedsl_grouped_gemm_nt_masked,
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
scaled_fp4_grouped_quantize,
silu_and_mul_scaled_nvfp4_experts_quantize,
)
......@@ -20,54 +29,54 @@ from vllm.utils.flashinfer import (
logger = init_logger(__name__)
def is_valid_flashinfer_cutedsl_fused_moe(
hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor
) -> bool:
"""
Check if the given problem size is supported by the FlashInfer CuteDSL MoE
kernel.
"""
if not has_flashinfer_cutedsl_grouped_gemm_nt_masked():
logger.debug_once(
"FlashInferCuteDSLExperts disabled: "
"flashinfer_cutedsl_fused_moe not available."
)
return False
# Data type checks
if (
w1.dtype != torch.uint8
or w2.dtype != torch.uint8
or hidden_states.dtype not in [torch.float32, torch.float16, torch.bfloat16]
):
logger.debug_once(
"FlashInferCuteDSLExperts disabled: w1/w2 must be torch.uint8 "
f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be "
f"float32, float16, or bfloat16 (got {hidden_states.dtype})."
)
return False
return True
class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
out_dtype: torch.dtype,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int,
num_dispatchers: int,
):
super().__init__(quant_config)
super().__init__(
moe_config=moe_config,
quant_config=quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
)
assert quant_config.quant_dtype == "nvfp4", (
"Only nvfp4 quantization are currently supported."
)
self.out_dtype = out_dtype
self.out_dtype = moe_config.in_dtype
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts,
)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
@staticmethod
def _supports_current_device() -> bool:
return current_platform.is_device_capability_family(100)
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [
(kNvfp4Static, kNvfp4Dynamic),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu"]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
def supports_expert_map(self) -> bool:
return False
......
......@@ -5,13 +5,22 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kFp8Static128BlockSym,
kFp8StaticTensorSym,
kNvfp4Dynamic,
kNvfp4Static,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
flashinfer_cutlass_fused_moe,
has_flashinfer_cutlass_fused_moe,
......@@ -50,40 +59,100 @@ def is_valid_flashinfer_cutlass_fused_moe(
class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
out_dtype: torch.dtype,
moe_config: mk.FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
ep_rank: int = 0,
ep_size: int = 1,
tp_rank: int = 0,
tp_size: int = 1,
use_dp: bool = False,
use_deepseek_fp8_block_scale: bool = False,
):
super().__init__(quant_config)
super().__init__(moe_config, quant_config)
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
"Only nvfp4, fp8, bfloat16 and"
" float16 quantization are currently supported."
)
self.ep_rank = ep_rank
self.ep_size = ep_size
self.tp_rank = tp_rank
self.tp_size = tp_size
self.out_dtype = out_dtype
self.use_dp = use_dp
self.ep_rank = moe_config.moe_parallel_config.ep_rank
self.ep_size = moe_config.moe_parallel_config.ep_size
self.tp_rank = moe_config.moe_parallel_config.tp_rank
self.tp_size = moe_config.moe_parallel_config.tp_size
self.out_dtype = moe_config.in_dtype
self.use_dp = moe_config.moe_parallel_config.dp_size > 1
# Enables DeepSeek-style FP8 block-scale path:
# - pass per-block weight scales to the kernel
# - skip input activation quantization (kernel applies scaling)
self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
self.use_deepseek_fp8_block_scale = quant_config.is_block_quantized
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
@staticmethod
def expects_unquantized_inputs(
moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
# NVFP4 TP kernels and FP8 block-quantized kernels apply
# input quantization inside FusedMoEPermuteExpertsUnpermute.
return (
quant_config.use_nvfp4_w4a4
and not moe_config.moe_parallel_config.use_all2all_kernels
) or (quant_config.use_fp8_w8a8 and quant_config.is_block_quantized)
@staticmethod
def _supports_current_device() -> bool:
return (
mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard,
current_platform.is_cuda()
and (
current_platform.is_device_capability((9, 0))
or current_platform.is_device_capability_family(100)
)
and has_flashinfer_cutlass_fused_moe()
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
return True
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
# The following are supported by FlashInferExperts:
# * unquantized
# * fp8 static per-tensor on 9.0+
# * fp8 block on 9.0
# * nvfp4 on 10.0+
p = current_platform
scheme = (weight_key, activation_key)
return (
(
scheme
in [
(None, None),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
]
)
or (
(scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym))
and (p.is_device_capability((9, 0)))
)
or (
(scheme == (kNvfp4Static, kNvfp4Dynamic))
and (p.is_device_capability_family(100))
)
)
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu", "relu2_no_mul"]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
# FLASHINFER_CUTLASS currently uses its down P/F, which does not
# work with SP. This will be removed in follow up after we get
# rid of the FlashInfer specific P/F function.
return (
moe_parallel_config.dp_size == 1
or moe_parallel_config.dp_size == moe_parallel_config.ep_size
)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def supports_expert_map(self) -> bool:
return False
......@@ -231,85 +300,3 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# No support for LoRA in flashinfer_cutlass_fused_moe.
# See TODOs in flashinfer functions runMoe and runMoeMinLantency.
raise NotImplementedError("LoRA is not supported for flashinfer_cutlass_moe")
def flashinfer_cutlass_moe_fp4(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_config: FusedMoEQuantConfig,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
fused_experts = mk.FusedMoEModularKernel(
create_flashinfer_prepare_finalize(
use_dp=False, use_nvfp4=True, enable_alltoallv=False
),
FlashInferExperts(
out_dtype=hidden_states.dtype,
quant_config=quant_config,
use_dp=False,
),
)
return fused_experts(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
def flashinfer_cutlass_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_config: FusedMoEQuantConfig,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
tp_rank: int = 0,
tp_size: int = 1,
ep_rank: int = 0,
ep_size: int = 1,
use_dp: bool = False,
) -> torch.Tensor:
fused_experts = mk.FusedMoEModularKernel(
create_flashinfer_prepare_finalize(use_dp=use_dp),
FlashInferExperts(
out_dtype=hidden_states.dtype,
quant_config=quant_config,
tp_rank=tp_rank,
tp_size=tp_size,
ep_rank=ep_rank,
ep_size=ep_size,
),
)
return fused_experts(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
......@@ -3,7 +3,12 @@
import torch
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
calculate_tile_tokens_dim,
......@@ -11,8 +16,107 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kFp8Static128BlockSym,
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
#
# Methods used by the oracle for kernel selection.
#
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
# Add check flashinfer trtllm is available
return p.is_cuda() and p.is_device_capability_family(100)
def _supports_no_act_and_mul() -> bool:
"""Does not support non-gated MoE (i.e. Nanotron-Mini)."""
return False
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Fp8 per-tensor and Fp8 block."""
SUPPORTED_W_A = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
def _supports_activation(activation: str) -> bool:
"""Supports silu activation only."""
return activation in ["silu"]
def _supports_routing_method(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
routing_method: RoutingMethodType,
) -> bool:
"""Monolithic kernels need to express router support."""
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
# NOTE(rob): potentially allow others here. This is a conservative list.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
# NOTE(rob): kernel requires Llama4.
return routing_method == RoutingMethodType.Llama4
else:
raise ValueError("Unsupported quantization scheme.")
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""Supports TRTLLM Kernel does not support EPLB."""
return not moe_parallel_config.enable_eplb
def is_supported_config_trtllm(
moe_config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[bool, str | None]:
"""
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
"""
def _make_reason(reason: str) -> str:
return f"kernel does not support {reason}"
if not _supports_current_device():
return False, _make_reason("current device")
elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
return False, _make_reason("no act_and_mul MLP layer")
elif not _supports_activation(moe_config.activation):
return False, _make_reason(f"{moe_config.activation} activation")
elif not _supports_quant_scheme(weight_key, activation_key):
return False, _make_reason("quantization scheme")
elif not _supports_parallel_config(moe_config.moe_parallel_config):
return False, _make_reason("parallel config")
elif not _supports_routing_method(
weight_key, activation_key, moe_config.routing_method
):
return False, _make_reason("routing method")
elif activation_format != mk.FusedMoEActivationFormat.Standard:
return False, _make_reason("activation format")
return True, None
def flashinfer_fused_moe_blockscale_fp8(
routing_logits: torch.Tensor,
......
......@@ -5,7 +5,11 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
......@@ -17,7 +21,17 @@ from vllm.model_executor.layers.fused_moe.utils import (
normalize_batched_scales_shape,
normalize_scales_shape,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
group_broadcast,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
......@@ -633,25 +647,62 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int,
num_dispatchers: int,
quant_config: FusedMoEQuantConfig,
):
super().__init__(quant_config)
super().__init__(
moe_config=moe_config,
quant_config=quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
)
assert not self.quant_config.use_int8_w8a8, "NYI"
assert not self.quant_config.use_int8_w8a16, "NYI"
assert not self.quant_config.use_int4_w4a16, "NYI"
assert self.quant_config.ocp_mx_scheme is None, "NYI"
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts,
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
@staticmethod
def _supports_current_device() -> bool:
raise NotImplementedError(
"NaiveBatchedExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
raise NotImplementedError(
"NaiveBatchedExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
raise NotImplementedError(
"NaiveBatchedExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_activation(activation: str) -> bool:
raise NotImplementedError(
"NaiveBatchedExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
raise NotImplementedError(
"NaiveBatchedExperts is not yet used by an Oracle. "
"This method should not be called."
)
def supports_chunking(self) -> bool:
......@@ -675,6 +726,8 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert self.num_dispatchers is not None
assert self.max_num_tokens is not None
num_dp = self.num_dispatchers
num_experts = local_num_experts
workspace13 = (num_experts, self.max_num_tokens * num_dp, K)
......@@ -826,29 +879,69 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int,
num_dispatchers: int,
quant_config: FusedMoEQuantConfig,
):
super().__init__(quant_config)
super().__init__(
moe_config=moe_config,
quant_config=quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
)
assert not self.quant_config.use_int8_w8a8, "NYI"
assert not self.quant_config.use_int8_w8a16, "NYI"
assert not self.quant_config.use_int4_w4a16, "NYI"
assert self.quant_config.ocp_mx_scheme is None, "NYI"
assert max_num_tokens > 0
assert num_dispatchers > 0
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts,
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
@staticmethod
def _supports_current_device() -> bool:
return current_platform.is_cuda_alike()
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
p = current_platform
device_supports_fp8 = (p.is_rocm() and p.rocm.on_gfx9()) or (
p.is_cuda() and p.has_device_capability((8, 9))
)
SUPPORTED_W_A_FP8 = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
]
return (weight_key, activation_key) == (None, None) or (
device_supports_fp8 and (weight_key, activation_key) in SUPPORTED_W_A_FP8
)
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in [
"silu",
"gelu",
"swigluoai",
"silu_no_mul",
"gelu_no_mul",
"relu2_no_mul",
]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
def supports_chunking(self) -> bool:
return False
......@@ -870,6 +963,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert self.num_dispatchers is not None
assert self.max_num_tokens is not None
num_dp = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = self.max_num_tokens
......
......@@ -8,7 +8,11 @@ import torch
import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
batched_moe_align_block_size,
moe_align_block_size,
......@@ -27,6 +31,13 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_moe_intermediate_size,
marlin_quant_input,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
kNvfp4Static,
)
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
......@@ -522,7 +533,10 @@ def batched_fused_marlin_moe(
class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int | None = None,
num_dispatchers: int | None = None,
w13_g_idx: torch.Tensor | None = None,
w2_g_idx: torch.Tensor | None = None,
w13_g_idx_sort_indices: torch.Tensor | None = None,
......@@ -541,7 +555,51 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
self.w13_g_idx_sort_indices = w13_g_idx_sort_indices
self.w2_g_idx_sort_indices = w2_g_idx_sort_indices
self.is_k_full = is_k_full
super().__init__(quant_config)
super().__init__(
moe_config=moe_config,
quant_config=quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
)
@staticmethod
def _supports_current_device() -> bool:
p = current_platform
return p.is_cuda() and p.has_device_capability((7, 5))
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
# TODO(rob): add int4, mxfp4, int8 as integrations
# are migrated to use the oracle one-by-one.
SUPPORTED_W = [
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
kNvfp4Static,
]
return weight_key in SUPPORTED_W
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in [
"silu",
"gelu",
"swigluoai",
"silu_no_mul",
"gelu_no_mul",
"relu2_no_mul",
]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
@property
def quant_type_id(self) -> int:
......@@ -587,38 +645,15 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
class MarlinExperts(MarlinExpertsBase):
def __init__(
self,
quant_config: FusedMoEQuantConfig,
w13_g_idx: torch.Tensor | None = None,
w2_g_idx: torch.Tensor | None = None,
w13_g_idx_sort_indices: torch.Tensor | None = None,
w2_g_idx_sort_indices: torch.Tensor | None = None,
is_k_full: bool = True,
):
super().__init__(
quant_config,
w13_g_idx,
w2_g_idx,
w13_g_idx_sort_indices,
w2_g_idx_sort_indices,
is_k_full,
)
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard,
)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def supports_chunking(self) -> bool:
return True
......@@ -714,9 +749,10 @@ class MarlinExperts(MarlinExpertsBase):
class BatchedMarlinExperts(MarlinExpertsBase):
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int,
num_dispatchers: int,
quant_config: FusedMoEQuantConfig,
w13_g_idx: torch.Tensor | None = None,
w2_g_idx: torch.Tensor | None = None,
w13_g_idx_sort_indices: torch.Tensor | None = None,
......@@ -724,15 +760,16 @@ class BatchedMarlinExperts(MarlinExpertsBase):
is_k_full: bool = True,
):
super().__init__(
quant_config,
w13_g_idx,
w2_g_idx,
w13_g_idx_sort_indices,
w2_g_idx_sort_indices,
is_k_full,
moe_config=moe_config,
quant_config=quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=is_k_full,
)
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
def supports_expert_map(self) -> bool:
return True
......@@ -740,14 +777,9 @@ class BatchedMarlinExperts(MarlinExpertsBase):
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceDelegate()
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts,
)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
def supports_chunking(self) -> bool:
return False
......@@ -763,9 +795,11 @@ class BatchedMarlinExperts(MarlinExpertsBase):
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert self.num_dispatchers is not None
assert self.max_num_tokens is not None
num_dispatchers = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens
max_num_tokens = self.max_num_tokens
workspace13 = (num_experts * max_num_tokens * num_dispatchers, max(K, N * 2))
workspace2 = (num_experts * max_num_tokens * num_dispatchers, N)
output = (num_experts, max_num_tokens * num_dispatchers, K)
......
......@@ -19,13 +19,11 @@ from vllm.model_executor.layers.batch_invariant import (
)
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
_get_config_dtype_str,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm,
deep_gemm_moe_fp8,
)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size,
)
......@@ -44,9 +42,16 @@ from vllm.model_executor.layers.fused_moe.utils import (
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kFp8DynamicTokenSym,
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
logger = init_logger(__name__)
......@@ -1534,66 +1539,36 @@ def fused_experts(
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
quant_config: FusedMoEQuantConfig | None = None,
allow_deep_gemm: bool = False,
) -> torch.Tensor:
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
# However, on B200, we use DeepGemm for all cases because they only support
# E8M0 scale, which means we requantize the weight and input to the specific
# scale. Fallen back to cutlass or triton for some cases would cause
# accuracy issue.
if (
allow_deep_gemm
and quant_config.use_fp8_w8a8
and (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))
):
assert quant_config is not None
return deep_gemm_moe_fp8(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=quant_config.w1_scale,
w2_scale=quant_config.w2_scale,
a1_scale=quant_config.a1_scale,
a2_scale=quant_config.a2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
return dispatch_fused_experts_func(inplace)(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=quant_config.use_fp8_w8a8,
use_int8_w8a8=quant_config.use_int8_w8a8,
use_int8_w8a16=quant_config.use_int8_w8a16,
use_int4_w4a16=quant_config.use_int4_w4a16,
ocp_mx_scheme=quant_config.ocp_mx_scheme,
per_channel_quant=quant_config.per_act_token_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=quant_config.w1_scale,
w2_scale=quant_config.w2_scale,
w1_zp=quant_config.w1_zp,
w2_zp=quant_config.w2_zp,
a1_scale=quant_config.a1_scale,
a2_scale=quant_config.a2_scale,
block_shape=quant_config.block_shape,
w1_bias=quant_config.w1_bias,
w2_bias=quant_config.w2_bias,
)
return dispatch_fused_experts_func(inplace)(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=quant_config.use_fp8_w8a8,
use_int8_w8a8=quant_config.use_int8_w8a8,
use_int8_w8a16=quant_config.use_int8_w8a16,
use_int4_w4a16=quant_config.use_int4_w4a16,
ocp_mx_scheme=quant_config.ocp_mx_scheme,
per_channel_quant=quant_config.per_act_token_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=quant_config.w1_scale,
w2_scale=quant_config.w2_scale,
w1_zp=quant_config.w1_zp,
w2_zp=quant_config.w2_zp,
a1_scale=quant_config.a1_scale,
a2_scale=quant_config.a2_scale,
block_shape=quant_config.block_shape,
w1_bias=quant_config.w1_bias,
w2_bias=quant_config.w2_bias,
)
def _get_config_quant_dtype(
......@@ -1924,19 +1899,53 @@ def fused_experts_impl(
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
super().__init__(quant_config)
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard,
super().__init__(moe_config, quant_config)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
return current_platform.is_cuda_alike()
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
p = current_platform
device_supports_fp8 = (p.is_rocm() and p.rocm.on_gfx9()) or (
p.is_cuda() and p.has_device_capability((8, 9))
)
if not device_supports_fp8:
return (weight_key, activation_key) == (None, None)
SUPPORTED_W_A = [
(None, None),
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
(kFp8StaticTensorSym, kFp8DynamicTokenSym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu", "gelu", "swigluoai"]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
def supports_chunking(self) -> bool:
return True
......@@ -2111,11 +2120,43 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
class TritonWNA16Experts(TritonExperts):
def __init__(
self,
quant_config: FusedMoEQuantConfig,
):
super().__init__(quant_config)
@staticmethod
def _supports_current_device() -> bool:
raise NotImplementedError(
"TritonWNA16Experts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
raise NotImplementedError(
"TritonWNA16Experts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
raise NotImplementedError(
"TritonWNA16Experts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_activation(activation: str) -> bool:
raise NotImplementedError(
"TritonWNA16Experts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
raise NotImplementedError(
"TritonWNA16Experts is not yet used by an Oracle. "
"This method should not be called."
)
def apply(
self,
......@@ -2254,10 +2295,12 @@ class TritonWNA16Experts(TritonExperts):
def modular_triton_fused_moe(
quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEModularKernel:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonExperts(quant_config),
TritonExperts(moe_config, quant_config),
shared_experts,
)
......@@ -9,12 +9,16 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
from vllm.triton_utils import tl, triton
from vllm.utils.import_utils import has_triton_kernels
......@@ -241,8 +245,43 @@ def make_routing_data(
class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config: FusedMoEQuantConfig):
super().__init__(quant_config)
@staticmethod
def _supports_current_device() -> bool:
raise NotImplementedError(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
raise NotImplementedError(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
raise NotImplementedError(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_activation(activation: str) -> bool:
raise NotImplementedError(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
raise NotImplementedError(
"OAITritonExperts is not yet used by an Oracle. "
"This method should not be called."
)
def supports_expert_map(self) -> bool:
return True
......@@ -297,19 +336,9 @@ class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
class OAITritonExperts(BaseOAITritonExperts):
def __init__(self, quant_config: FusedMoEQuantConfig):
# TODO (varun) : Enable activation quantization
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
super().__init__(quant_config)
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard,
)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def supports_chunking(self) -> bool:
return True
......@@ -391,19 +420,9 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
One use case for it is to inject LoRA modules on the activation and moe_sum.
"""
def __init__(self, quant_config: FusedMoEQuantConfig):
# TODO (varun) : Enable activation quantization
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
super().__init__(quant_config)
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard,
)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def supports_chunking(self) -> bool:
return True
......
......@@ -330,7 +330,6 @@ class FusedMoE(CustomOp):
is_sequence_parallel=False,
expert_mapping: list[tuple[str, str, int, str]] | None = None,
n_shared_experts: int | None = None,
routing_method_type: RoutingMethodType | None = None,
router_logits_dtype: torch.dtype | None = None,
):
super().__init__()
......@@ -519,10 +518,43 @@ class FusedMoE(CustomOp):
self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = activation
# TODO(bnell): in next PR move capture back to layer
capture: Callable[[torch.Tensor], None] | None = None
if (
self.vllm_config.model_config is not None
and self.vllm_config.model_config.enable_return_routed_experts
):
# In dummy runs, the capturer is not initialized.
capturer = RoutedExpertsCapturer.get_instance()
if capturer is not None:
capture = lambda topk_ids: capturer.capture(self.layer_id, topk_ids)
self.router = create_fused_moe_router(
top_k=top_k,
global_num_experts=self.global_num_experts,
eplb_state=self.eplb_state,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
num_fused_shared_experts=self.num_fused_shared_experts,
enable_eplb=enable_eplb,
# TODO(bnell): once we can construct the MK at init time, we
# can make this a value.
indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
capture=capture,
)
self.routing_method_type: RoutingMethodType = self.router.routing_method_type
self.moe_config: FusedMoEConfig = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
intermediate_size_per_partition=self.intermediate_size_per_partition,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=moe_in_dtype,
......@@ -531,6 +563,9 @@ class FusedMoE(CustomOp):
has_bias=has_bias,
is_act_and_mul=is_act_and_mul,
is_lora_enabled=vllm_config.lora_config is not None,
activation=activation,
device=vllm_config.device_config.device,
routing_method=self.routing_method_type,
)
self.moe_config_use_flashinfer_cutlass_kernels = (
self.moe_config.use_flashinfer_cutlass_kernels
......@@ -594,39 +629,6 @@ class FusedMoE(CustomOp):
self.batched_hidden_states: torch.Tensor | None = None
self.batched_router_logits: torch.Tensor | None = None
# TODO(bnell): in next PR move capture back to layer
capture: Callable[[torch.Tensor], None] | None = None
if (
self.vllm_config.model_config is not None
and self.vllm_config.model_config.enable_return_routed_experts
):
# In dummy runs, the capturer is not initialized.
capturer = RoutedExpertsCapturer.get_instance()
if capturer is not None:
capture = lambda topk_ids: capturer.capture(self.layer_id, topk_ids)
self.router = create_fused_moe_router(
top_k=top_k,
global_num_experts=self.global_num_experts,
eplb_state=self.eplb_state,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
num_fused_shared_experts=self.num_fused_shared_experts,
enable_eplb=enable_eplb,
# TODO(bnell): once we can construct the MK at init time, we
# can make this a value.
indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
routing_method_type=routing_method_type,
capture=capture,
)
self.routing_method_type: RoutingMethodType = self.router.routing_method_type
# Note: maybe_init_modular_kernel should only be called by
# prepare_communication_buffer_for_model.
# This is called after all weight loading and post-processing, so it
......
......@@ -13,6 +13,7 @@ import vllm.envs as envs
from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
......@@ -22,6 +23,9 @@ from vllm.model_executor.layers.fused_moe.utils import (
count_expert_num_tokens,
disable_inplace,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
from vllm.utils.math_utils import cdiv
from vllm.v1.worker.ubatching import (
dbo_enabled,
......@@ -374,18 +378,51 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int | None = None,
num_dispatchers: int | None = None,
):
"""
moe_config: MoE layer configuration.
quant_config: Quantization parameters for this experts instance.
"""
if self.activation_format() == FusedMoEActivationFormat.Standard and (
max_num_tokens is not None or num_dispatchers is not None
):
raise ValueError(
"max_num_tokens and num_dispatchers should only be set for "
"BatchedExperts activation format."
)
elif self.activation_format() == FusedMoEActivationFormat.BatchedExperts and (
max_num_tokens is None or num_dispatchers is None
):
raise ValueError(
"max_num_tokens and num_dispatchers must be set for "
"BatchedExperts activation format."
)
self.moe_config = moe_config
self.quant_config = quant_config
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
@property
@staticmethod
def expects_unquantized_inputs(
moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
"""
Whether or not the PrepareFinalize should defer input quantization
in the prepare step. If True, then the Experts kernel will
execute the input quantization itself.
Sample subclasses that override are AITER and FlashInfer CUTLASS.
"""
return False
@staticmethod
@abstractmethod
def activation_formats(
self,
) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]:
def activation_format() -> FusedMoEActivationFormat:
"""
A property which is a tuple of the input and output activation formats
for the 'apply' method.
......@@ -435,6 +472,78 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
return E, M, N, K, topk
#
# Various helpers for registering support for various features.
# Used by the oracle to select a particular kernel for a deployment.
#
@staticmethod
def is_supported_config(
cls: type["FusedMoEPermuteExpertsUnpermute"],
moe_config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: FusedMoEActivationFormat,
) -> tuple[bool, str | None]:
def _make_reason(reason: str) -> str:
return f"kernel does not support {reason}"
if not cls._supports_current_device():
return False, _make_reason("current device")
elif not (moe_config.is_act_and_mul or cls._supports_no_act_and_mul()):
return False, _make_reason("no act_and_mul MLP layer")
elif not cls._supports_activation(moe_config.activation):
return False, _make_reason(f"{moe_config.activation} activation")
elif not cls._supports_quant_scheme(weight_key, activation_key):
return False, _make_reason("quantization scheme")
elif not cls._supports_parallel_config(moe_config.moe_parallel_config):
return False, _make_reason("parallel config")
elif activation_format != cls.activation_format():
return False, _make_reason(f"{activation_format.value} activation format")
return True, None
@staticmethod
@abstractmethod
def _supports_current_device() -> bool:
"""
Whether the kernel supports the current device type
(compute cability and current platform).
"""
raise NotImplementedError
@staticmethod
@abstractmethod
def _supports_no_act_and_mul() -> bool:
"""
Whether the kernel supports act_and_mul=False, i.e.
non-gated MoE models like Nemotron-Nano.
"""
raise NotImplementedError
@staticmethod
@abstractmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
raise NotImplementedError
@staticmethod
@abstractmethod
def _supports_activation(activation: str) -> bool:
"""
Whether the kernel supports a particular act function.
"""
raise NotImplementedError
@staticmethod
@abstractmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""
Whether the kernel supports deployment in expert parallel.
"""
raise NotImplementedError
#
# Various helpers for accessing quantization parameters from the
# quant_config.
......@@ -715,12 +824,12 @@ class FusedMoEModularKernel(torch.nn.Module):
self._post_init_setup()
assert (
prepare_finalize.activation_format == fused_experts.activation_formats[0]
prepare_finalize.activation_format == fused_experts.activation_format()
), (
f"{prepare_finalize.__class__.__name__}."
f"{prepare_finalize.activation_format} == "
f"{fused_experts.__class__.__name__}."
f"{fused_experts.activation_formats[0]}"
f"{fused_experts.activation_format()}"
)
def _post_init_setup(self):
......
......@@ -14,21 +14,11 @@ from vllm.model_executor.layers.fused_moe.config import (
nvfp4_moe_quant_config,
nvfp4_w4a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
is_flashinfer_fp4_cutedsl_moe_available,
is_flashinfer_fp4_cutlass_moe_available,
is_supported_config_trtllm,
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
......@@ -36,27 +26,26 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
get_flashinfer_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
is_fp4_marlin_supported,
prepare_nvfp4_moe_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported,
QuantKey,
)
logger = init_logger(__name__)
class NvFp4MoeBackend(Enum):
FLASHINFER_CUTLASS = "FlashInfer CUTLASS"
FLASHINFER_TRTLLM = "FlashInfer TRTLLM"
FLASHINFER_CUTEDSL = "FlashInfer CUTEDSL"
VLLM_CUTLASS = "vLLM CUTASS"
MARLIN = "vLLM MARLIN"
FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"
FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS"
FLASHINFER_CUTEDSL = "FLASHINFER_CUTEDSL"
VLLM_CUTLASS = "VLLM_CUTLASS"
MARLIN = "MARLIN"
FLASHINFER_NVFP4_MOE_BACKENDS = [
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.FLASHINFER_TRTLLM,
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
]
......@@ -72,44 +61,208 @@ def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool:
# of all experts in Expert Parallel Mode when all experts are not
# on the same rank.
return backend in [
NvFp4MoeBackend.FLASHINFER_CUTLASS,
return backend in FLASHINFER_NVFP4_MOE_BACKENDS
def backend_to_kernel_cls(
backend: NvFp4MoeBackend,
) -> type[mk.FusedMoEPermuteExpertsUnpermute]:
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
raise NotImplementedError(
"FLASHINFER_TRTLLM doesn't support Modular Kernel Interface"
)
elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
return FlashInferExperts
elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
FlashInferCuteDSLExperts,
)
return FlashInferCuteDSLExperts
elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4,
)
return CutlassExpertsFp4
elif backend == NvFp4MoeBackend.MARLIN:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
return MarlinExperts
else:
raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}")
def select_nvfp4_moe_backend(
config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]:
"""
Select the primary NvFP4 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
# NOTE: the kernels are selected in the following order.
AVAILABLE_BACKENDS = [
NvFp4MoeBackend.FLASHINFER_TRTLLM,
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.VLLM_CUTLASS,
NvFp4MoeBackend.MARLIN,
]
# NOTE(rob): this is kind of a hack. We need to peak into
# the prepare-finalize selection to determine if we are using
# the batched or standard expert format.
use_batched = (
config.moe_parallel_config.use_deepep_ll_kernels
or config.moe_parallel_config.use_pplx_kernels
)
activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts
if use_batched
else mk.FusedMoEActivationFormat.Standard
)
def select_nvfp4_moe_backend() -> NvFp4MoeBackend:
def _make_log_backend(backend: NvFp4MoeBackend):
return f"Using {backend.value} backend for NvFp4 MoE"
available_backend_strs = [b.value for b in AVAILABLE_BACKENDS]
return (
f"Using '{backend.value}' NvFp4 MoE backend out "
f"of potential backends: {available_backend_strs}."
)
if cutlass_fp4_supported() and not envs.VLLM_TEST_FORCE_FP8_MARLIN:
allow_flashinfer = (
is_flashinfer_fp4_cutlass_moe_available()
or is_flashinfer_fp4_cutedsl_moe_available()
def _make_log_unsupported(backend: NvFp4MoeBackend, reason: str | None) -> str:
if reason:
return (
f"NvFp4 MoE backend '{backend.value}' does not support the "
f"deployment configuration since {reason}."
)
else:
return (
f"NvFp4 MoE backend '{backend.value}' does not support the "
"deployment configuration."
)
def _return_or_raise(
backend: NvFp4MoeBackend,
config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
)
if allow_flashinfer and envs.VLLM_USE_FLASHINFER_MOE_FP4:
backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()]
if supported:
logger.info_once(_make_log_backend(backend))
return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason))
if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"):
if not envs.VLLM_USE_FLASHINFER_MOE_FP4:
# If the user rejects FlashInfer remove those backends.
for b in FLASHINFER_NVFP4_MOE_BACKENDS:
AVAILABLE_BACKENDS.remove(b)
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
# If user is explicit about backend, validate it.
fi_backend = get_flashinfer_moe_backend()
if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
backend = NvFp4MoeBackend.FLASHINFER_TRTLLM
supported, reason = is_supported_config_trtllm(
config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend))
return backend, None
else:
raise ValueError(_make_log_unsupported(backend, reason))
else:
backend = fi_2_vllm_backend_map[fi_backend]
return _return_or_raise(
backend, config, weight_key, activation_key, activation_format
)
else:
backend = NvFp4MoeBackend.VLLM_CUTLASS
elif is_fp4_marlin_supported():
# If the user is not explicit about the backend, try each.
for backend in FLASHINFER_NVFP4_MOE_BACKENDS:
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
k_cls = None
supported, reason = is_supported_config_trtllm(
config,
weight_key,
activation_key,
activation_format,
)
else:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, None
else:
logger.debug_once(
_make_log_unsupported(backend, reason), scope="local"
)
raise NotImplementedError(
"Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no "
"FlashInfer NVFP4 MoE backend supports the configuration."
)
if envs.VLLM_TEST_FORCE_FP8_MARLIN:
backend = NvFp4MoeBackend.MARLIN
else:
raise ValueError("No NvFp4 kernel backend available for NvFp4 MoE.")
# Log warning if FI backend requested but not available.
if (
backend not in FLASHINFER_NVFP4_MOE_BACKENDS
and envs.VLLM_USE_FLASHINFER_MOE_FP4
):
logger.warning_once(
"Requested FlashInfer backend for NvFp4 MoE, but it's not available. "
"Falling back to %s for NvFp4 MoE",
backend.value,
scope="local",
return _return_or_raise(
backend, config, weight_key, activation_key, activation_format
)
else:
logger.info_once(_make_log_backend(backend), scope="local")
return backend
# Select kernels in order of backend.
for backend in AVAILABLE_BACKENDS:
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
k_cls = None # type: ignore[assignment]
supported, reason = is_supported_config_trtllm(
config,
weight_key,
activation_key,
activation_format,
)
else:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
else:
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
raise NotImplementedError(
"No NvFp4 MoE backend supports the deployment configuration."
)
def convert_to_nvfp4_moe_kernel_format(
......@@ -238,55 +391,69 @@ def make_nvfp4_moe_quant_config(
)
def make_nvfp4_moe_kernel(
backend: NvFp4MoeBackend,
quant_config: FusedMoEQuantConfig,
def make_nvfp4_moe_kernel_for_mkm(
moe_config: FusedMoEConfig,
) -> mk.FusedMoEModularKernel | None:
assert moe_config.dp_size == 1
quant_config: FusedMoEQuantConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
) -> mk.FusedMoEPermuteExpertsUnpermute:
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
experts = experts_cls(
moe_config=moe_config,
quant_config=quant_config,
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
)
else:
experts = experts_cls(
moe_config=moe_config,
quant_config=quant_config,
)
UNSUPPORTED_BACKENDS = [
# TRTLLM does not use the modular kernl abstraction.
NvFp4MoeBackend.FLASHINFER_TRTLLM,
# CUTEDSL is used with BATCHED (masked) format only.
# TODO: add here once we support dp/ep via the oracle.
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
]
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
if backend in UNSUPPORTED_BACKENDS:
return None
elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
FlashInferExperts(
out_dtype=moe_config.in_dtype,
quant_config=quant_config,
ep_rank=moe_config.ep_rank,
ep_size=moe_config.ep_size,
tp_rank=moe_config.tp_rank,
tp_size=moe_config.tp_size,
use_dp=False,
use_deepseek_fp8_block_scale=False,
),
def make_nvfp4_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
) -> mk.FusedMoEModularKernel:
# TODO(rob): unify after we merge tp and dp/ep.
if (
moe_config.moe_parallel_config.use_all2all_kernels
and moe_config.moe_parallel_config.all2all_backend
not in ["allgather_reducescatter", "naive"]
):
raise ValueError(
"NvFP4 Oracle should not create non-naive A2A P/F. "
"This should happen via the ModularKernelMethod."
)
elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
CutlassExpertsFp4(
out_dtype=moe_config.in_dtype,
# TODO(rob): see what impact this has on expert map?
max_experts_per_worker=moe_config.num_experts,
quant_config=quant_config,
),
)
# Create Prepare/Finalize.
prepare_finalize = MoEPrepareAndFinalizeNoEP(
defer_input_quant=experts_cls.expects_unquantized_inputs(
moe_config, moe_quant_config
),
)
elif backend == NvFp4MoeBackend.MARLIN:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
MarlinExperts(quant_config=quant_config),
)
# Create Experts.
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
)
else:
raise ValueError(f"Unknown NvFp4 MoE backend: {backend}")
# NOTE(rob): we only want the mk to control the shared_expert
# if using all2all (for SBO). bnell is making this explict in
# the new MoE runner class.
kernel = mk.FusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts=None,
moe_parallel_config=moe_config.moe_parallel_config,
)
# TODO(rob): update inplace logic to be part of the kernel.
return kernel
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment