Unverified Commit 97995f63 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[MoE Refactor] Create MK for TRTLLM Kernels (#32564)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Signed-off-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
Signed-off-by: default avatarRobert Shaw <robertgshaw2@gmail.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <rshaw@neuralmagic.com>
parent 881a6b01
...@@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk ...@@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce): class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce):
""" """
Useful in the case when some FusedMoEPermuteExpertsUnpermute Useful in the case when some FusedMoEExpertsModular
implementation does not perform weight application and reduction implementation does not perform weight application and reduction
but cannot address the needs of all the compatible PrepareAndFinalize but cannot address the needs of all the compatible PrepareAndFinalize
implementations. implementations.
...@@ -62,7 +62,7 @@ class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce): ...@@ -62,7 +62,7 @@ class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce):
if output is None: if output is None:
return fused_expert_output return fused_expert_output
# MoEPrepareAndFinalizeNoEP needs the output to be in the `output` # MoEPrepareAndFinalizeNoDPEPModular needs the output to be in the `output`
# tensor. # tensor.
assert output.size() == fused_expert_output.size(), ( assert output.size() == fused_expert_output.size(), (
"output shape is expected to match the fused_expert_output shape. " "output shape is expected to match the fused_expert_output shape. "
......
...@@ -32,8 +32,8 @@ class TritonOrCutlassExperts(FallbackExperts): ...@@ -32,8 +32,8 @@ class TritonOrCutlassExperts(FallbackExperts):
@staticmethod @staticmethod
def get_clses() -> tuple[ def get_clses() -> tuple[
type[mk.FusedMoEPermuteExpertsUnpermute], type[mk.FusedMoEExpertsModular],
type[mk.FusedMoEPermuteExpertsUnpermute], type[mk.FusedMoEExpertsModular],
]: ]:
return (CutlassExpertsFp8, TritonExperts) return (CutlassExpertsFp8, TritonExperts)
...@@ -77,7 +77,7 @@ class TritonOrCutlassExperts(FallbackExperts): ...@@ -77,7 +77,7 @@ class TritonOrCutlassExperts(FallbackExperts):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEExpertsModular:
# Small batch fallback for sm100. # Small batch fallback for sm100.
if self.is_sm100 and hidden_states.shape[0] <= 8: if self.is_sm100 and hidden_states.shape[0] <= 8:
return self.fallback_experts return self.fallback_experts
......
...@@ -32,8 +32,8 @@ class TritonOrDeepGemmExperts(FallbackExperts): ...@@ -32,8 +32,8 @@ class TritonOrDeepGemmExperts(FallbackExperts):
@staticmethod @staticmethod
def get_clses() -> tuple[ def get_clses() -> tuple[
type[mk.FusedMoEPermuteExpertsUnpermute], type[mk.FusedMoEExpertsModular],
type[mk.FusedMoEPermuteExpertsUnpermute], type[mk.FusedMoEExpertsModular],
]: ]:
return (DeepGemmExperts, TritonExperts) return (DeepGemmExperts, TritonExperts)
...@@ -79,7 +79,7 @@ class TritonOrDeepGemmExperts(FallbackExperts): ...@@ -79,7 +79,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
) -> mk.FusedMoEPermuteExpertsUnpermute: ) -> mk.FusedMoEExpertsModular:
if is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2): if is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2):
return self.experts return self.experts
else: else:
......
This diff is collapsed.
...@@ -140,6 +140,7 @@ autotune = _lazy_import_wrapper( ...@@ -140,6 +140,7 @@ autotune = _lazy_import_wrapper(
"autotune", "autotune",
fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(), fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(),
) )
_is_fi_autotuning: bool = False
@functools.cache @functools.cache
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment