Unverified Commit c0569dbc authored by Varun Sundar Rabindranath's avatar Varun Sundar Rabindranath Committed by GitHub
Browse files

[Misc] ModularKernel : Perform WeightAndReduce inside TritonExperts & DeepGemmExperts (#20725)


Signed-off-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
parent 8bb43b9c
...@@ -260,6 +260,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -260,6 +260,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: str,
global_num_experts: int, global_num_experts: int,
...@@ -273,6 +274,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -273,6 +274,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
): ):
assert expert_tokens_meta is not None assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens expert_num_tokens = expert_tokens_meta.expert_num_tokens
......
...@@ -129,30 +129,22 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -129,30 +129,22 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
return self.batched_triton_experts.workspace_shapes( return self.batched_triton_experts.workspace_shapes(
a, aq, M, N, K, topk, global_num_experts, local_num_experts) a, aq, M, N, K, topk, global_num_experts, local_num_experts)
def apply( def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
self, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
output: torch.Tensor, topk_ids: torch.Tensor, activation: str, global_num_experts: int,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
): apply_router_weight_on_input: bool):
experts = (self.batched_deep_gemm_experts experts = (self.batched_deep_gemm_experts
if self.allow_deep_gemm else self.batched_triton_experts) if self.allow_deep_gemm else self.batched_triton_experts)
assert experts is not None assert experts is not None
experts.apply(output, hidden_states, w1, w2, topk_ids, activation, experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids,
global_num_experts, expert_map, w1_scale, w2_scale, activation, global_num_experts, expert_map, w1_scale,
w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
workspace2, expert_tokens_meta) workspace2, expert_tokens_meta,
apply_router_weight_on_input)
...@@ -291,26 +291,17 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -291,26 +291,17 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
return (workspace1, workspace2, output, return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype) self.out_dtype if self.out_dtype is not None else a.dtype)
def apply( def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
self, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
output: torch.Tensor, topk_ids: torch.Tensor, activation: str, global_num_experts: int,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
): apply_router_weight_on_input: bool):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
......
...@@ -13,7 +13,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( ...@@ -13,7 +13,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP) MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate) TopKWeightAndReduceContiguous, TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8) per_token_group_quant_fp8)
...@@ -90,8 +90,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -90,8 +90,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
return True return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl. return TopKWeightAndReduceNoOP()
return TopKWeightAndReduceDelegate()
def workspace_shapes( def workspace_shapes(
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
...@@ -104,9 +103,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -104,9 +103,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
block_m = self.block_shape[0] block_m = self.block_shape[0]
M_sum = (M * topk) + num_experts * (block_m - 1) M_sum = (M * topk) + num_experts * (block_m - 1)
M_sum = round_up(M_sum, block_m) M_sum = round_up(M_sum, block_m)
workspace1 = (M_sum, max(N * 2, K)) workspace1 = (M_sum, max(N // 2, K))
workspace2 = (M_sum, max(N, K)) workspace2 = (M_sum, max(N, K))
output = (M, topk, K) output = (M, K)
return (workspace1, workspace2, output, a.dtype) return (workspace1, workspace2, output, a.dtype)
def apply( def apply(
...@@ -115,6 +114,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -115,6 +114,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: str,
global_num_experts: int, global_num_experts: int,
...@@ -128,11 +128,14 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -128,11 +128,14 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
): ):
assert self.block_shape is not None assert self.block_shape is not None
a1q = hidden_states a1q = hidden_states
_, N, K = w1.size() _, N, K = w1.size()
M, _ = output.size()
num_topk = topk_ids.size(1)
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = w1.size(0) global_num_experts = w1.size(0)
...@@ -159,11 +162,12 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -159,11 +162,12 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# Note: M_sum is different than the pre-permuted shape of a1q. # Note: M_sum is different than the pre-permuted shape of a1q.
M_sum = a1q.size(0) M_sum = a1q.size(0)
mm1_out = _resize_cache(workspace13, (M_sum, N)) mm1_out = _resize_cache(workspace2, (M_sum, N))
act_out = _resize_cache(workspace2, (M_sum, N // 2)) act_out = _resize_cache(workspace13, (M_sum, N // 2))
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), quant_out = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn),
(M_sum, N // 2)) (M_sum, N // 2))
mm2_out = _resize_cache(workspace2, (M_sum, K)) mm2_out = _resize_cache(workspace13, (M_sum, K))
perm_out = _resize_cache(workspace2, (M * num_topk, K))
m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale), m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale),
mm1_out, expert_ids) mm1_out, expert_ids)
...@@ -179,7 +183,14 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -179,7 +183,14 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale), m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
mm2_out, expert_ids) mm2_out, expert_ids)
torch.index_select(mm2_out, 0, inv_perm, out=output.view((-1, K))) torch.index_select(mm2_out, 0, inv_perm, out=perm_out)
TopKWeightAndReduceContiguous().apply(
output=output,
fused_expert_output=perm_out,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input)
def deep_gemm_moe_fp8( def deep_gemm_moe_fp8(
......
...@@ -696,15 +696,16 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -696,15 +696,16 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
return t.to(f32) * group_broadcast(scale, t.shape) return t.to(f32) * group_broadcast(scale, t.shape)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
activation: str, global_num_experts: int, topk_ids: torch.Tensor, activation: str, global_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]): expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool):
assert hidden_states.dim() == 3 assert hidden_states.dim() == 3
assert expert_tokens_meta is not None assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens expert_num_tokens = expert_tokens_meta.expert_num_tokens
...@@ -899,15 +900,16 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -899,15 +900,16 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
return (workspace13, workspace2, output, a.dtype) return (workspace13, workspace2, output, a.dtype)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
activation: str, global_num_experts: int, topk_ids: torch.Tensor, activation: str, global_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]): expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool):
# Check constraints. # Check constraints.
if self.use_int4_w4a16: if self.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), ( assert hidden_states.size(-1) // 2 == w1.size(2), (
......
...@@ -26,7 +26,7 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( ...@@ -26,7 +26,7 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP) MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate) TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input) _resize_cache, moe_kernel_quantize_input)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
...@@ -1606,8 +1606,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1606,8 +1606,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
return True return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl. return TopKWeightAndReduceNoOP()
return TopKWeightAndReduceDelegate()
def workspace_shapes( def workspace_shapes(
self, self,
...@@ -1620,9 +1619,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1620,9 +1619,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1 = (M, topk, max(N * 2, K)) workspace1 = (M, topk, max(N // 2, K))
workspace2 = (M, topk, N) workspace2 = (M, topk, max(N, K))
output = (M, topk, K) output = (M, K)
return (workspace1, workspace2, output, a.dtype) return (workspace1, workspace2, output, a.dtype)
def apply( def apply(
...@@ -1631,6 +1630,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1631,6 +1630,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: str,
global_num_experts: int, global_num_experts: int,
...@@ -1644,6 +1644,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1644,6 +1644,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
): ):
# Check constraints. # Check constraints.
if self.use_int4_w4a16: if self.use_int4_w4a16:
...@@ -1696,28 +1697,30 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1696,28 +1697,30 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
raise ValueError( raise ValueError(
f"Unsupported compute_type: {hidden_states.dtype}") f"Unsupported compute_type: {hidden_states.dtype}")
# We can reuse the memory between these because by the time we need # Note that the output tensor might be in workspace1
# cache3, we're done with cache1 intermediate_cache1 = _resize_cache(workspace2,
intermediate_cache1 = _resize_cache(workspace13,
(num_tokens, top_k_num, N)) (num_tokens, top_k_num, N))
intermediate_cache2 = _resize_cache(workspace2, intermediate_cache2 = _resize_cache(workspace13,
(num_tokens * top_k_num, N // 2)) (num_tokens * top_k_num, N // 2))
intermediate_cache3 = _resize_cache(workspace2,
(num_tokens, top_k_num, K))
sorted_token_ids, expert_ids, num_tokens_post_padded = ( sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
global_num_experts, expert_map)) global_num_experts, expert_map))
invoke_fused_moe_kernel(hidden_states, invoke_fused_moe_kernel(
hidden_states,
w1, w1,
intermediate_cache1, intermediate_cache1,
a1q_scale, a1q_scale,
w1_scale, w1_scale,
w1_zp, w1_zp,
None, None, # topk_weights
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
False, False, # mul_routed_weights
top_k_num, top_k_num,
config, config,
compute_type=compute_type, compute_type=compute_type,
...@@ -1739,15 +1742,15 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1739,15 +1742,15 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
invoke_fused_moe_kernel(qintermediate_cache2, invoke_fused_moe_kernel(qintermediate_cache2,
w2, w2,
output, intermediate_cache3,
a2q_scale, a2q_scale,
w2_scale, w2_scale,
w2_zp, w2_zp,
None, topk_weights,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
False, not apply_router_weight_on_input,
1, 1,
config, config,
compute_type=compute_type, compute_type=compute_type,
...@@ -1758,6 +1761,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1758,6 +1761,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
per_channel_quant=self.per_act_token_quant, per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape) block_shape=self.block_shape)
ops.moe_sum(intermediate_cache3, output)
def modular_triton_fused_moe( def modular_triton_fused_moe(
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
......
...@@ -360,6 +360,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -360,6 +360,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: str,
global_num_experts: int, global_num_experts: int,
...@@ -373,6 +374,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -373,6 +374,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[ExpertTokensMetadata], expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
): ):
""" """
This function computes the intermediate result of a Mixture of Experts This function computes the intermediate result of a Mixture of Experts
...@@ -384,6 +386,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -384,6 +386,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
layer. layer.
- w1 (torch.Tensor): The first set of expert weights. - w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights. - w2 (torch.Tensor): The second set of expert weights.
- topk_weights: A map of row to expert weights. Some implementations
choose to do weight application.
- topk_ids (torch.Tensor): A map of row to expert id. - topk_ids (torch.Tensor): A map of row to expert id.
- activation (str): The activation function to apply after the first - activation (str): The activation function to apply after the first
MoE layer. MoE layer.
...@@ -409,6 +413,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -409,6 +413,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
ExpertTokensMetadata object containing gpu/cpu tensors ExpertTokensMetadata object containing gpu/cpu tensors
as big as the number of local experts with the information about the as big as the number of local experts with the information about the
number of tokens assigned to each local expert. number of tokens assigned to each local expert.
- apply_router_weight_on_input: True if router weights are already
applied on the input. This is relevant if the implementation
chooses to do weight application.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -452,17 +459,21 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -452,17 +459,21 @@ class FusedMoEModularKernel(torch.nn.Module):
f"{fused_experts.__class__.__name__}." f"{fused_experts.__class__.__name__}."
f"{fused_experts.activation_formats[0]}") f"{fused_experts.activation_formats[0]}")
def _do_fused_experts( def _do_fused_experts(self, fused_out: Optional[torch.Tensor],
self, fused_out: Optional[torch.Tensor], a1: torch.Tensor, a1: torch.Tensor, a1q: torch.Tensor,
a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int, topk_weights: torch.Tensor, topk_ids: torch.Tensor,
local_num_experts: int, expert_map: Optional[torch.Tensor], activation: str, global_num_experts: int,
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], local_num_experts: int,
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata] expert_tokens_meta: Optional[ExpertTokensMetadata],
) -> torch.Tensor: apply_router_weight_on_input: bool) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
...@@ -485,10 +496,12 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -485,10 +496,12 @@ class FusedMoEModularKernel(torch.nn.Module):
# reuse workspace13 for the output # reuse workspace13 for the output
fused_out = _resize_cache(workspace13, fused_out_shape) fused_out = _resize_cache(workspace13, fused_out_shape)
self.fused_experts.apply(fused_out, self.fused_experts.apply(
fused_out,
a1q, a1q,
w1, w1,
w2, w2,
topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
...@@ -501,20 +514,31 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -501,20 +514,31 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=a2_scale, a2_scale=a2_scale,
workspace13=workspace13, workspace13=workspace13,
workspace2=workspace2, workspace2=workspace2,
expert_tokens_meta=expert_tokens_meta) expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input)
return fused_out return fused_out
def _maybe_chunk_fused_experts( def _maybe_chunk_fused_experts(
self, a1: torch.Tensor, a1q: torch.Tensor, w1: torch.Tensor, self,
w2: torch.Tensor, topk_ids: torch.Tensor, activation: str, a1: torch.Tensor,
global_num_experts: int, local_num_experts: int, a1q: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata] expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
) -> torch.Tensor: ) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
...@@ -529,6 +553,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -529,6 +553,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a1q=a1q, a1q=a1q,
w1=w1, w1=w1,
w2=w2, w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
...@@ -540,7 +565,8 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -540,7 +565,8 @@ class FusedMoEModularKernel(torch.nn.Module):
w2_zp=w2_zp, w2_zp=w2_zp,
a1q_scale=a1q_scale, a1q_scale=a1q_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta) expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input)
# Chunking required case # Chunking required case
assert num_chunks > 1 assert num_chunks > 1
...@@ -557,11 +583,12 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -557,11 +583,12 @@ class FusedMoEModularKernel(torch.nn.Module):
def slice_input_tensors( def slice_input_tensors(
chunk_idx: int chunk_idx: int
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], torch.Tensor]: Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
s = chunk_idx * CHUNK_SIZE s = chunk_idx * CHUNK_SIZE
e = min(s + CHUNK_SIZE, M) e = min(s + CHUNK_SIZE, M)
return (a1q[s:e], _chunk_scales(a1q_scale, s, e), return (a1q[s:e], _chunk_scales(a1q_scale, s, e),
_chunk_scales(a2_scale, s, e), topk_ids[s:e]) _chunk_scales(a2_scale, s,
e), topk_ids[s:e], topk_weights[s:e])
def slice_output_tensor(chunk_idx: int) -> torch.Tensor: def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
assert fused_out.size(0) % M == 0, ( assert fused_out.size(0) % M == 0, (
...@@ -594,7 +621,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -594,7 +621,7 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_num_tokens_cpu=c_expert_num_tokens_cpu) expert_num_tokens_cpu=c_expert_num_tokens_cpu)
for chunk_idx in range(num_chunks): for chunk_idx in range(num_chunks):
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids = ( c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
slice_input_tensors(chunk_idx)) slice_input_tensors(chunk_idx))
c_expert_tokens_meta = None c_expert_tokens_meta = None
...@@ -603,11 +630,13 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -603,11 +630,13 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_tokens_meta, c_topk_ids, local_num_experts, expert_tokens_meta, c_topk_ids, local_num_experts,
expert_map) expert_map)
self._do_fused_experts(fused_out=slice_output_tensor(chunk_idx), self._do_fused_experts(
fused_out=slice_output_tensor(chunk_idx),
a1=a1, a1=a1,
a1q=c_a1q, a1q=c_a1q,
w1=w1, w1=w1,
w2=w2, w2=w2,
topk_weights=c_topk_weights,
topk_ids=c_topk_ids, topk_ids=c_topk_ids,
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
...@@ -619,7 +648,8 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -619,7 +648,8 @@ class FusedMoEModularKernel(torch.nn.Module):
w2_zp=w2_zp, w2_zp=w2_zp,
a1q_scale=c_a1q_scale, a1q_scale=c_a1q_scale,
a2_scale=c_a2_scale, a2_scale=c_a2_scale,
expert_tokens_meta=c_expert_tokens_meta) expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input)
return fused_out return fused_out
...@@ -719,6 +749,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -719,6 +749,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a1q=a1q, a1q=a1q,
w1=w1, w1=w1,
w2=w2, w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
...@@ -730,7 +761,8 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -730,7 +761,8 @@ class FusedMoEModularKernel(torch.nn.Module):
w2_zp=w2_zp, w2_zp=w2_zp,
a1q_scale=a1q_scale, a1q_scale=a1q_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta) expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input)
self.prepare_finalize.finalize( self.prepare_finalize.finalize(
output, fused_out, topk_weights, topk_ids, output, fused_out, topk_weights, topk_ids,
......
...@@ -48,12 +48,19 @@ class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce): ...@@ -48,12 +48,19 @@ class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce):
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> torch.Tensor: apply_router_weight_on_input: bool) -> torch.Tensor:
# Relax this if an explicit copy is necessary. Note that, # Weight application and reduction operations are already done.
# if a copy is employed we have to make sure that the if output is None:
# tensors don't overlap
assert output is None
return fused_expert_output return fused_expert_output
# MoEPrepareAndFinalizeNoEP needs the output to be in the `output`
# tensor.
assert output.size() == fused_expert_output.size(), (
"output shape is expected to match the fused_expert_output shape. "
f"But got output={output.size()}, "
f"used_expert_output={fused_expert_output.size()}")
output.copy_(fused_expert_output, non_blocking=True)
return output
class TopKWeightAndReduceContiguous(mk.TopKWeightAndReduce): class TopKWeightAndReduceContiguous(mk.TopKWeightAndReduce):
""" """
......
...@@ -122,6 +122,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -122,6 +122,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation: str,
global_num_experts: int, global_num_experts: int,
...@@ -135,6 +136,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -135,6 +136,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
): ):
use_deep_gemm = (self.allow_deep_gemm use_deep_gemm = (self.allow_deep_gemm
and (_valid_deep_gemm(hidden_states, w1, w2) and (_valid_deep_gemm(hidden_states, w1, w2)
...@@ -148,6 +150,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -148,6 +150,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
hidden_states, hidden_states,
w1, w1,
w2, w2,
topk_weights,
topk_ids, topk_ids,
activation, activation,
global_num_experts, global_num_experts,
...@@ -161,4 +164,5 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -161,4 +164,5 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13, workspace13,
workspace2, workspace2,
expert_tokens_meta, expert_tokens_meta,
apply_router_weight_on_input,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment