Unverified Commit da364615 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernels] Modular kernel refactor (#24812)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent f08919b7
......@@ -83,8 +83,6 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
......@@ -92,7 +90,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
......@@ -101,8 +99,6 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
):
assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.workspace_shapes(
a,
aq,
M,
N,
K,
......@@ -113,8 +109,6 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
)
else:
return self.triton_expert.workspace_shapes(
a,
aq,
M,
N,
K,
......
......@@ -52,8 +52,6 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
......@@ -61,14 +59,12 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# The workspaces for this implementation are managed by flashinfer.
# TODO(varun) : workspace1 is could be used as the output tensor. This
# is error-prone. Allow the `workspace_shapes` to return None workspaces
workspace1 = (M, K)
workspace2 = (0, 0)
workspace1 = (0,)
workspace2 = (0,)
output = (M, K)
return (workspace1, workspace2, output, a.dtype)
return (workspace1, workspace2, output)
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int, local_num_experts: int):
# Number of tokens in the input tensor.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment