Unverified Commit 506c4928 authored by xutizhou's avatar xutizhou Committed by GitHub
Browse files

feat: integrate deepgemm into EPMoE (#6821)


Co-authored-by: default avatartianqilin.99 <tianqilin.99@bytedance.com>
Co-authored-by: default avatarTianQiLin666666 <1834987979@qq.com>
Co-authored-by: default avatarCheng Wan <54331508+ch-wan@users.noreply.github.com>
parent 30ceccc7
...@@ -478,11 +478,13 @@ def post_reorder_triton_kernel( ...@@ -478,11 +478,13 @@ def post_reorder_triton_kernel(
end_expert_id, end_expert_id,
topk, topk,
hidden_size, hidden_size,
dst_start,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
InDtype = down_output_ptr.dtype.element_ty InDtype = down_output_ptr.dtype.element_ty
src_idx = tl.program_id(0) src_idx_int32 = tl.program_id(0)
src_idx = src_idx_int32.to(tl.int64)
src2dst_ptr = src2dst_ptr + src_idx * topk src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk topk_ids_ptr = topk_ids_ptr + src_idx * topk
topk_weights_ptr = topk_weights_ptr + src_idx * topk topk_weights_ptr = topk_weights_ptr + src_idx * topk
...@@ -501,7 +503,9 @@ def post_reorder_triton_kernel( ...@@ -501,7 +503,9 @@ def post_reorder_triton_kernel(
expert_id = tl.load(topk_ids_ptr + idx) expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id: if expert_id >= start_expert_id and expert_id <= end_expert_id:
computed = True computed = True
dst_idx = tl.load(src2dst_ptr + idx) dst_idx_int32 = tl.load(src2dst_ptr + idx)
dst_idx = dst_idx_int32.to(tl.int64)
dst_idx = dst_idx - dst_start
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
load_ptr = down_output_ptr + dst_idx * hidden_size load_ptr = down_output_ptr + dst_idx * hidden_size
in_data = tl.load(load_ptr + offset, mask=mask) in_data = tl.load(load_ptr + offset, mask=mask)
...@@ -1086,3 +1090,156 @@ def tma_align_input_scale(input_scale: torch.Tensor): ...@@ -1086,3 +1090,156 @@ def tma_align_input_scale(input_scale: torch.Tensor):
BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_K=BLOCK_SIZE_K,
) )
return output.t()[:m] return output.t()[:m]
@triton.jit
def compute_masked_m_triton_kernel(seg_indptr, masked_m):
expert_id = tl.program_id(0)
start = tl.load(seg_indptr + expert_id)
end = tl.load(seg_indptr + expert_id + 1)
tl.store(masked_m + expert_id, (end - start))
@triton.jit
def deepgemm_compute_src2dst_triton_kernel(
topk_ids,
reorder_ids,
seg_indptr,
src2dst,
m_max,
num_toks,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = dst_id < num_toks
src_id = tl.load(reorder_ids + dst_id, mask=mask)
expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
expert_dst_start = tl.load(seg_indptr + expert_id)
expert_dst_offset = dst_id - expert_dst_start
dst_id = expert_id * m_max + expert_dst_offset
tl.store(src2dst + src_id, dst_id, mask=mask)
@triton.jit
def fill_gateup_input_triton_kernel(
input_ptr,
scale_ptr,
gateup_input_ptr,
gateup_input_scale_ptr,
src2dst_ptr,
topk_ids_ptr,
start_expert_id,
end_expert_id,
topk,
m_max,
hidden_size,
scale_size,
BLOCK_SIZE: tl.constexpr,
):
src_idx_int32 = tl.program_id(0)
src_idx = src_idx_int32.to(tl.int64)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
src_ptr = input_ptr + src_idx * hidden_size
scale_src_ptr = scale_ptr + src_idx * scale_size
vec = tl.arange(0, BLOCK_SIZE)
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
dst_idx_int32 = tl.load(src2dst_ptr + idx)
dst_idx = dst_idx_int32.to(tl.int64)
dst_idx = dst_idx - start_expert_id * m_max
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + vec
mask = offset < hidden_size
in_data = tl.load(src_ptr + offset, mask=mask)
tl.store(dst_ptr + offset, in_data, mask=mask)
scale_dst_ptr = gateup_input_scale_ptr + dst_idx * scale_size
for start_offset in tl.range(0, scale_size, BLOCK_SIZE):
offset = start_offset + vec
mask = offset < scale_size
in_scale = tl.load(scale_src_ptr + offset, mask=mask)
tl.store(scale_dst_ptr + offset, in_scale, mask=mask)
def moe_ep_deepgemm_preprocess(
topk_ids: torch.Tensor,
num_experts: int,
hidden_states: torch.Tensor,
top_k: int,
start_expert_id,
end_expert_id,
block_shape,
output_dtype: torch.dtype = torch.float8_e4m3fn,
):
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32)
compute_seg_indptr_triton_kernel[(num_experts,)](
reorder_topk_ids, seg_indptr, topk_ids.numel()
)
grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
compute_masked_m_triton_kernel[(num_experts,)](seg_indptr, masked_m)
# For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165
m_max = (hidden_states.size(0) + 255) // 256 * 256
expected_m = (topk_ids.numel() + num_experts - 1) // num_experts
gateup_input = torch.empty(
(int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)),
device=hidden_states.device,
dtype=output_dtype,
)
deepgemm_compute_src2dst_triton_kernel[grid](
topk_ids,
reorder_ids,
seg_indptr,
src2dst,
m_max,
topk_ids.numel(),
BLOCK_SIZE=256,
)
if block_shape is None:
block_shape = [128, 128]
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)
gateup_input_scale = torch.empty(
(gateup_input.size(0), gateup_input.size(1), scale.size(1)),
device=hidden_states.device,
dtype=scale.dtype,
)
fill_gateup_input_triton_kernel[(hidden_states.shape[0],)](
hidden_states,
scale,
gateup_input,
gateup_input_scale,
src2dst,
topk_ids,
start_expert_id,
end_expert_id,
top_k,
m_max,
hidden_states.size(1),
scale.size(1),
BLOCK_SIZE=1024,
)
return (
m_max,
masked_m[start_expert_id : (end_expert_id + 1)],
expected_m,
src2dst,
gateup_input,
gateup_input_scale,
)
...@@ -16,6 +16,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ...@@ -16,6 +16,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
ep_scatter, ep_scatter,
gelu_and_mul_triton_kernel, gelu_and_mul_triton_kernel,
grouped_gemm_triton, grouped_gemm_triton,
moe_ep_deepgemm_preprocess,
post_reorder_triton_kernel, post_reorder_triton_kernel,
pre_reorder_triton_kernel, pre_reorder_triton_kernel,
run_moe_ep_preproess, run_moe_ep_preproess,
...@@ -178,6 +179,7 @@ class EPMoE(torch.nn.Module): ...@@ -178,6 +179,7 @@ class EPMoE(torch.nn.Module):
assert ( assert (
num_fused_shared_experts == 0 num_fused_shared_experts == 0
), "num_fused_shared_experts is not supported in EP" ), "num_fused_shared_experts is not supported in EP"
self.num_fused_shared_experts = num_fused_shared_experts
self.num_experts_per_partition = self.num_experts // self.tp_size self.num_experts_per_partition = self.num_experts // self.tp_size
self.start_expert_id = self.tp_rank * self.num_experts_per_partition self.start_expert_id = self.tp_rank * self.num_experts_per_partition
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1 self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
...@@ -227,13 +229,182 @@ class EPMoE(torch.nn.Module): ...@@ -227,13 +229,182 @@ class EPMoE(torch.nn.Module):
self.grouped_gemm_runner = None self.grouped_gemm_runner = None
self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
)
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
return self.forward_deepgemm(hidden_states, router_logits)
else:
return self.forward_normal(hidden_states, router_logits)
def forward_deepgemm(
self, hidden_states: torch.Tensor, router_logits: torch.Tensor
):
assert self.quant_method is not None
assert self.activation == "silu"
hidden_states_shape = hidden_states.shape hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device hidden_states_device = hidden_states.device
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function,
routed_scaling_factor=self.routed_scaling_factor,
)
assert self.quant_method is not None if not self.use_block_quant:
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
scale_block_size = 128
w13_weight_scale_n = 2 * (
(self.intermediate_size + scale_block_size - 1) // scale_block_size
)
w13_weight_scale_k = (
hidden_states_shape[-1] + scale_block_size - 1
) // scale_block_size
w13_weight_scale = (
self.w13_weight_scale.unsqueeze(1)
.repeat_interleave(w13_weight_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w13_weight_scale_k, dim=2)
)
self.w13_weight_fp8 = (
self.w13_weight,
w13_weight_scale,
)
w2_weight_scale_n = (
hidden_states_shape[-1] + scale_block_size - 1
) // scale_block_size
w2_weight_scale_k = (
self.intermediate_size + scale_block_size - 1
) // scale_block_size
w2_weight_scale = (
self.w2_weight_scale.unsqueeze(1)
.repeat_interleave(w2_weight_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w2_weight_scale_k, dim=2)
)
self.w2_weight_fp8 = (
self.w2_weight,
w2_weight_scale,
)
# PreReorder
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
moe_ep_deepgemm_preprocess(
topk_ids,
self.num_experts,
hidden_states,
self.top_k,
self.start_expert_id,
self.end_expert_id,
self.block_shape,
)
)
dispose_tensor(hidden_states)
# GroupGemm-0
gateup_input_fp8 = (
gateup_input,
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(gateup_input_scale),
)
num_groups, m, k = gateup_input_fp8[0].size()
n = self.w13_weight.size(1)
gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
gateup_input_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
)
del gateup_input
del gateup_input_fp8
# Act
down_input = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=hidden_states_device,
dtype=self.fp8_dtype,
)
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2 // scale_block_size,
),
device=hidden_states_device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
)
del gateup_output
# GroupGemm-1
n = self.w2_weight.size(1)
down_input_fp8 = (
down_input,
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
)
down_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
)
del down_input
del down_input_fp8
# PostReorder
output = torch.empty(
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
)
post_reorder_triton_kernel[(hidden_states_shape[0],)](
down_output,
output,
src2dst,
topk_ids,
topk_weights,
self.start_expert_id,
self.end_expert_id,
self.top_k,
hidden_states_shape[1],
m_max * self.start_expert_id,
BLOCK_SIZE=512,
)
return output
def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
assert self.quant_method is not None
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
if self.grouped_gemm_runner is None: if self.grouped_gemm_runner is None:
self.grouped_gemm_runner = GroupedGemmRunner( self.grouped_gemm_runner = GroupedGemmRunner(
hidden_states.device, hidden_states.device,
...@@ -249,6 +420,7 @@ class EPMoE(torch.nn.Module): ...@@ -249,6 +420,7 @@ class EPMoE(torch.nn.Module):
renormalize=self.renormalize, renormalize=self.renormalize,
topk_group=self.topk_group, topk_group=self.topk_group,
num_expert_group=self.num_expert_group, num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias, correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function, custom_routing_function=self.custom_routing_function,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
...@@ -440,6 +612,7 @@ class EPMoE(torch.nn.Module): ...@@ -440,6 +612,7 @@ class EPMoE(torch.nn.Module):
self.end_expert_id, self.end_expert_id,
self.top_k, self.top_k,
hidden_states_shape[1], hidden_states_shape[1],
0,
BLOCK_SIZE=512, BLOCK_SIZE=512,
) )
return output return output
......
...@@ -182,6 +182,7 @@ def ep_moe( ...@@ -182,6 +182,7 @@ def ep_moe(
end_expert_id, end_expert_id,
top_k, top_k,
hidden_states.size(1), hidden_states.size(1),
0,
BLOCK_SIZE=512, BLOCK_SIZE=512,
) )
return output return output
......
...@@ -77,6 +77,7 @@ def benchmark(batch_size, provider): ...@@ -77,6 +77,7 @@ def benchmark(batch_size, provider):
end_expert_id, end_expert_id,
topk, topk,
hidden_size, hidden_size,
0,
block_size, block_size,
) )
......
...@@ -85,6 +85,7 @@ def run_triton_kernel( ...@@ -85,6 +85,7 @@ def run_triton_kernel(
end_expert_id, end_expert_id,
topk, topk,
hidden_size, hidden_size,
0,
block_size, block_size,
) )
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment