#include #include #include #include const static size_t NUM_MAX_EXPERTS = 64; #define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH( \ TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) template __global__ void moe_alig_block_size_kernel(scalar_t *__restrict__ topk_ids, int32_t *sorted_token_ids, int32_t *expert_ids, int32_t *total_tokens_post_pad, int32_t num_experts, int32_t block_size, size_t numel) { const size_t tokens_per_thread = ((numel + blockDim.x - 1) / blockDim.x); const size_t start_idx = threadIdx.x * tokens_per_thread; __shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS]; __shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1]; for(int i = 0;i < num_experts;i++){ tokens_cnts[threadIdx.x + 1][i] = 0; } for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { ++tokens_cnts[threadIdx.x + 1][topk_ids[i]]; } __syncthreads(); tokens_cnts[0][threadIdx.x] = 0; for(int i=1;i<=blockDim.x;++i){ tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x]; } __syncthreads(); if(threadIdx.x ==0){ cumsum[0] = 0; for(int i=1;i<=num_experts;++i){ cumsum[i] = cumsum[i-1] + (tokens_cnts[blockDim.x][i - 1] + block_size - 1) / block_size * block_size; } *total_tokens_post_pad = cumsum[num_experts]; } __syncthreads(); for(int i= cumsum[threadIdx.x];i<<<1, num_experts, 0, stream>>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), num_experts, block_size, topk_ids.numel()); }); }