Unverified Commit 97995f63 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor] Create MK for TRTLLM Kernels (#32564)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Signed-off-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
Signed-off-by: default avatarRobert Shaw <robertgshaw2@gmail.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
parent 881a6b01
......@@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce):
"""
Useful in the case when some FusedMoEPermuteExpertsUnpermute
Useful in the case when some FusedMoEExpertsModular
implementation does not perform weight application and reduction
but cannot address the needs of all the compatible PrepareAndFinalize
implementations.
......@@ -62,7 +62,7 @@ class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce):
if output is None:
return fused_expert_output
# MoEPrepareAndFinalizeNoEP needs the output to be in the `output`
# MoEPrepareAndFinalizeNoDPEPModular needs the output to be in the `output`
# tensor.
assert output.size() == fused_expert_output.size(), (
"output shape is expected to match the fused_expert_output shape. "
......
......@@ -32,8 +32,8 @@ class TritonOrCutlassExperts(FallbackExperts):
@staticmethod
def get_clses() -> tuple[
type[mk.FusedMoEPermuteExpertsUnpermute],
type[mk.FusedMoEPermuteExpertsUnpermute],
type[mk.FusedMoEExpertsModular],
type[mk.FusedMoEExpertsModular],
]:
return (CutlassExpertsFp8, TritonExperts)
......@@ -77,7 +77,7 @@ class TritonOrCutlassExperts(FallbackExperts):
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEExpertsModular:
# Small batch fallback for sm100.
if self.is_sm100 and hidden_states.shape[0] <= 8:
return self.fallback_experts
......
......@@ -32,8 +32,8 @@ class TritonOrDeepGemmExperts(FallbackExperts):
@staticmethod
def get_clses() -> tuple[
type[mk.FusedMoEPermuteExpertsUnpermute],
type[mk.FusedMoEPermuteExpertsUnpermute],
type[mk.FusedMoEExpertsModular],
type[mk.FusedMoEExpertsModular],
]:
return (DeepGemmExperts, TritonExperts)
......@@ -79,7 +79,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEExpertsModular:
if is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2):
return self.experts
else:
......
......@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
class TrtLlmGenExperts(mk.FusedMoEExpertsModular):
"""TensorRT-LLM-based fused MoE expert implementation."""
def __init__(
......
......@@ -23,7 +23,7 @@ if current_platform.is_xpu():
from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe
class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
class XPUExperts(mk.FusedMoEExpertsModular):
def __init__(
self,
moe_config: FusedMoEConfig,
......
......@@ -172,7 +172,7 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
# Further check if the ModularKernel implementation uses the DeepGemmExperts
return isinstance(
module.quant_method.moe_mk, (DeepGemmExperts, TritonOrDeepGemmExperts)
module.quant_method.moe_kernel, (DeepGemmExperts, TritonOrDeepGemmExperts)
)
......
This diff is collapsed.
......@@ -140,6 +140,7 @@ autotune = _lazy_import_wrapper(
"autotune",
fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(),
)
_is_fi_autotuning: bool = False
@functools.cache
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment