Commit 7d183246 authored by wujl5's avatar wujl5
Browse files

optimized sum kernel

parent 2f6f5bb3
...@@ -363,6 +363,35 @@ __global__ void moe_sum_kernel( ...@@ -363,6 +363,35 @@ __global__ void moe_sum_kernel(
} }
} }
template <typename scalar_t, int TOPK, int SPLIT_D, int BLOCK_DIM>
__global__ void moe_sum_sharedmem_topk8(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ input,
const int d) {
const int token_idx = blockIdx.x / SPLIT_D;
const int sub_block = blockIdx.x % SPLIT_D;
const int d_per_block = (d + SPLIT_D - 1) / SPLIT_D;
const int64_t d_start = sub_block * d_per_block;
const int64_t token_offset = token_idx * TOPK * d;
const int64_t d_end = min(d_start + d_per_block, d);
__shared__ __align__(16) scalar_t sem_input[TOPK][BLOCK_DIM];
for (int64_t idx = d_start + threadIdx.x; idx < d_end; idx += blockDim.x) {
sem_input[0][threadIdx.x] = input[token_offset + 0 * d + idx];
sem_input[1][threadIdx.x] = input[token_offset + 1 * d + idx];
sem_input[2][threadIdx.x] = input[token_offset + 2 * d + idx];
sem_input[3][threadIdx.x] = input[token_offset + 3 * d + idx];
sem_input[4][threadIdx.x] = input[token_offset + 4 * d + idx];
sem_input[5][threadIdx.x] = input[token_offset + 5 * d + idx];
sem_input[6][threadIdx.x] = input[token_offset + 6 * d + idx];
sem_input[7][threadIdx.x] = input[token_offset + 7 * d + idx];
__syncthreads();
scalar_t x = sem_input[0][threadIdx.x] + sem_input[1][threadIdx.x] + sem_input[2][threadIdx.x] +
sem_input[3][threadIdx.x] + sem_input[4][threadIdx.x] + sem_input[5][threadIdx.x] +
sem_input[6][threadIdx.x] + sem_input[7][threadIdx.x];
out[token_idx * d + idx] = x;
}
}
} // namespace moe } // namespace moe
} // namespace vllm } // namespace vllm
...@@ -504,6 +533,12 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] ...@@ -504,6 +533,12 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
constexpr int splitD_ = 8;
const int TOPK8_GRID_DIM = num_tokens * splitD_;
constexpr int TOPK8_BLOCK_DIM = 256;
dim3 grid_8(TOPK8_GRID_DIM);
dim3 block_8(TOPK8_BLOCK_DIM);
switch (topk) { switch (topk) {
case 2: case 2:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
...@@ -530,15 +565,15 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] ...@@ -530,15 +565,15 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
break; break;
case 8: case 8:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_sharedmem_topk8", [&]{
vllm::moe::moe_sum_kernel<scalar_t, 8><<<grid, block, 0, stream>>>( vllm::moe::moe_sum_sharedmem_topk8<scalar_t, 8, splitD_, TOPK8_BLOCK_DIM><<<grid_8, block_8, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size); hidden_size);
}); });
break; break;
default: default:
at::sum_out(output, input, 1); at::sum_out(output, input, 1);
break; break;
} }
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment