Unverified Commit da364615 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernels] Modular kernel refactor (#24812)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent f08919b7
...@@ -83,8 +83,6 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -83,8 +83,6 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
...@@ -92,7 +90,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -92,7 +90,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Note: the deep gemm workspaces are strictly larger than the triton # Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm # workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set. # even if we fall back to triton later, e.g. if expert maps are set.
...@@ -101,8 +99,6 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -101,8 +99,6 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
): ):
assert self.deep_gemm_expert is not None assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.workspace_shapes( return self.deep_gemm_expert.workspace_shapes(
a,
aq,
M, M,
N, N,
K, K,
...@@ -113,8 +109,6 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -113,8 +109,6 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
) )
else: else:
return self.triton_expert.workspace_shapes( return self.triton_expert.workspace_shapes(
a,
aq,
M, M,
N, N,
K, K,
......
...@@ -52,8 +52,6 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -52,8 +52,6 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
...@@ -61,14 +59,12 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -61,14 +59,12 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# The workspaces for this implementation are managed by flashinfer. # The workspaces for this implementation are managed by flashinfer.
# TODO(varun) : workspace1 is could be used as the output tensor. This workspace1 = (0,)
# is error-prone. Allow the `workspace_shapes` to return None workspaces workspace2 = (0,)
workspace1 = (M, K)
workspace2 = (0, 0)
output = (M, K) output = (M, K)
return (workspace1, workspace2, output, a.dtype) return (workspace1, workspace2, output)
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int, local_num_experts: int): def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int, local_num_experts: int):
# Number of tokens in the input tensor. # Number of tokens in the input tensor.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment