Unverified Commit 817d4370 authored by Shi Shuai's avatar Shi Shuai Committed by GitHub
Browse files

feat: support ep size < 32 for sgl kernel (#4348)

parent c550e52f
...@@ -196,6 +196,8 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8): ...@@ -196,6 +196,8 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
expert_ids_triton, expert_ids_triton,
num_tokens_post_pad_triton, num_tokens_post_pad_triton,
) )
try:
ops.moe_align_block_size( ops.moe_align_block_size(
topk_ids, topk_ids,
num_experts, num_experts,
...@@ -204,6 +206,11 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8): ...@@ -204,6 +206,11 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
expert_ids_vllm, expert_ids_vllm,
num_tokens_post_pad_vllm, num_tokens_post_pad_vllm,
) )
print(f"✅ VLLM implementation works with {num_experts} experts!")
vllm_works = True
except RuntimeError as e:
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}")
vllm_works = False
if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose( if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose(
num_tokens_post_pad_cuda, num_tokens_post_pad_triton num_tokens_post_pad_cuda, num_tokens_post_pad_triton
...@@ -216,10 +223,15 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8): ...@@ -216,10 +223,15 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda) print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda)
print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton) print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton)
if torch.allclose(expert_ids_cuda, expert_ids_vllm) and torch.allclose( if (
num_tokens_post_pad_cuda, num_tokens_post_pad_vllm vllm_works
and torch.allclose(expert_ids_cuda, expert_ids_vllm)
and torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_vllm)
): ):
print("✅ SGL and VLLM implementations match") print("✅ SGL and VLLM implementations match")
else:
if not vllm_works:
print("⚠️ VLLM comparison skipped due to failure")
else: else:
print("❌ SGL and VLLM implementations do not match") print("❌ SGL and VLLM implementations do not match")
print("SGL expert_ids:", expert_ids_cuda) print("SGL expert_ids:", expert_ids_cuda)
...@@ -228,8 +240,9 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8): ...@@ -228,8 +240,9 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm) print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm)
# Test range
num_tokens_range = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] num_tokens_range = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
num_experts_range = [32, 64, 128, 256] num_experts_range = [8, 32, 64, 128, 256]
topk_range = [2, 4, 8] topk_range = [2, 4, 8]
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range)) configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
...@@ -316,6 +329,7 @@ def benchmark(num_tokens, num_experts, topk, provider): ...@@ -316,6 +329,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
quantiles=quantiles, quantiles=quantiles,
) )
else: # vllm else: # vllm
try:
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench(
lambda: ops.moe_align_block_size( lambda: ops.moe_align_block_size(
topk_ids, topk_ids,
...@@ -327,6 +341,10 @@ def benchmark(num_tokens, num_experts, topk, provider): ...@@ -327,6 +341,10 @@ def benchmark(num_tokens, num_experts, topk, provider):
), ),
quantiles=quantiles, quantiles=quantiles,
) )
except RuntimeError as e:
print(f"❌ VLLM benchmark failed with {num_experts} experts: {e}")
# Return extreme values to indicate failure in the chart
return float("inf"), float("inf"), float("inf")
return 1000 * ms, 1000 * max_ms, 1000 * min_ms return 1000 * ms, 1000 * max_ms, 1000 * min_ms
...@@ -343,7 +361,7 @@ if __name__ == "__main__": ...@@ -343,7 +361,7 @@ if __name__ == "__main__":
"--num_experts", "--num_experts",
type=int, type=int,
default=256, default=256,
choices=[8, 64, 128, 256], choices=[8, 16, 32, 64, 128, 256],
help="Number of experts for benchmark", help="Number of experts for benchmark",
) )
parser.add_argument( parser.add_argument(
...@@ -353,8 +371,15 @@ if __name__ == "__main__": ...@@ -353,8 +371,15 @@ if __name__ == "__main__":
choices=[2, 4, 8], choices=[2, 4, 8],
help="Top-k value for benchmark", help="Top-k value for benchmark",
) )
parser.add_argument(
"--skip_full_benchmark",
action="store_true",
help="Only run the calculate_diff function, skip full benchmarking",
)
args = parser.parse_args() args = parser.parse_args()
calculate_diff(num_tokens=1024, num_experts=args.num_experts, topk=args.topk) calculate_diff(num_tokens=1024, num_experts=args.num_experts, topk=args.topk)
if not args.skip_full_benchmark:
print(f"\n📊 Running performance benchmark for {args.num_experts} experts...")
benchmark.run(print_data=True) benchmark.run(print_data=True)
...@@ -47,6 +47,7 @@ __global__ void moe_align_block_size_kernel( ...@@ -47,6 +47,7 @@ __global__ void moe_align_block_size_kernel(
int32_t* __restrict__ expert_ids, int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad, int32_t* __restrict__ total_tokens_post_pad,
int32_t num_experts, int32_t num_experts,
int32_t padded_num_experts,
int32_t experts_per_warp, int32_t experts_per_warp,
int32_t block_size, int32_t block_size,
size_t numel, size_t numel,
...@@ -57,7 +58,7 @@ __global__ void moe_align_block_size_kernel( ...@@ -57,7 +58,7 @@ __global__ void moe_align_block_size_kernel(
const int my_expert_start = warp_id * experts_per_warp; const int my_expert_start = warp_id * experts_per_warp;
for (int i = 0; i < experts_per_warp; ++i) { for (int i = 0; i < experts_per_warp; ++i) {
if (my_expert_start + i < num_experts) { if (my_expert_start + i < padded_num_experts) {
shared_counts[warp_id * experts_per_warp + i] = 0; shared_counts[warp_id * experts_per_warp + i] = 0;
} }
} }
...@@ -108,23 +109,44 @@ void moe_align_block_size( ...@@ -108,23 +109,44 @@ void moe_align_block_size(
torch::Tensor token_cnts_buffer, torch::Tensor token_cnts_buffer,
torch::Tensor cumsum_buffer) { torch::Tensor cumsum_buffer) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(num_experts % WARP_SIZE == 0);
int experts_per_warp = num_experts / WARP_SIZE; int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
int experts_per_warp;
int threads;
if (num_experts <= 8) {
experts_per_warp = 8;
threads = 256;
} else if (num_experts <= 16) {
experts_per_warp = 16;
threads = 512;
} else {
experts_per_warp = WARP_SIZE;
threads = 1024;
}
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto align_kernel = moe_align_block_size_kernel<scalar_t>; auto align_kernel = moe_align_block_size_kernel<scalar_t>;
size_t shared_mem_size = 32 * experts_per_warp * sizeof(int32_t);
align_kernel<<<1, 1024, shared_mem_size, stream>>>( size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp);
size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t);
align_kernel<<<1, threads, shared_mem_size, stream>>>(
topk_ids.data_ptr<scalar_t>(), topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(), sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(), experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
num_experts, num_experts,
padded_num_experts,
experts_per_warp, experts_per_warp,
block_size, block_size,
topk_ids.numel(), topk_ids.numel(),
cumsum_buffer.data_ptr<int32_t>()); cumsum_buffer.data_ptr<int32_t>());
const int block_threads = 256; const int block_threads = std::min(256, (int)threads);
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
const int max_blocks = 65535; const int max_blocks = 65535;
const int actual_blocks = std::min(num_blocks, max_blocks); const int actual_blocks = std::min(num_blocks, max_blocks);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment