Unverified Commit ffec8154 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Perf] Optimize additional `fill(0)` in cutlass moe, 2.9% E2E throughput...


[Perf] Optimize additional `fill(0)` in cutlass moe, 2.9% E2E throughput improvement, 10.8% TTFT improvement (#31754)
Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent d386ab14
...@@ -173,7 +173,7 @@ def run_cutlass_moe_fp8( ...@@ -173,7 +173,7 @@ def run_cutlass_moe_fp8(
num_expert = global_num_experts if expert_map is None else expert_map.size(0) num_expert = global_num_experts if expert_map is None else expert_map.size(0)
# permuted a1q reuses workspace2 # permuted a1q reuses workspace2
a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute( a1q, a1q_scale, expert_first_token_offset, inv_perm, _ = moe_permute(
a1q, a1q,
a1q_scale, a1q_scale,
topk_ids, topk_ids,
...@@ -182,7 +182,7 @@ def run_cutlass_moe_fp8( ...@@ -182,7 +182,7 @@ def run_cutlass_moe_fp8(
expert_map, expert_map,
permuted_hidden_states=a1q_perm, permuted_hidden_states=a1q_perm,
) )
expert_offsets = expert_offsets[:-1] expert_offsets = expert_first_token_offset[:-1]
ops.get_cutlass_moe_mm_problem_sizes( ops.get_cutlass_moe_mm_problem_sizes(
local_topk_ids, problem_sizes1, problem_sizes2, global_num_experts, N, K local_topk_ids, problem_sizes1, problem_sizes2, global_num_experts, N, K
...@@ -215,9 +215,6 @@ def run_cutlass_moe_fp8( ...@@ -215,9 +215,6 @@ def run_cutlass_moe_fp8(
act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out
) )
if expert_map is not None:
mm2_out.fill_(0)
ops.cutlass_moe_mm( ops.cutlass_moe_mm(
mm2_out, mm2_out,
a2q, a2q,
...@@ -243,6 +240,9 @@ def run_cutlass_moe_fp8( ...@@ -243,6 +240,9 @@ def run_cutlass_moe_fp8(
permuted_hidden_states=mm2_out, permuted_hidden_states=mm2_out,
topk_weights=topk_weights, topk_weights=topk_weights,
inv_permuted_idx=inv_perm, inv_permuted_idx=inv_perm,
expert_first_token_offset=(
expert_first_token_offset if expert_map is not None else None
),
) )
...@@ -988,7 +988,7 @@ def run_cutlass_moe_w4a8_fp8( ...@@ -988,7 +988,7 @@ def run_cutlass_moe_w4a8_fp8(
num_expert = global_num_experts if expert_map is None else expert_map.size(0) num_expert = global_num_experts if expert_map is None else expert_map.size(0)
# permuted a1q reuses workspace2 # permuted a1q reuses workspace2
a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute( a1q, a1q_scale, expert_first_token_offset, inv_perm, _ = moe_permute(
a1q, a1q,
a1q_scale, a1q_scale,
topk_ids, topk_ids,
...@@ -997,7 +997,7 @@ def run_cutlass_moe_w4a8_fp8( ...@@ -997,7 +997,7 @@ def run_cutlass_moe_w4a8_fp8(
expert_map, expert_map,
permuted_hidden_states=a1q_perm, permuted_hidden_states=a1q_perm,
) )
expert_offsets = expert_offsets[:-1] expert_offsets = expert_first_token_offset[:-1]
# For RS gemm SwapAB is always enabled (swap logical M, N in the problem shape) # For RS gemm SwapAB is always enabled (swap logical M, N in the problem shape)
ops.get_cutlass_moe_mm_problem_sizes( ops.get_cutlass_moe_mm_problem_sizes(
...@@ -1032,9 +1032,6 @@ def run_cutlass_moe_w4a8_fp8( ...@@ -1032,9 +1032,6 @@ def run_cutlass_moe_w4a8_fp8(
act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out
) )
if expert_map is not None:
mm2_out.fill_(0)
ops.cutlass_w4a8_moe_mm( ops.cutlass_w4a8_moe_mm(
mm2_out, mm2_out,
a2q, a2q,
...@@ -1058,6 +1055,9 @@ def run_cutlass_moe_w4a8_fp8( ...@@ -1058,6 +1055,9 @@ def run_cutlass_moe_w4a8_fp8(
permuted_hidden_states=mm2_out, permuted_hidden_states=mm2_out,
topk_weights=topk_weights, topk_weights=topk_weights,
inv_permuted_idx=inv_perm, inv_permuted_idx=inv_perm,
expert_first_token_offset=(
expert_first_token_offset if expert_map is not None else None
),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment