Unverified Commit 55444ed6 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

[EP] Add cuda kernel for moe_ep_pre_reorder (#6699)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
parent 20fd53b8
......@@ -224,6 +224,7 @@ set(SOURCES
"csrc/moe/moe_topk_softmax_kernels.cu"
"csrc/moe/fp8_blockwise_moe_kernel.cu"
"csrc/moe/prepare_moe_input.cu"
"csrc/moe/ep_moe_reorder_kernel.cu"
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/speculative/packbit.cu"
......
import itertools
import torch
import triton
import triton.language as tl
from sgl_kernel import ep_moe_pre_reorder
from sglang.srt.layers.moe.ep_moe.kernels import pre_reorder_triton_kernel
batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
configs = [(bs,) for bs in batch_sizes]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["cuda", "triton"],
line_names=["CUDA Kernel", "Triton Kernel"],
styles=[("green", "-"), ("orange", "-")],
ylabel="us",
plot_name="ep-moe-pre-reorder-performance",
args={},
)
)
def benchmark(batch_size, provider):
dtype = torch.float32
device = torch.device("cuda")
hidden_size, topk, start_expert_id, end_expert_id, block_size = 4096, 8, 0, 255, 512
# Allocate fresh tensors for every run to match bench_moe_fused_gate style
def alloc_tensors():
input_ = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
gateup_input = torch.zeros(
batch_size * topk, hidden_size, dtype=dtype, device=device
)
src2dst = torch.randint(
0, batch_size * topk, (batch_size, topk), dtype=torch.int32, device=device
)
topk_ids = torch.randint(
start_expert_id,
end_expert_id + 1,
(batch_size, topk),
dtype=torch.int32,
device=device,
)
a1_scales = torch.rand(
end_expert_id - start_expert_id + 1, dtype=torch.float32, device=device
)
return input_, gateup_input, src2dst, topk_ids, a1_scales
quantiles = [0.5, 0.2, 0.8]
if provider == "cuda":
def run_cuda():
inp, gout, s2d, tk_ids, scales = alloc_tensors()
ep_moe_pre_reorder(
inp,
gout,
s2d,
tk_ids,
scales,
start_expert_id,
end_expert_id,
topk,
True,
)
ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles)
elif provider == "triton":
def run_triton():
inp, gout, s2d, tk_ids, scales = alloc_tensors()
pre_reorder_triton_kernel[(batch_size,)](
inp.view(-1),
gout.view(-1),
s2d.view(-1),
tk_ids.view(-1),
scales,
start_expert_id,
end_expert_id,
topk,
hidden_size,
block_size,
True,
)
ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles)
else:
raise ValueError(f"Unknown provider: {provider}")
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
benchmark.run(print_data=True)
......@@ -150,6 +150,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"n_share_experts_fusion, float routed_scaling_factor) -> "
"(Tensor[])");
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
m.def(
"ep_moe_pre_reorder(Tensor input_ptr, Tensor gateup_input_ptr, Tensor src2dst_ptr, Tensor topk_ids_ptr, Tensor "
"a1_scales_ptr, int start_expert_id, int end_expert_id, int topk, bool use_per_token_if_dynamic) -> ()");
m.impl("ep_moe_pre_reorder", torch::kCUDA, &ep_moe_pre_reorder);
m.def(
"fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a_ptrs, Tensor b_ptrs, Tensor out_ptrs, Tensor "
"a_scales_ptrs, Tensor b_scales_ptrs, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
......
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <THC/THCAtomics.cuh>
#include <flashinfer/vec_dtypes.cuh>
#include "utils.h"
__global__ void ep_pre_reorder_cuda_kernel(
const float* __restrict__ input_ptr,
float* __restrict__ gateup_input_ptr,
const int* __restrict__ src2dst_ptr,
const int* __restrict__ topk_ids_ptr,
const float* __restrict__ a1_scales_ptr,
int start_expert_id,
int end_expert_id,
int topk,
int hidden_size,
bool use_per_token_if_dynamic) {
int token_idx = blockIdx.x;
int tid = threadIdx.x;
const float* src_ptr = input_ptr + int64_t(token_idx) * hidden_size;
const int* token_src2dst = src2dst_ptr + token_idx * topk;
const int* token_topk_ids = topk_ids_ptr + token_idx * topk;
for (int k = 0; k < topk; ++k) {
int expert_id = token_topk_ids[k];
if (expert_id < start_expert_id || expert_id > end_expert_id) continue;
float scale = 1.0f;
if (a1_scales_ptr != nullptr and use_per_token_if_dynamic) {
scale = 1.0f / a1_scales_ptr[token_idx];
}
if (a1_scales_ptr != nullptr) {
if (!use_per_token_if_dynamic) {
scale = 1.0f / a1_scales_ptr[expert_id - start_expert_id];
}
}
int dst_idx = token_src2dst[k];
float* dst_ptr = gateup_input_ptr + int64_t(dst_idx) * hidden_size;
constexpr uint32_t vec_size = 16 / sizeof(float);
using vec_t = flashinfer::vec_t<float, vec_size>;
for (int idx = tid; idx < hidden_size / vec_size; idx += blockDim.x) {
vec_t input_vec, output_vec;
input_vec.cast_load(src_ptr + idx * vec_size);
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
float val = static_cast<float>(input_vec[i]);
output_vec[i] = val * scale;
}
output_vec.cast_store(dst_ptr + idx * vec_size);
}
}
}
void ep_moe_pre_reorder(
torch::Tensor input,
torch::Tensor gateup_input,
torch::Tensor src2dst,
torch::Tensor topk_ids,
torch::Tensor a1_scales,
int64_t start_expert_id,
int64_t end_expert_id,
int64_t topk,
bool use_per_token_if_dynamic) {
int total_blocks = input.size(0);
int block_size = 512;
dim3 grid(total_blocks);
dim3 block(block_size);
int hidden_size = input.size(1);
ep_pre_reorder_cuda_kernel<<<grid, block>>>(
input.data_ptr<float>(),
gateup_input.data_ptr<float>(),
src2dst.data_ptr<int>(),
topk_ids.data_ptr<int>(),
a1_scales.defined() ? a1_scales.data_ptr<float>() : nullptr,
start_expert_id,
end_expert_id,
topk,
hidden_size,
use_per_token_if_dynamic);
}
......@@ -240,6 +240,17 @@ void prepare_moe_input(
const int64_t n,
const int64_t k);
void ep_moe_pre_reorder(
torch::Tensor input,
torch::Tensor gateup_input,
torch::Tensor src2dst,
torch::Tensor topk_ids,
torch::Tensor a1_scales,
int64_t start_expert_id,
int64_t end_expert_id,
int64_t topk,
bool use_per_token_if_dynamic);
/*
* From csrc/speculative
*/
......
......@@ -46,6 +46,7 @@ from sgl_kernel.gemm import (
)
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
from sgl_kernel.moe import (
ep_moe_pre_reorder,
fp8_blockwise_scaled_grouped_mm,
moe_align_block_size,
moe_fused_gate,
......
......@@ -62,6 +62,30 @@ def moe_fused_gate(
)
def ep_moe_pre_reorder(
input_tensor,
gateup_input,
src2dst,
topk_ids,
a1_scales,
start_expert_id,
end_expert_id,
topk,
use_per_token_if_dynamic,
):
return torch.ops.sgl_kernel.ep_moe_pre_reorder.default(
input_tensor,
gateup_input,
src2dst,
topk_ids,
a1_scales,
start_expert_id,
end_expert_id,
topk,
use_per_token_if_dynamic,
)
def fp8_blockwise_scaled_grouped_mm(
output,
a_ptrs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment