Commit 131fbe9b authored by 王敏's avatar 王敏
Browse files

[fix]修复moe模型专家并行kernel报错

parent f112086f
...@@ -263,6 +263,79 @@ __global__ void sgl_moe_align_block_size_kernel( ...@@ -263,6 +263,79 @@ __global__ void sgl_moe_align_block_size_kernel(
} }
} }
// taken from
// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
template <typename scalar_t>
__global__ void sgl_ep_moe_align_block_size_kernel(
scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
int32_t block_size, size_t numel, int32_t* cumsum,
int32_t start_expert, int32_t end_expert) {
__shared__ int32_t shared_counts[32][8];
__shared__ int32_t local_offsets[256];
const int warp_id = threadIdx.x / 32;
const int lane_id = threadIdx.x % 32;
const int experts_per_warp = 8;
const int my_expert_start = warp_id * experts_per_warp;
for (int i = 0; i < experts_per_warp; ++i) {
if (my_expert_start + i < num_experts) {
shared_counts[warp_id][i] = 0;
}
}
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int expert_id = topk_ids[i];
if (expert_id >= start_expert && expert_id < end_expert) {
expert_id -= start_expert;
int warp_idx = expert_id / experts_per_warp;
int expert_offset = expert_id % experts_per_warp;
atomicAdd(&shared_counts[warp_idx][expert_offset], 1);
}
}
__syncthreads();
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
int expert_count = 0;
int warp_idx = (i - 1) / experts_per_warp;
int expert_offset = (i - 1) % experts_per_warp;
expert_count = shared_counts[warp_idx][expert_offset];
cumsum[i] =
cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
}
*total_tokens_post_pad = cumsum[num_experts];
}
__syncthreads();
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
local_offsets[threadIdx.x] = cumsum[threadIdx.x];
}
__syncthreads();
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
if (expert_id >= start_expert && expert_id < end_expert) {
expert_id -= start_expert;
int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1);
sorted_token_ids[rank_post_pad] = i;
}
}
}
template <typename scalar_t, int TOPK> template <typename scalar_t, int TOPK>
__global__ void moe_sum_kernel( __global__ void moe_sum_kernel(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
...@@ -488,75 +561,55 @@ void ep_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, ...@@ -488,75 +561,55 @@ void ep_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
((num_thread + 1) * num_experts) * sizeof(uint16_t) + ((num_thread + 1) * num_experts) * sizeof(uint16_t) +
(num_experts + 1) * sizeof(int32_t); (num_experts + 1) * sizeof(int32_t);
// bool use_global_memory = false; bool use_sgl_kernel = false;
// bool use_i16 = false; // Use uint16_t for shared memory token counts bool use_i16 = false; // Use uint16_t for shared memory token counts
// if (shared_mem_i32 < device_max_shared_mem) {
// // Do nothing in this case. We're all set to use int32_t token counts if (shared_mem_i32 < device_max_shared_mem) {
// } else if (shared_mem_i16 < device_max_shared_mem && // Do nothing in this case. We're all set to use int32_t token counts
// topk_ids.numel() <= 65535) { } else if (shared_mem_i16 < device_max_shared_mem &&
// // when nelements of topk_ids is smaller than 65535 (max value of uint16), topk_ids.numel() <= 65535) {
// // element value of token_cnts would also smaller than 65535, // when nelements of topk_ids is smaller than 65535 (max value of uint16),
// // so we can use uint16 as dtype of token_cnts // element value of token_cnts would also smaller than 65535,
// use_i16 = true; // so we can use uint16 as dtype of token_cnts
// } else { use_i16 = true;
// use_global_memory = true; } else {
// } use_sgl_kernel = true;
}
// if (use_global_memory) {
// VLLM_DISPATCH_INTEGRAL_TYPES( if (use_sgl_kernel) {
// topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { VLLM_DISPATCH_INTEGRAL_TYPES(
// // calc needed amount of shared mem for `tokens_cnts` and `cumsum` topk_ids.scalar_type(), "sgl_ep_moe_align_block_size_kernel", [&] {
// // tensors // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); // tensors
auto options_int =
// auto options_int = torch::TensorOptions() torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
// .dtype(torch::kInt)
// .device(topk_ids.device());
// torch::Tensor token_cnts_buffer = // torch::Tensor token_cnts_buffer =
// torch::empty({(num_experts + 1) * num_experts}, options_int); // torch::empty({(num_experts + 1) * num_experts}, options_int);
// torch::Tensor cumsum_buffer = torch::Tensor cumsum_buffer =
// torch::empty({num_experts + 1}, options_int); torch::empty({num_experts + 1}, options_int);
// auto kernel = auto kernel = vllm::moe::sgl_ep_moe_align_block_size_kernel<scalar_t>;
// vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>; kernel<<<1, 1024, 0, stream>>>(
// kernel<<<1, num_thread, 0, stream>>>( topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
// topk_ids.data_ptr<scalar_t>(), experts_ids.data_ptr<int32_t>(),
// sorted_token_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
// experts_ids.data_ptr<int32_t>(), topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>(), start_expert, end_expert);
// num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, });
// topk_ids.numel(), token_cnts_buffer.data_ptr<int32_t>(), } else if (use_i16) {
// cumsum_buffer.data_ptr<int32_t>()); VLLM_DISPATCH_INTEGRAL_TYPES(
// }); topk_ids.scalar_type(), "ep_moe_align_block_size_kernel", [&] {
// } else if (use_i16) { auto kernel =
// VLLM_DISPATCH_INTEGRAL_TYPES( vllm::moe::ep_moe_align_block_size_kernel<scalar_t, uint16_t>;
// topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
// // set dynamic shared mem (void*)kernel, shared_mem_i16));
// auto kernel = kernel<<<1, num_thread, shared_mem_i16, stream>>>(
// vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>; topk_ids.data_ptr<scalar_t>(),
// AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( sorted_token_ids.data_ptr<int32_t>(),
// (void*)kernel, shared_mem_i16)); experts_ids.data_ptr<int32_t>(),
// kernel<<<1, num_thread, shared_mem_i16, stream>>>( num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
// topk_ids.data_ptr<scalar_t>(), topk_ids.numel(), start_expert, end_expert);
// sorted_token_ids.data_ptr<int32_t>(), });
// experts_ids.data_ptr<int32_t>(), } else {
// num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
// topk_ids.numel());
// });
// } else {
// VLLM_DISPATCH_INTEGRAL_TYPES(
// topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// auto kernel =
// vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
// AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
// (void*)kernel, shared_mem_i32));
// kernel<<<1, num_thread, shared_mem_i32, stream>>>(
// topk_ids.data_ptr<scalar_t>(),
// sorted_token_ids.data_ptr<int32_t>(),
// experts_ids.data_ptr<int32_t>(),
// num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
// topk_ids.numel());
// });
// }
VLLM_DISPATCH_INTEGRAL_TYPES( VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "ep_moe_align_block_size_kernel", [&] { topk_ids.scalar_type(), "ep_moe_align_block_size_kernel", [&] {
auto kernel = auto kernel =
...@@ -570,6 +623,8 @@ void ep_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, ...@@ -570,6 +623,8 @@ void ep_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel(), start_expert, end_expert); topk_ids.numel(), start_expert, end_expert);
}); });
}
} }
void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment