Commit 8b546443 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.7.2-dev_wm' into 'v0.7.2-dev'

[fix]修复moe模型专家并行kernel报错

See merge request dcutoolkit/deeplearing/vllm!66
parents f112086f 131fbe9b
......@@ -263,6 +263,79 @@ __global__ void sgl_moe_align_block_size_kernel(
}
}
// taken from
// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
template <typename scalar_t>
__global__ void sgl_ep_moe_align_block_size_kernel(
scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
int32_t block_size, size_t numel, int32_t* cumsum,
int32_t start_expert, int32_t end_expert) {
__shared__ int32_t shared_counts[32][8];
__shared__ int32_t local_offsets[256];
const int warp_id = threadIdx.x / 32;
const int lane_id = threadIdx.x % 32;
const int experts_per_warp = 8;
const int my_expert_start = warp_id * experts_per_warp;
for (int i = 0; i < experts_per_warp; ++i) {
if (my_expert_start + i < num_experts) {
shared_counts[warp_id][i] = 0;
}
}
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int expert_id = topk_ids[i];
if (expert_id >= start_expert && expert_id < end_expert) {
expert_id -= start_expert;
int warp_idx = expert_id / experts_per_warp;
int expert_offset = expert_id % experts_per_warp;
atomicAdd(&shared_counts[warp_idx][expert_offset], 1);
}
}
__syncthreads();
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
int expert_count = 0;
int warp_idx = (i - 1) / experts_per_warp;
int expert_offset = (i - 1) % experts_per_warp;
expert_count = shared_counts[warp_idx][expert_offset];
cumsum[i] =
cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
}
*total_tokens_post_pad = cumsum[num_experts];
}
__syncthreads();
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
local_offsets[threadIdx.x] = cumsum[threadIdx.x];
}
__syncthreads();
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
if (expert_id >= start_expert && expert_id < end_expert) {
expert_id -= start_expert;
int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1);
sorted_token_ids[rank_post_pad] = i;
}
}
}
template <typename scalar_t, int TOPK>
__global__ void moe_sum_kernel(
scalar_t* __restrict__ out, // [..., d]
......@@ -488,75 +561,55 @@ void ep_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
((num_thread + 1) * num_experts) * sizeof(uint16_t) +
(num_experts + 1) * sizeof(int32_t);
// bool use_global_memory = false;
// bool use_i16 = false; // Use uint16_t for shared memory token counts
// if (shared_mem_i32 < device_max_shared_mem) {
// // Do nothing in this case. We're all set to use int32_t token counts
// } else if (shared_mem_i16 < device_max_shared_mem &&
// topk_ids.numel() <= 65535) {
// // when nelements of topk_ids is smaller than 65535 (max value of uint16),
// // element value of token_cnts would also smaller than 65535,
// // so we can use uint16 as dtype of token_cnts
// use_i16 = true;
// } else {
// use_global_memory = true;
// }
// if (use_global_memory) {
// VLLM_DISPATCH_INTEGRAL_TYPES(
// topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// // tensors
// const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
// auto options_int = torch::TensorOptions()
// .dtype(torch::kInt)
// .device(topk_ids.device());
bool use_sgl_kernel = false;
bool use_i16 = false; // Use uint16_t for shared memory token counts
if (shared_mem_i32 < device_max_shared_mem) {
// Do nothing in this case. We're all set to use int32_t token counts
} else if (shared_mem_i16 < device_max_shared_mem &&
topk_ids.numel() <= 65535) {
// when nelements of topk_ids is smaller than 65535 (max value of uint16),
// element value of token_cnts would also smaller than 65535,
// so we can use uint16 as dtype of token_cnts
use_i16 = true;
} else {
use_sgl_kernel = true;
}
if (use_sgl_kernel) {
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "sgl_ep_moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
auto options_int =
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
// torch::Tensor token_cnts_buffer =
// torch::empty({(num_experts + 1) * num_experts}, options_int);
// torch::Tensor cumsum_buffer =
// torch::empty({num_experts + 1}, options_int);
// auto kernel =
// vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
// kernel<<<1, num_thread, 0, stream>>>(
// topk_ids.data_ptr<scalar_t>(),
// sorted_token_ids.data_ptr<int32_t>(),
// experts_ids.data_ptr<int32_t>(),
// num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
// topk_ids.numel(), token_cnts_buffer.data_ptr<int32_t>(),
// cumsum_buffer.data_ptr<int32_t>());
// });
// } else if (use_i16) {
// VLLM_DISPATCH_INTEGRAL_TYPES(
// topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// // set dynamic shared mem
// auto kernel =
// vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>;
// AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
// (void*)kernel, shared_mem_i16));
// kernel<<<1, num_thread, shared_mem_i16, stream>>>(
// topk_ids.data_ptr<scalar_t>(),
// sorted_token_ids.data_ptr<int32_t>(),
// experts_ids.data_ptr<int32_t>(),
// num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
// topk_ids.numel());
// });
// } else {
// VLLM_DISPATCH_INTEGRAL_TYPES(
// topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// auto kernel =
// vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
// AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
// (void*)kernel, shared_mem_i32));
// kernel<<<1, num_thread, shared_mem_i32, stream>>>(
// topk_ids.data_ptr<scalar_t>(),
// sorted_token_ids.data_ptr<int32_t>(),
// experts_ids.data_ptr<int32_t>(),
// num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
// topk_ids.numel());
// });
// }
torch::Tensor cumsum_buffer =
torch::empty({num_experts + 1}, options_int);
auto kernel = vllm::moe::sgl_ep_moe_align_block_size_kernel<scalar_t>;
kernel<<<1, 1024, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>(), start_expert, end_expert);
});
} else if (use_i16) {
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "ep_moe_align_block_size_kernel", [&] {
auto kernel =
vllm::moe::ep_moe_align_block_size_kernel<scalar_t, uint16_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem_i16));
kernel<<<1, num_thread, shared_mem_i16, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel(), start_expert, end_expert);
});
} else {
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "ep_moe_align_block_size_kernel", [&] {
auto kernel =
......@@ -570,6 +623,8 @@ void ep_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel(), start_expert, end_expert);
});
}
}
void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment