Unverified Commit 257015d5 authored by milesial's avatar milesial Committed by GitHub
Browse files

[MoE] Triton MoE Perf regression - restore low latency path (#39016)

parent b4784001
......@@ -1551,6 +1551,55 @@ def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
return torch_vllm_outplace_fused_experts
def _prepare_expert_assignment(
topk_ids: torch.Tensor,
config: dict[str, Any],
num_tokens: int,
top_k_num: int,
global_num_experts: int,
expert_map: torch.Tensor | None,
*,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
block_shape: list[int] | None = None,
ignore_invalid_experts: bool = False,
) -> tuple[torch.Tensor | None, torch.Tensor, torch.Tensor]:
"""Prepare expert assignments for the aligned and low-latency Triton paths."""
# SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k
# activates only a small fraction of total experts
# Skips moe_align_block_size and activates the `sorted_token_ids is None`
# path of the fused_moe_kernel kernel
naive_block_assignment = (
expert_map is None
and num_tokens * top_k_num * 4 <= global_num_experts
and not (
(use_int8_w8a16 or use_int4_w4a16)
and block_shape is not None
and block_shape[1] > 0
)
)
if naive_block_assignment:
return (
None,
topk_ids.view(-1),
torch.full(
(1,),
topk_ids.numel() * config["BLOCK_SIZE_M"],
dtype=torch.int32,
device=topk_ids.device,
),
)
return moe_align_block_size(
topk_ids,
config["BLOCK_SIZE_M"],
global_num_experts,
expert_map,
ignore_invalid_experts=ignore_invalid_experts,
)
# TODO (bnell): replace this with modular op. Can get rid of inplace/outplace
# torch ops.
def fused_experts(
......@@ -1791,36 +1840,18 @@ def fused_experts_impl(
ocp_mx_scheme=ocp_mx_scheme,
)
# SPARSITY_FACTOR is a heuristic margin ensuring num_tokens * top_k
# activates only a small fraction of total experts
SPARSITY_FACTOR = 4
# block quantized code path is not implemented yet.
naive_block_assignment = (
expert_map is None
and num_tokens * top_k_num * SPARSITY_FACTOR <= global_num_experts
and not (
(use_int8_w8a16 or use_int4_w4a16)
and block_shape is not None
and block_shape[1] > 0
)
)
if not naive_block_assignment:
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
sorted_token_ids, expert_ids, num_tokens_post_padded = _prepare_expert_assignment(
topk_ids,
config["BLOCK_SIZE_M"],
config,
num_tokens,
top_k_num,
global_num_experts,
expert_map,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape,
ignore_invalid_experts=True,
)
else:
max_num_tokens_padded = topk_ids.numel() * config["BLOCK_SIZE_M"]
expert_ids = topk_ids.view(-1)
num_tokens_post_padded = torch.empty(
(1), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_padded.fill_(max_num_tokens_padded)
sorted_token_ids = None
dispatch_fused_moe_kernel(
qhidden_states,
......@@ -2073,8 +2104,18 @@ class TritonExperts(mk.FusedMoEExpertsModular):
)
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
sorted_token_ids, expert_ids, num_tokens_post_padded = (
_prepare_expert_assignment(
topk_ids,
config,
num_tokens,
top_k_num,
global_num_experts,
expert_map,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
block_shape=self.block_shape,
)
)
invoke_fused_moe_triton_kernel(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment