Unverified Commit 5a93b916 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor] Integrate Naive Prepare Finalize into MK (#32567)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Signed-off-by: default avatarAmir Klein <203507526+amirkl94@users.noreply.github.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avataramirkl94 <203507526+amirkl94@users.noreply.github.com>
parent 6d86fde0
...@@ -33,7 +33,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( ...@@ -33,7 +33,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend, Fp8MoeBackend,
convert_to_fp8_moe_kernel_format, convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel, make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config, make_fp8_moe_quant_config,
select_fp8_moe_backend, select_fp8_moe_backend,
) )
...@@ -53,7 +52,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( ...@@ -53,7 +52,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe, apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, W8A8BlockFp8LinearOp,
...@@ -679,15 +677,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -679,15 +677,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
allow_vllm_cutlass=False, allow_vllm_cutlass=False,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def create_weights( def create_weights(
self, self,
layer: Module, layer: Module,
...@@ -813,7 +802,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -813,7 +802,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def _setup_kernel( def _setup_kernel(
self, self,
layer: Module, layer: FusedMoE,
w13: torch.Tensor, w13: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
w13_scale: torch.Tensor, w13_scale: torch.Tensor,
...@@ -845,16 +834,15 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -845,16 +834,15 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases. # in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config and ( if self.moe_quant_config:
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None assert self.experts_cls is not None
self.kernel, self.use_inplace = make_fp8_moe_kernel( self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
) )
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
...@@ -909,33 +897,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -909,33 +897,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: raise ValueError(
return None f"{self.__class__.__name__} uses the new modular kernel initialization "
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: "logic. This function should not be called."
# For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
) )
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute: ) -> FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None raise ValueError(
assert self.experts_cls is not None f"{self.__class__.__name__} uses the new modular kernel initialization "
return make_fp8_moe_kernel_for_mkm( "logic. This function should not be called."
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
...@@ -1037,9 +1011,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1037,9 +1011,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None assert self.moe_mk is not None
assert not self.is_monolithic assert not self.is_monolithic
return self.kernel( return self.moe_mk(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
......
...@@ -26,7 +26,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( ...@@ -26,7 +26,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend, Fp8MoeBackend,
convert_to_fp8_moe_kernel_format, convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel, make_fp8_moe_kernel,
make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config, make_fp8_moe_quant_config,
select_fp8_moe_backend, select_fp8_moe_backend,
) )
...@@ -35,7 +34,6 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( ...@@ -35,7 +34,6 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
convert_to_nvfp4_moe_kernel_format, convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend, is_global_sf_supported_for_nvfp4_backend,
make_nvfp4_moe_kernel, make_nvfp4_moe_kernel,
make_nvfp4_moe_kernel_for_mkm,
make_nvfp4_moe_quant_config, make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend, select_nvfp4_moe_backend,
) )
...@@ -54,13 +52,11 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( ...@@ -54,13 +52,11 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
) )
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe, flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe, flashinfer_trtllm_fp4_routed_moe,
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe, apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, W8A8BlockFp8LinearOp,
...@@ -739,47 +735,23 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -739,47 +735,23 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
activation_key=kFp8StaticTensorSym, activation_key=kFp8StaticTensorSym,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
# TRT LLM not supported with all2all yet. raise ValueError(
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: f"{self.__class__.__name__} uses the new modular kernel initialization "
return None "logic. This function should not be called."
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
# For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=False,
) )
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None raise ValueError(
assert self.experts_cls is not None f"{self.__class__.__name__} uses the new modular kernel initialization "
return make_fp8_moe_kernel_for_mkm( "logic. This function should not be called."
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
def create_weights( def create_weights(
...@@ -863,7 +835,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -863,7 +835,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def _setup_kernel( def _setup_kernel(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
w13: torch.Tensor, w13: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
w13_scale: torch.Tensor, w13_scale: torch.Tensor,
...@@ -893,11 +865,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -893,11 +865,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config: if self.moe_quant_config:
assert self.experts_cls is not None assert self.experts_cls is not None
self.kernel, self.use_inplace = make_fp8_moe_kernel( self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
fp8_backend=self.fp8_backend, fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
) )
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
...@@ -998,8 +972,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -998,8 +972,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
f"but got {layer.activation}" f"but got {layer.activation}"
) )
assert self.kernel is not None assert self.moe_mk is not None
return self.kernel( return self.moe_mk(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -1379,50 +1353,27 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1379,50 +1353,27 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
activation_key=kNvfp4Dynamic, activation_key=kNvfp4Dynamic,
) )
# Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend self.nvfp4_backend
) )
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.kernel is not None:
return self.kernel.prepare_finalize.topk_indices_dtype()
return None
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: raise ValueError(
return None f"{self.__class__.__name__} uses the new modular kernel initialization "
elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: "logic. This function should not be called."
# For no-EP case, don't use the MKM framework.
if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
# For now, fp4 moe only works with the flashinfer dispatcher.
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
self.moe
) )
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
else:
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize, prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None raise ValueError(
assert self.experts_cls is not None f"{self.__class__.__name__} uses the new modular kernel initialization "
return make_nvfp4_moe_kernel_for_mkm( "logic. This function should not be called."
moe_config=self.moe,
quant_config=self.moe_quant_config,
experts_cls=self.experts_cls,
prepare_finalize=prepare_finalize,
) )
def uses_weight_scale_2_pattern(self) -> bool: def uses_weight_scale_2_pattern(self) -> bool:
...@@ -1547,7 +1498,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1547,7 +1498,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
) )
layer.register_parameter("w2_input_scale", w2_input_scale) layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: FusedMoE) -> None:
""" """
Convert NVFP4 MoE weights into kernel format and setup the kernel. Convert NVFP4 MoE weights into kernel format and setup the kernel.
""" """
...@@ -1599,15 +1550,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1599,15 +1550,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases. # in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config and ( if self.moe_quant_config:
(not self.moe.moe_parallel_config.use_all2all_kernels)
or self.moe.moe_parallel_config.use_naive_all2all_kernels
):
assert self.experts_cls is not None assert self.experts_cls is not None
self.kernel = make_nvfp4_moe_kernel( self.moe_mk = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config, moe_quant_config=self.moe_quant_config,
moe_config=self.moe, moe_config=self.moe,
experts_cls=self.experts_cls, experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
) )
@property @property
...@@ -1708,8 +1658,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1708,8 +1658,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
) )
else: else:
assert self.kernel is not None assert self.moe_mk is not None
return self.kernel( return self.moe_mk(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
......
...@@ -15,9 +15,6 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -15,9 +15,6 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig, FusedMoEParallelConfig,
RoutingMethodType, RoutingMethodType,
) )
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kNvfp4Dynamic, kNvfp4Dynamic,
...@@ -42,7 +39,6 @@ __all__ = [ ...@@ -42,7 +39,6 @@ __all__ = [
"is_flashinfer_fp4_cutlass_moe_available", "is_flashinfer_fp4_cutlass_moe_available",
"is_flashinfer_fp4_cutedsl_moe_available", "is_flashinfer_fp4_cutedsl_moe_available",
"reorder_w1w3_to_w3w1", "reorder_w1w3_to_w3w1",
"build_flashinfer_fp4_cutlass_moe_prepare_finalize",
] ]
# #
...@@ -163,17 +159,6 @@ def reorder_w1w3_to_w3w1( ...@@ -163,17 +159,6 @@ def reorder_w1w3_to_w3w1(
) )
def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe: FusedMoEConfig,
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1
enable_alltoallv = moe.moe_parallel_config.all2all_backend == "flashinfer_all2allv"
return create_flashinfer_prepare_finalize(
use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv
)
def prepare_static_weights_for_trtllm_fp4_moe( def prepare_static_weights_for_trtllm_fp4_moe(
# args_dequant, # args_dequant,
# args, # args,
......
...@@ -4,15 +4,8 @@ from enum import Enum ...@@ -4,15 +4,8 @@ from enum import Enum
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
...@@ -163,18 +156,6 @@ def make_fp8_moe_alpha_scales_for_fi( ...@@ -163,18 +156,6 @@ def make_fp8_moe_alpha_scales_for_fi(
return g1_alphas, g2_alphas return g1_alphas, g2_alphas
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe: FusedMoEConfig | None, use_deepseek_fp8_block_scale: bool = False
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
# Propagate block-scale flag so prepare/finalize can skip act quantization
# and inform the kernel to consume per-block weight scales.
return create_flashinfer_prepare_finalize(
use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
)
def get_flashinfer_moe_backend() -> FlashinferMoeBackend: def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
backend_map = { backend_map = {
"throughput": FlashinferMoeBackend.CUTLASS, "throughput": FlashinferMoeBackend.CUTLASS,
......
...@@ -14,7 +14,6 @@ from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank ...@@ -14,7 +14,6 @@ from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts, TritonOrDeepGemmExperts,
) )
...@@ -169,9 +168,10 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: ...@@ -169,9 +168,10 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
# modular kernels could invoke deep_gemm_moe_fp8 # modular kernels could invoke deep_gemm_moe_fp8
return True return True
mk: FusedMoEModularKernel = module.quant_method.fused_experts
# Further check if the ModularKernel implementation uses the DeepGemmExperts # Further check if the ModularKernel implementation uses the DeepGemmExperts
return isinstance(mk.fused_experts, (DeepGemmExperts, TritonOrDeepGemmExperts)) return isinstance(
module.quant_method.moe_mk, (DeepGemmExperts, TritonOrDeepGemmExperts)
)
FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set() FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment