Unverified Commit 42135d68 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor] Oracle Select FP8+NVFP4 Kernels In Priority (#32414)

parent e14467be
...@@ -123,7 +123,6 @@ def convert_to_unquantized_kernel_format( ...@@ -123,7 +123,6 @@ def convert_to_unquantized_kernel_format(
def make_unquantized_moe_kernel( def make_unquantized_moe_kernel(
layer: torch.nn.Module,
backend: UnquantizedMoeBackend, backend: UnquantizedMoeBackend,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
...@@ -141,12 +140,8 @@ def make_unquantized_moe_kernel( ...@@ -141,12 +140,8 @@ def make_unquantized_moe_kernel(
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoEP(),
FlashInferExperts( FlashInferExperts(
out_dtype=layer.params_dtype, moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
tp_rank=moe_config.moe_parallel_config.tp_rank,
tp_size=moe_config.moe_parallel_config.tp_size,
ep_rank=moe_config.moe_parallel_config.ep_rank,
ep_size=moe_config.moe_parallel_config.ep_size,
), ),
) )
use_inplace = False use_inplace = False
...@@ -157,13 +152,19 @@ def make_unquantized_moe_kernel( ...@@ -157,13 +152,19 @@ def make_unquantized_moe_kernel(
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoEP(),
AiterExperts(quant_config), AiterExperts(
moe_config=moe_config,
quant_config=quant_config,
),
) )
elif backend == UnquantizedMoeBackend.TRITON: elif backend == UnquantizedMoeBackend.TRITON:
from vllm.model_executor.layers.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe import TritonExperts
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoEP(),
TritonExperts(quant_config), TritonExperts(
moe_config=moe_config,
quant_config=quant_config,
),
) )
return kernel, use_inplace return kernel, use_inplace
...@@ -9,11 +9,21 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk ...@@ -9,11 +9,21 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEParallelConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
)
class QuantMethod(IntEnum): class QuantMethod(IntEnum):
...@@ -269,17 +279,49 @@ def rocm_aiter_fused_experts( ...@@ -269,17 +279,49 @@ def rocm_aiter_fused_experts(
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config): @staticmethod
super().__init__(quant_config) def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def expects_unquantized_inputs(
fused_moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
# AITER fused MoE kernels handle input quantization internally.
return True
@property @staticmethod
def activation_formats( def _supports_current_device() -> bool:
self, return rocm_aiter_ops.is_fused_moe_enabled()
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return ( @staticmethod
mk.FusedMoEActivationFormat.Standard, def _supports_no_act_and_mul() -> bool:
mk.FusedMoEActivationFormat.Standard, return False
)
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
# TODO(rob): AITER also supports MXFP4, which is not
# yet supported via an Oracle. Once it is, we will add
# MXFP4 to this list.
SUPPORTED_W_A = [
(None, None),
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu", "gelu"]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
def supports_expert_map(self): def supports_expert_map(self):
return True return True
......
...@@ -34,6 +34,11 @@ class CustomRoutingRouter(BaseRouter): ...@@ -34,6 +34,11 @@ class CustomRoutingRouter(BaseRouter):
@property @property
def routing_method_type(self) -> RoutingMethodType: def routing_method_type(self) -> RoutingMethodType:
from vllm.model_executor.models.llama4 import Llama4MoE
# NOTE: FLASHINFER_TRTLLM support the Llama4 router.
if self.custom_routing_function == Llama4MoE.custom_routing_function:
return RoutingMethodType.Llama4
return RoutingMethodType.Custom return RoutingMethodType.Custom
def _compute_routing( def _compute_routing(
......
...@@ -261,7 +261,6 @@ class GroupedTopKRouter(BaseRouter): ...@@ -261,7 +261,6 @@ class GroupedTopKRouter(BaseRouter):
num_fused_shared_experts: int = 0, num_fused_shared_experts: int = 0,
enable_eplb: bool = False, enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None,
routing_method_type: RoutingMethodType | None = None,
): ):
super().__init__( super().__init__(
top_k=top_k, top_k=top_k,
...@@ -278,13 +277,12 @@ class GroupedTopKRouter(BaseRouter): ...@@ -278,13 +277,12 @@ class GroupedTopKRouter(BaseRouter):
self.e_score_correction_bias = e_score_correction_bias self.e_score_correction_bias = e_score_correction_bias
self.num_fused_shared_experts = num_fused_shared_experts self.num_fused_shared_experts = num_fused_shared_experts
# Determine routing method type if scoring_func == "sigmoid":
if routing_method_type is not None:
self._routing_method_type = routing_method_type
elif scoring_func == "sigmoid":
self._routing_method_type = RoutingMethodType.DeepSeekV3 self._routing_method_type = RoutingMethodType.DeepSeekV3
else: else:
self._routing_method_type = RoutingMethodType.TopK # NOTE: this prohibits the FLASHINFER_TRTLLM kernels from
# being selected, since they only support DeepSeek-style.
self._routing_method_type = RoutingMethodType.Unspecified
@property @property
def routing_method_type(self) -> RoutingMethodType: def routing_method_type(self) -> RoutingMethodType:
......
...@@ -6,7 +6,6 @@ import torch ...@@ -6,7 +6,6 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
from vllm.model_executor.layers.fused_moe.router.custom_routing_router import ( from vllm.model_executor.layers.fused_moe.router.custom_routing_router import (
CustomRoutingRouter, CustomRoutingRouter,
...@@ -36,7 +35,6 @@ def create_fused_moe_router( ...@@ -36,7 +35,6 @@ def create_fused_moe_router(
global_num_experts: int, global_num_experts: int,
renormalize: bool = True, renormalize: bool = True,
indices_type_getter: Callable[[], torch.dtype | None] | None = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None,
routing_method_type: RoutingMethodType | None = None,
# grouped topk parameters # grouped topk parameters
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: int | None = None, num_expert_group: int | None = None,
...@@ -128,7 +126,6 @@ def create_fused_moe_router( ...@@ -128,7 +126,6 @@ def create_fused_moe_router(
num_fused_shared_experts=num_fused_shared_experts, num_fused_shared_experts=num_fused_shared_experts,
enable_eplb=enable_eplb, enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter, indices_type_getter=indices_type_getter,
routing_method_type=routing_method_type,
) )
router.capture = capture router.capture = capture
return router return router
......
...@@ -5,7 +5,10 @@ ...@@ -5,7 +5,10 @@
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
...@@ -17,19 +20,22 @@ class TritonOrCutlassExperts(FallbackExperts): ...@@ -17,19 +20,22 @@ class TritonOrCutlassExperts(FallbackExperts):
def __init__( def __init__(
self, self,
e: int, moe_config: FusedMoEConfig,
n: int,
k: int,
out_dtype: torch.dtype | None,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
device: torch.dtype,
): ):
self.is_sm100 = current_platform.has_device_capability(100) self.is_sm100 = current_platform.has_device_capability(100)
super().__init__( super().__init__(
experts=CutlassExpertsFp8(e, n, k, out_dtype, quant_config, device), experts=CutlassExpertsFp8(moe_config, quant_config),
fallback_experts=TritonExperts(quant_config), fallback_experts=TritonExperts(moe_config, quant_config),
) )
@staticmethod
def get_clses() -> tuple[
type[mk.FusedMoEPermuteExpertsUnpermute],
type[mk.FusedMoEPermuteExpertsUnpermute],
]:
return (CutlassExpertsFp8, TritonExperts)
def workspace_shapes( def workspace_shapes(
self, self,
M: int, M: int,
......
...@@ -4,7 +4,10 @@ ...@@ -4,7 +4,10 @@
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts, DeepGemmExperts,
_valid_deep_gemm, _valid_deep_gemm,
...@@ -20,12 +23,19 @@ from vllm.utils.deep_gemm import ( ...@@ -20,12 +23,19 @@ from vllm.utils.deep_gemm import (
class TritonOrDeepGemmExperts(FallbackExperts): class TritonOrDeepGemmExperts(FallbackExperts):
"""DeepGemm with fallback to Triton for low latency shapes.""" """DeepGemm with fallback to Triton for low latency shapes."""
def __init__(self, quant_config: FusedMoEQuantConfig): def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
super().__init__( super().__init__(
experts=DeepGemmExperts(quant_config), experts=DeepGemmExperts(moe_config, quant_config),
fallback_experts=TritonExperts(quant_config), fallback_experts=TritonExperts(moe_config, quant_config),
) )
@staticmethod
def get_clses() -> tuple[
type[mk.FusedMoEPermuteExpertsUnpermute],
type[mk.FusedMoEPermuteExpertsUnpermute],
]:
return (DeepGemmExperts, TritonExperts)
def workspace_shapes( def workspace_shapes(
self, self,
M: int, M: int,
......
...@@ -6,37 +6,73 @@ import torch ...@@ -6,37 +6,73 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
moe: FusedMoEConfig, moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
gemm1_alpha, gemm1_alpha,
gemm1_beta, gemm1_beta,
gemm1_clamp_limit, gemm1_clamp_limit,
max_capture_size, max_capture_size,
): ):
super().__init__(quant_config) super().__init__(moe_config, quant_config)
self.moe = moe
self.gemm1_alpha = gemm1_alpha self.gemm1_alpha = gemm1_alpha
self.gemm1_beta = gemm1_beta self.gemm1_beta = gemm1_beta
self.gemm1_clamp_limit = gemm1_clamp_limit self.gemm1_clamp_limit = gemm1_clamp_limit
self.max_capture_size = max_capture_size self.max_capture_size = max_capture_size
@property @staticmethod
def activation_formats( def activation_format() -> mk.FusedMoEActivationFormat:
self, return mk.FusedMoEActivationFormat.Standard
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return ( @staticmethod
mk.FusedMoEActivationFormat.Standard, def _supports_current_device() -> bool:
mk.FusedMoEActivationFormat.Standard, raise NotImplementedError(
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
raise NotImplementedError(
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
raise NotImplementedError(
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_activation(activation: str) -> bool:
raise NotImplementedError(
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
)
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
raise NotImplementedError(
"TrtLlmGenExperts is not yet used by an Oracle. "
"This method should not be called."
) )
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
...@@ -86,7 +122,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -86,7 +122,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk = topk_ids.size(-1) topk = topk_ids.size(-1)
local_num_experts = w1.size(0) local_num_experts = w1.size(0)
intermediate_size = w2.size(1) intermediate_size = w2.size(1)
local_expert_offset = self.moe.ep_rank * local_num_experts local_expert_offset = self.moe_config.ep_rank * local_num_experts
x_quant = hidden_states x_quant = hidden_states
x_scale = a1q_scale x_scale = a1q_scale
......
...@@ -96,13 +96,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -96,13 +96,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
): ):
logger.debug("BatchedTritonExperts %s", self.moe) logger.debug("BatchedTritonExperts %s", self.moe)
return BatchedTritonExperts( return BatchedTritonExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config,
max_num_tokens=self.moe.max_num_tokens, max_num_tokens=self.moe.max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
) )
else: else:
logger.debug("TritonExperts %s", self.moe) logger.debug("TritonExperts %s", self.moe)
return TritonExperts(self.moe_quant_config) return TritonExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config,
)
def create_weights( def create_weights(
self, self,
...@@ -192,7 +196,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -192,7 +196,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
assert self.moe_quant_config is not None assert self.moe_quant_config is not None
self.kernel, self.use_inplace = make_unquantized_moe_kernel( self.kernel, self.use_inplace = make_unquantized_moe_kernel(
layer=layer,
backend=self.unquantized_backend, backend=self.unquantized_backend,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
......
...@@ -739,6 +739,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -739,6 +739,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
return BatchedMarlinExperts( return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank, max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
moe_config=self.moe,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx, w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx, w2_g_idx=w2_g_idx,
...@@ -749,6 +750,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -749,6 +750,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
else: else:
# Standard Marlin experts for AWQ # Standard Marlin experts for AWQ
return MarlinExperts( return MarlinExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx, w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx, w2_g_idx=w2_g_idx,
......
...@@ -875,6 +875,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -875,6 +875,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
return BatchedMarlinExperts( return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank, max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
moe_config=self.moe,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx, w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx, w2_g_idx=w2_g_idx,
...@@ -885,6 +886,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -885,6 +886,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
else: else:
# Standard Marlin experts for GPTQ # Standard Marlin experts for GPTQ
return MarlinExperts( return MarlinExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx, w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx, w2_g_idx=w2_g_idx,
......
...@@ -853,6 +853,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -853,6 +853,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
max_num_tokens=max_num_tokens_per_rank, max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
moe_config=self.moe,
) )
else: else:
raise NotImplementedError( raise NotImplementedError(
...@@ -875,11 +876,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -875,11 +876,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
} }
return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs) return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs)
elif self.mxfp4_backend == Mxfp4Backend.MARLIN: elif self.mxfp4_backend == Mxfp4Backend.MARLIN:
return MarlinExperts(self.moe_quant_config) return MarlinExperts(self.moe, self.moe_quant_config)
elif self.mxfp4_backend == Mxfp4Backend.TRITON: elif self.mxfp4_backend == Mxfp4Backend.TRITON:
if self.moe.is_lora_enabled: if self.moe.is_lora_enabled:
return UnfusedOAITritonExperts(self.moe_quant_config) return UnfusedOAITritonExperts(self.moe, self.moe_quant_config)
return OAITritonExperts(self.moe_quant_config) return OAITritonExperts(self.moe, self.moe_quant_config)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP" f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP"
......
...@@ -9,10 +9,6 @@ from vllm import envs ...@@ -9,10 +9,6 @@ from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
) )
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize, create_flashinfer_prepare_finalize,
...@@ -203,33 +199,6 @@ def build_flashinfer_fp8_cutlass_moe_prepare_finalize( ...@@ -203,33 +199,6 @@ def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
) )
def select_cutlass_fp8_gemm_impl(
moe: FusedMoEConfig | None,
quant_config: FusedMoEQuantConfig,
out_dtype: torch.dtype | None = None,
use_deepseek_fp8_block_scale: bool = False,
) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return a GEMM *experts* implementation for fused-MoE layers"""
if moe is not None:
return FlashInferExperts(
out_dtype=moe.in_dtype,
quant_config=quant_config,
ep_rank=moe.moe_parallel_config.ep_rank,
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,
tp_size=moe.moe_parallel_config.tp_size,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
)
assert out_dtype is not None, "If moe config is None, out_dtype must be passed"
return FlashInferExperts(
out_dtype=out_dtype,
quant_config=quant_config,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
)
def get_flashinfer_moe_backend() -> FlashinferMoeBackend: def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
backend_map = { backend_map = {
"throughput": FlashinferMoeBackend.CUTLASS, "throughput": FlashinferMoeBackend.CUTLASS,
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment