Unverified Commit 3f2e315f authored by tql.99's avatar tql.99 Committed by GitHub
Browse files

optimize: reduce shulffle and quantization overhead in cutlass_moe sm90 (#8962)


Co-authored-by: default avatar戚余航 <qiyuhang@bytedance.com>
parent 6e215118
......@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
......@@ -124,6 +124,7 @@ def cutlass_fused_experts_fp8(
if is_cuda:
from sglang.srt.layers.quantization.fp8_kernel import (
per_group_transpose,
per_token_group_quant_fp8_hopper_moe_mn_major,
sglang_per_token_group_quant_fp8,
)
......@@ -152,15 +153,12 @@ def cutlass_fused_experts_fp8(
k,
)
if is_sm100_supported():
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
else:
rep_a = shuffle_rows(a, a_map, (m * topk, k))
rep_a_q, rep_a1_scales = per_token_group_quant_fp8_hopper_moe_mn_major(
rep_a, expert_offsets, problem_sizes1, 128
)
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
if not is_sm100_supported():
rep_a1_scales = per_group_transpose(rep_a1_scales, expert_offsets)
w1_scale = w1_scale.contiguous()
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
......@@ -193,12 +191,9 @@ def cutlass_fused_experts_fp8(
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
silu_and_mul(c1, intermediate)
if is_sm100_supported():
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
else:
intemediate_q, a2_scale = per_token_group_quant_fp8_hopper_moe_mn_major(
intermediate, expert_offsets, problem_sizes2, 128
)
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
if not is_sm100_supported():
a2_scale = per_group_transpose(a2_scale, expert_offsets)
w2_scale = w2_scale.contiguous()
fp8_blockwise_scaled_grouped_mm(
......
......@@ -1356,3 +1356,62 @@ def per_token_group_quant_fp8_hopper_moe_mn_major(
expert_tokens_alignment,
)
return a_q, sfa
@triton.jit
def _per_group_transpose(
data_ptr: torch.Tensor,
trans_data_ptr: torch.Tensor,
expert_offsets: torch.Tensor,
k: int,
M_ALIGNMENT: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
expert_id = tl.program_id(0)
m_id = tl.program_id(1)
k_id = tl.program_id(2)
curr_expert_offset = tl.load(expert_offsets + expert_id)
next_expert_offset = tl.load(expert_offsets + expert_id + 1)
num_tokens_of_expert = next_expert_offset - curr_expert_offset
tl.multiple_of(curr_expert_offset, M_ALIGNMENT)
tl.multiple_of(next_expert_offset, M_ALIGNMENT)
data_start_ptr = data_ptr + curr_expert_offset * k
trans_data_start_ptr = trans_data_ptr + curr_expert_offset * k
k_coord = k_id * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
k_mask = k_coord < k
for start_m in tl.range(0, num_tokens_of_expert, BLOCK_SIZE_M * tl.num_programs(1)):
m_coord = start_m + m_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
m_mask = m_coord < num_tokens_of_expert
off = m_coord[:, None] * k + k_coord[None, :]
trans_off = m_coord[:, None] + k_coord[None, :] * num_tokens_of_expert
mask = m_mask[:, None] & k_mask[None, :]
data = tl.load(data_start_ptr + off, mask=mask)
tl.store(trans_data_start_ptr + trans_off, data, mask=mask)
def per_group_transpose(
a: torch.Tensor,
expert_offsets: torch.Tensor,
M_ALIGNMENT: int = 1,
) -> torch.Tensor:
assert a.dim() == 2
assert a.is_contiguous(), "`a` is not contiguous"
m, k = a.size()
trans_a = torch.empty_like(a)
num_experts = expert_offsets.size(0) - 1
grid = lambda META: (
num_experts,
triton.cdiv((m + num_experts - 1) // num_experts, META["BLOCK_SIZE_M"]),
triton.cdiv(k, META["BLOCK_SIZE_K"]),
)
_per_group_transpose[grid](
a, trans_a, expert_offsets, k, M_ALIGNMENT, BLOCK_SIZE_M=16, BLOCK_SIZE_K=8
)
return trans_a
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment