Unverified Commit 805d62ca authored by Varun Sundar Rabindranath's avatar Varun Sundar Rabindranath Committed by GitHub
Browse files

[Misc] DP : Add ExpertTokensMetadata (#20332)


Signed-off-by: default avatarVarun <vsundarr@redhat.com>
Signed-off-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: default avatarVarun <vsundarr@redhat.com>
parent b7d9e941
...@@ -260,8 +260,11 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -260,8 +260,11 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
): ):
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
import deep_gemm as dg import deep_gemm as dg
assert hidden_states.ndim == 3 assert hidden_states.ndim == 3
assert self.block_shape is not None assert self.block_shape is not None
...@@ -287,7 +290,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -287,7 +290,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
masked_m=expert_num_tokens, masked_m=expert_num_tokens,
expected_m=expected_m) expected_m=expected_m)
assert expert_num_tokens is not None
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1, a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
expert_num_tokens) expert_num_tokens)
......
...@@ -129,7 +129,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -129,7 +129,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
): ):
experts = (self.batched_deep_gemm_experts experts = (self.batched_deep_gemm_experts
if self.allow_deep_gemm else self.batched_triton_experts) if self.allow_deep_gemm else self.batched_triton_experts)
...@@ -137,4 +137,4 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -137,4 +137,4 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
experts.apply(output, hidden_states, w1, w2, topk_ids, activation, experts.apply(output, hidden_states, w1, w2, topk_ids, activation,
global_num_experts, expert_map, w1_scale, w2_scale, global_num_experts, expert_map, w1_scale, w2_scale,
w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
workspace2, expert_num_tokens) workspace2, expert_tokens_meta)
...@@ -303,11 +303,17 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -303,11 +303,17 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
): ):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
expert_num_tokens = None
if expert_tokens_meta is not None:
expert_num_tokens = expert_tokens_meta.expert_num_tokens
activation_callable = lambda o, i: self.activation(activation, o, i) activation_callable = lambda o, i: self.activation(activation, o, i)
in_dtype = hidden_states.dtype in_dtype = hidden_states.dtype
run_cutlass_moe_fp8( run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable, output, hidden_states, w1, w2, topk_ids, activation_callable,
......
...@@ -119,7 +119,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -119,7 +119,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
): ):
import deep_gemm as dg import deep_gemm as dg
assert self.block_shape is not None assert self.block_shape is not None
......
...@@ -62,8 +62,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -62,8 +62,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
has_scales = token_scales is not None has_scales = token_scales is not None
(num_tokens_per_rank, num_tokens_per_rdma_rank, expert_num_tokens, (num_tokens_per_rank, num_tokens_per_rdma_rank,
is_token_in_rank, event) = self.buffer.get_dispatch_layout( dispatch_expert_num_tokens, is_token_in_rank,
event) = self.buffer.get_dispatch_layout(
topk_idx=rank_topk_ids, topk_idx=rank_topk_ids,
num_experts=num_experts, num_experts=num_experts,
previous_event=None, previous_event=None,
...@@ -83,7 +84,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -83,7 +84,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank, is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=expert_num_tokens, num_tokens_per_expert=dispatch_expert_num_tokens,
topk_idx=rank_topk_ids, topk_idx=rank_topk_ids,
topk_weights=rank_topk_weights, topk_weights=rank_topk_weights,
# expert_alignment rounds the number of tokens per expert # expert_alignment rounds the number of tokens per expert
...@@ -115,7 +116,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -115,7 +116,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts - 1 if self.rank_expert_offset == 0 else 0, num_experts - 1 if self.rank_expert_offset == 0 else 0,
expert_topk_ids + self.rank_expert_offset) expert_topk_ids + self.rank_expert_offset)
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, # Makes a GPU-CPU copy.
# TODO (varun): Maybe it is better to re-compute the expert_num_tokens
# on GPU.
expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list(
expert_num_tokens_per_expert_list, device=expert_x.device)
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) expert_topk_weights)
def prepare( def prepare(
...@@ -129,8 +136,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -129,8 +136,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]: Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)
...@@ -149,7 +157,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -149,7 +157,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) )
if a1q_scale is not None and a1q_scale.numel() == 1: if a1q_scale is not None and a1q_scale.numel() == 1:
a1q_scale = a1q_scale.view(1, 1) a1q_scale = a1q_scale.view(1, 1)
(expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) = self._do_dispatch( expert_topk_weights) = self._do_dispatch(
tokens=a1q, tokens=a1q,
token_scales=a1q_scale, token_scales=a1q_scale,
...@@ -159,7 +167,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -159,7 +167,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
else: else:
# DeepEP kernels only support dispatching per-token-quant # DeepEP kernels only support dispatching per-token-quant
# quantization. dispatch in bfloat16. # quantization. dispatch in bfloat16.
(expert_x, _, expert_num_tokens, expert_topk_ids, (expert_x, _, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) = self._do_dispatch( expert_topk_weights) = self._do_dispatch(
tokens=a1, tokens=a1,
token_scales=None, token_scales=None,
...@@ -176,7 +184,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -176,7 +184,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
per_act_token_quant=False, per_act_token_quant=False,
block_shape=quant_config.block_shape) block_shape=quant_config.block_shape)
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) expert_topk_weights)
def _apply_weights_and_reduce(self, num_tokens: int, def _apply_weights_and_reduce(self, num_tokens: int,
......
...@@ -119,8 +119,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -119,8 +119,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]: Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
hidden_size = a1.size(1) hidden_size = a1.size(1)
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \ assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
...@@ -158,7 +159,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -158,7 +159,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype, expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape) quant_config.per_act_token_quant, quant_config.block_shape)
return (expert_x, expert_x_scale, expert_num_tokens, None, None) expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)
return (expert_x, expert_x_scale, expert_tokens_meta, None, None)
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor,
......
...@@ -505,8 +505,9 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -505,8 +505,9 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]: Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
assert a1.dim() == 2 assert a1.dim() == 2
assert topk_ids.dim() == 2 assert topk_ids.dim() == 2
assert topk_ids.size(0) == a1.size(0) assert topk_ids.size(0) == a1.size(0)
...@@ -587,7 +588,10 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -587,7 +588,10 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
assert b_a1_scale is None or b_a1_scale.ndim == 3 assert b_a1_scale is None or b_a1_scale.ndim == 3
return b_a1, b_a1_scale, tokens_per_expert, None, None expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=tokens_per_expert, expert_num_tokens_cpu=None)
return b_a1, b_a1_scale, expert_tokens_meta, None, None
def finalize( def finalize(
self, self,
...@@ -694,28 +698,19 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -694,28 +698,19 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
else: else:
return t.to(f32) * group_broadcast(scale, t.shape) return t.to(f32) * group_broadcast(scale, t.shape)
def apply( def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
self, w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor,
output: torch.Tensor, activation: str, global_num_experts: int,
hidden_states: torch.Tensor, expert_map: Optional[torch.Tensor],
w1: torch.Tensor, w1_scale: Optional[torch.Tensor],
w2: torch.Tensor, w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
topk_ids: torch.Tensor, w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
activation: str, a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
global_num_experts: int, workspace2: torch.Tensor,
expert_map: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata]):
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
):
assert hidden_states.dim() == 3 assert hidden_states.dim() == 3
assert expert_num_tokens is not None assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
num_local_experts = w1.size(0) num_local_experts = w1.size(0)
assert num_local_experts == w1.size(0), ( assert num_local_experts == w1.size(0), (
...@@ -902,26 +897,16 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -902,26 +897,16 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
output = (num_experts, max_num_tokens * num_dp, K) output = (num_experts, max_num_tokens * num_dp, K)
return (workspace13, workspace2, output, a.dtype) return (workspace13, workspace2, output, a.dtype)
def apply( def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
self, w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor,
output: torch.Tensor, activation: str, global_num_experts: int,
hidden_states: torch.Tensor, expert_map: Optional[torch.Tensor],
w1: torch.Tensor, w1_scale: Optional[torch.Tensor],
w2: torch.Tensor, w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
topk_ids: torch.Tensor, w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
activation: str, a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
global_num_experts: int, workspace2: torch.Tensor,
expert_map: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata]):
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
):
# Check constraints. # Check constraints.
if self.use_int4_w4a16: if self.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), ( assert hidden_states.size(-1) // 2 == w1.size(2), (
...@@ -938,6 +923,9 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -938,6 +923,9 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert hidden_states.dtype in [ assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
] ]
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size( E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
hidden_states, w1, w2, topk_ids) hidden_states, w1, w2, topk_ids)
......
...@@ -1630,7 +1630,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1630,7 +1630,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
): ):
# Check constraints. # Check constraints.
if self.use_int4_w4a16: if self.use_int4_w4a16:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum from enum import Enum
from math import prod from math import prod
from typing import Optional, final from typing import Optional, final
...@@ -95,6 +96,26 @@ class FusedMoEActivationFormat(Enum): ...@@ -95,6 +96,26 @@ class FusedMoEActivationFormat(Enum):
BatchedExperts = "batched_experts", BatchedExperts = "batched_experts",
@dataclass
class ExpertTokensMetadata:
"""
Metadata regarding expert-token routing.
"""
expert_num_tokens: torch.Tensor
expert_num_tokens_cpu: Optional[torch.Tensor]
@staticmethod
def make_from_list(expert_num_tokens_list: list[int],
device: str) -> "ExpertTokensMetadata":
expert_num_tokens_cpu = torch.tensor(expert_num_tokens_list,
device="cpu",
dtype=torch.int32)
return ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens_cpu.to(device,
non_blocking=True),
expert_num_tokens_cpu=expert_num_tokens_cpu)
# TODO: pass FusedMoEParallelConfig in as ctor parameter? # TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalize(ABC): class FusedMoEPrepareAndFinalize(ABC):
""" """
...@@ -114,8 +135,9 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -114,8 +135,9 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]: Optional[ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
""" """
Perform any quantization (and/or) dispatching needed Perform any quantization (and/or) dispatching needed
for this kernel. for this kernel.
...@@ -134,7 +156,8 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -134,7 +156,8 @@ class FusedMoEPrepareAndFinalize(ABC):
Returns a tuple of: Returns a tuple of:
- quantized + dispatched a. - quantized + dispatched a.
- quantized + dispatched a1_scales. - quantized + dispatched a1_scales.
- Optional tensor as big as number of local experts that contains the - Optional ExpertTokensMetadata containing gpu/cpu tensors
as big as the number of local experts with the information about the
number of tokens assigned to each local expert. number of tokens assigned to each local expert.
- Optional dispatched expert topk IDs - Optional dispatched expert topk IDs
- Optional dispatched expert topk weight - Optional dispatched expert topk weight
...@@ -318,7 +341,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -318,7 +341,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_tokens_meta: Optional[ExpertTokensMetadata],
): ):
""" """
This function computes the intermediate result of a Mixture of Experts This function computes the intermediate result of a Mixture of Experts
...@@ -351,8 +374,10 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -351,8 +374,10 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
must be large enough to hold output of either MoE gemm. must be large enough to hold output of either MoE gemm.
- workspace2 (torch.Tensor): A scratch tensor used for the activation - workspace2 (torch.Tensor): A scratch tensor used for the activation
function. function.
- expert_num_tokens: An optional tensor containing the number of tokens - expert_tokens_meta (Optional[ExpertTokensMetadata]) - An optional
assigned to each expert when using batched experts format input. ExpertTokensMetadata object containing gpu/cpu tensors
as big as the number of local experts with the information about the
number of tokens assigned to each local expert.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -458,7 +483,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -458,7 +483,7 @@ class FusedMoEModularKernel(torch.nn.Module):
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = local_num_experts global_num_experts = local_num_experts
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids, (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare( _expert_topk_weights) = self.prepare_finalize.prepare(
a1, a1,
a1_scale, a1_scale,
...@@ -542,7 +567,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -542,7 +567,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=a2_scale, a2_scale=a2_scale,
workspace13=workspace13, workspace13=workspace13,
workspace2=workspace2, workspace2=workspace2,
expert_num_tokens=expert_num_tokens, expert_tokens_meta=expert_tokens_meta,
) )
else: else:
# The leading output dimension may not be equal to M, so # The leading output dimension may not be equal to M, so
...@@ -589,7 +614,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -589,7 +614,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=curr_a2_scale, a2_scale=curr_a2_scale,
workspace13=workspace13, workspace13=workspace13,
workspace2=workspace2, workspace2=workspace2,
expert_num_tokens=expert_num_tokens, expert_tokens_meta=expert_tokens_meta,
) )
self.prepare_finalize.finalize(output, fused_out, topk_weights, self.prepare_finalize.finalize(output, fused_out, topk_weights,
......
...@@ -94,8 +94,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -94,8 +94,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]: Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
num_tokens = a1.size(0) # M num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K hidden_dim = a1.size(-1) # K
...@@ -200,7 +201,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -200,7 +201,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
assert expert_x_scale.ndim == 3 assert expert_x_scale.ndim == 3
return expert_x, expert_x_scale, expert_num_tokens, None, None expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)
return expert_x, expert_x_scale, expert_tokens_meta, None, None
def finalize( def finalize(
self, self,
......
...@@ -38,8 +38,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -38,8 +38,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]: Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)
......
...@@ -110,7 +110,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -110,7 +110,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
): ):
use_deep_gemm = (self.allow_deep_gemm use_deep_gemm = (self.allow_deep_gemm
and _valid_deep_gemm(hidden_states, w1, w2)) and _valid_deep_gemm(hidden_states, w1, w2))
...@@ -135,5 +135,5 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -135,5 +135,5 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale, a2_scale,
workspace13, workspace13,
workspace2, workspace2,
expert_num_tokens, expert_tokens_meta,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment