Unverified Commit cb3918a0 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

Optimize moe_sum_reduce_kernel (#9477)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: default avatarXiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
parent f3b67602
...@@ -4,7 +4,6 @@ import triton.language as tl ...@@ -4,7 +4,6 @@ import triton.language as tl
from triton.testing import do_bench from triton.testing import do_bench
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
@triton.jit @triton.jit
def _moe_sum_reduce_kernel( def _moe_sum_reduce_kernel(
input_ptr, input_ptr,
...@@ -29,31 +28,35 @@ def _moe_sum_reduce_kernel( ...@@ -29,31 +28,35 @@ def _moe_sum_reduce_kernel(
token_block_id = tl.program_id(0) token_block_id = tl.program_id(0)
dim_block_id = tl.program_id(1) dim_block_id = tl.program_id(1)
token_start = token_block_id * BLOCK_M offs_token = token_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
token_end = min((token_block_id + 1) * BLOCK_M, token_num) offs_dim = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
dim_start = dim_block_id * BLOCK_DIM mask_token = offs_token < token_num
dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim) mask_dim = offs_dim < hidden_dim
offs_dim = dim_start + tl.arange(0, BLOCK_DIM) base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :]
for token_index in range(token_start, token_end): accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32)
accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
for i in tl.range(0, topk_num, num_stages=NUM_STAGE): tile = tl.load(
tmp = tl.load( base_ptrs + i * input_stride_1,
input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0 mask=mask_token[:, None] & mask_dim[None, :],
) other=0.0,
accumulator += tmp
accumulator = accumulator * routed_scaling_factor
store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
tl.store(
store_t_ptr,
accumulator.to(input_ptr.dtype.element_ty),
mask=offs_dim < dim_end,
) )
accumulator += tile.to(tl.float32)
accumulator *= routed_scaling_factor
# -------- Write back --------
store_ptrs = output_ptr + offs_token[:, None] * output_stride_0 + offs_dim[None, :]
tl.store(
store_ptrs,
accumulator.to(input_ptr.dtype.element_ty),
mask=mask_token[:, None] & mask_dim[None, :],
)
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
def moe_sum_reduce( def moe_sum_reduce(
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
): ):
...@@ -66,7 +69,7 @@ def moe_sum_reduce( ...@@ -66,7 +69,7 @@ def moe_sum_reduce(
BLOCK_M = 1 BLOCK_M = 1
BLOCK_DIM = 2048 BLOCK_DIM = 2048
NUM_STAGE = 1 NUM_STAGE = 1
num_warps = 8 num_warps = 16
grid = ( grid = (
triton.cdiv(token_num, BLOCK_M), triton.cdiv(token_num, BLOCK_M),
......
...@@ -735,29 +735,32 @@ def _moe_sum_reduce_kernel( ...@@ -735,29 +735,32 @@ def _moe_sum_reduce_kernel(
token_block_id = tl.program_id(0) token_block_id = tl.program_id(0)
dim_block_id = tl.program_id(1) dim_block_id = tl.program_id(1)
token_start = token_block_id * BLOCK_M offs_token = token_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
token_end = min((token_block_id + 1) * BLOCK_M, token_num) offs_dim = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
dim_start = dim_block_id * BLOCK_DIM mask_token = offs_token < token_num
dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim) mask_dim = offs_dim < hidden_dim
offs_dim = dim_start + tl.arange(0, BLOCK_DIM) base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :]
for token_index in range(token_start, token_end): accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32)
accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
for i in tl.range(0, topk_num, num_stages=NUM_STAGE): tile = tl.load(
tmp = tl.load( base_ptrs + i * input_stride_1,
input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0 mask=mask_token[:, None] & mask_dim[None, :],
) other=0.0,
accumulator += tmp
accumulator = accumulator * routed_scaling_factor
store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
tl.store(
store_t_ptr,
accumulator.to(input_ptr.dtype.element_ty),
mask=offs_dim < dim_end,
) )
accumulator += tile.to(tl.float32)
accumulator *= routed_scaling_factor
# -------- Write back --------
store_ptrs = output_ptr + offs_token[:, None] * output_stride_0 + offs_dim[None, :]
tl.store(
store_ptrs,
accumulator.to(input_ptr.dtype.element_ty),
mask=mask_token[:, None] & mask_dim[None, :],
)
def moe_sum_reduce_triton( def moe_sum_reduce_triton(
...@@ -772,7 +775,7 @@ def moe_sum_reduce_triton( ...@@ -772,7 +775,7 @@ def moe_sum_reduce_triton(
BLOCK_M = 1 BLOCK_M = 1
BLOCK_DIM = 2048 BLOCK_DIM = 2048
NUM_STAGE = 1 NUM_STAGE = 1
num_warps = 8 num_warps = 16
grid = ( grid = (
triton.cdiv(token_num, BLOCK_M), triton.cdiv(token_num, BLOCK_M),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment