Unverified Commit e3b12667 authored by Varun Sundar Rabindranath's avatar Varun Sundar Rabindranath Committed by GitHub
Browse files

[BugFix] : Fix Batched DeepGemm Experts (#19515)


Signed-off-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
parent e6aab5de
...@@ -47,15 +47,21 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -47,15 +47,21 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
N: int, N: int,
K: int, K: int,
topk: int, topk: int,
num_experts: int, global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2 assert a.dim() == 2
num_dp = self.world_size // self.dp_size # FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers = self.world_size
num_experts = local_num_experts
max_num_tokens = a.size( max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens 0) if self.max_num_tokens is None else self.max_num_tokens
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N)) workspace13 = (num_experts, max_num_tokens * num_dispatchers,
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2)) max(K, N))
output = (num_experts, max_num_tokens * num_dp, K) workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype) return (workspace13, workspace2, output, a.dtype)
def apply( def apply(
...@@ -84,9 +90,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -84,9 +90,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a1q = hidden_states a1q = hidden_states
_, N, K = w1.size() _, N, K = w1.size()
if global_num_experts == -1:
global_num_experts = w1.size(0)
assert w2.size(1) == K assert w2.size(1) == K
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size( E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
......
...@@ -81,18 +81,19 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -81,18 +81,19 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
N: int, N: int,
K: int, K: int,
topk: int, topk: int,
num_experts: int, global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# Note: the deep gemm workspaces are strictly larger than the triton # Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm # workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set. # even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None: if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None:
return self.batched_deep_gemm_experts.workspace_shapes( return self.batched_deep_gemm_experts.workspace_shapes(
a, aq, M, N, K, topk, num_experts) a, aq, M, N, K, topk, global_num_experts, local_num_experts)
else: else:
assert self.batched_triton_experts is not None assert self.batched_triton_experts is not None
return self.batched_triton_experts.workspace_shapes( return self.batched_triton_experts.workspace_shapes(
a, aq, M, N, K, topk, num_experts) a, aq, M, N, K, topk, global_num_experts, local_num_experts)
def apply( def apply(
self, self,
......
...@@ -230,7 +230,8 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -230,7 +230,8 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
N: int, N: int,
K: int, K: int,
topk: int, topk: int,
num_experts: int, global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1: tuple[int, ...] = () workspace1: tuple[int, ...] = ()
workspace2: tuple[int, ...] = () workspace2: tuple[int, ...] = ()
......
...@@ -74,15 +74,12 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -74,15 +74,12 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
return True return True
def workspace_shapes( def workspace_shapes(
self, self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
a: torch.Tensor, topk: int, global_num_experts: int, local_num_experts: int
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
num_experts = global_num_experts
block_m = self.block_shape[0] block_m = self.block_shape[0]
M_sum = (M * topk) + num_experts * (block_m - 1) M_sum = (M * topk) + num_experts * (block_m - 1)
M_sum = round_up(M_sum, block_m) M_sum = round_up(M_sum, block_m)
......
...@@ -521,10 +521,12 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -521,10 +521,12 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
N: int, N: int,
K: int, K: int,
topk: int, topk: int,
num_experts: int, global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2 assert a.dim() == 2
num_dp = self.world_size // self.dp_size num_dp = self.dp_size
num_experts = local_num_experts
workspace13 = (num_experts, self.max_num_tokens * num_dp, K) workspace13 = (num_experts, self.max_num_tokens * num_dp, K)
workspace2 = (self.max_num_tokens * num_dp, N) workspace2 = (self.max_num_tokens * num_dp, N)
return (workspace13, workspace2, workspace13, a.dtype) return (workspace13, workspace2, workspace13, a.dtype)
...@@ -624,10 +626,12 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -624,10 +626,12 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
N: int, N: int,
K: int, K: int,
topk: int, topk: int,
num_experts: int, global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2 assert a.dim() == 2
num_dp = self.world_size // self.dp_size num_dp = self.world_size // self.dp_size
num_experts = local_num_experts
max_num_tokens = a.size( max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens 0) if self.max_num_tokens is None else self.max_num_tokens
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N)) workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
......
...@@ -1553,7 +1553,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1553,7 +1553,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
N: int, N: int,
K: int, K: int,
topk: int, topk: int,
num_experts: int, global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1 = (M, topk, max(N * 2, K)) workspace1 = (M, topk, max(N * 2, K))
workspace2 = (M, topk, N) workspace2 = (M, topk, N)
......
...@@ -194,7 +194,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -194,7 +194,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
N: int, N: int,
K: int, K: int,
topk: int, topk: int,
num_experts: int, global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
""" """
Compute the shapes for the temporary and final outputs of the two gemms Compute the shapes for the temporary and final outputs of the two gemms
...@@ -372,8 +373,9 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -372,8 +373,9 @@ class FusedMoEModularKernel(torch.nn.Module):
a1 = hidden_states a1 = hidden_states
output = a1 if inplace else torch.zeros_like(a1) output = a1 if inplace else torch.zeros_like(a1)
local_num_experts = w1.size(0)
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = w1.size(0) global_num_experts = local_num_experts
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids, (a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare( _expert_topk_weights) = self.prepare_finalize.prepare(
...@@ -408,16 +410,19 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -408,16 +410,19 @@ class FusedMoEModularKernel(torch.nn.Module):
if num_chunks == 1: if num_chunks == 1:
(workspace13_shape, workspace2_shape, fused_out_shape, (workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.fused_experts.workspace_shapes( workspace_dtype) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts) a1, a1q, M, N, K, top_k, global_num_experts,
local_num_experts)
else: else:
# Use the full M to get the final output shape. # Use the full M to get the final output shape.
_, _, fused_out_shape, _ = ( _, _, fused_out_shape, _ = (
self.fused_experts.workspace_shapes( self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts)) a1, a1q, M, N, K, top_k, global_num_experts,
local_num_experts))
# Use the CHUNK_SIZE to get the workspace shapes. # Use the CHUNK_SIZE to get the workspace shapes.
workspace13_shape, workspace2_shape, _, workspace_dtype = ( workspace13_shape, workspace2_shape, _, workspace_dtype = (
self.fused_experts.workspace_shapes( self.fused_experts.workspace_shapes(
a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts)) a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts,
local_num_experts))
# We can reuse the memory between cache1 and cache3 because by the # We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1. # time we need cache3, we're done with cache1.
......
...@@ -159,6 +159,12 @@ def moe_align_block_size( ...@@ -159,6 +159,12 @@ def moe_align_block_size(
Aligns the token distribution across experts to be compatible with block Aligns the token distribution across experts to be compatible with block
size for matrix multiplication. size for matrix multiplication.
Note: In the case of expert_parallel, moe_align_block_size initially
considers all experts as valid and aligns all tokens appropriately.
Before the function returns it marks the experts_ids that are not in
the current GPU rank as -1 so the MoE matmuls could skip those blocks.
This requires the num_experts input arg to be the num global experts.
Parameters: Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the - topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token. top-k expert indices for each token.
......
...@@ -48,7 +48,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -48,7 +48,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
N: int, N: int,
K: int, K: int,
topk: int, topk: int,
num_experts: int, global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# Note: the deep gemm workspaces are strictly larger than the triton # Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm # workspaces so we can be pessimistic here and allocate for DeepGemm
...@@ -56,10 +57,11 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -56,10 +57,11 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
assert self.deep_gemm_expert is not None assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.workspace_shapes( return self.deep_gemm_expert.workspace_shapes(
a, aq, M, N, K, topk, num_experts) a, aq, M, N, K, topk, global_num_experts, local_num_experts)
else: else:
return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk, return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk,
num_experts) global_num_experts,
local_num_experts)
def apply( def apply(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment