Commit ca034fd5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-fth' into 'v0.9.2-dev'

V0.9.2 dev fth

See merge request dcutoolkit/deeplearing/vllm!199
parents ed392378 59a345ac
......@@ -173,6 +173,35 @@ __global__ void moe_sum_kernel(
}
}
template <typename scalar_t, int TOPK, int SPLIT_D, int BLOCK_DIM>
__global__ void moe_sum_sharedmem_topk8(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ input,
const int d) {
const int token_idx = blockIdx.x / SPLIT_D;
const int sub_block = blockIdx.x % SPLIT_D;
const int d_per_block = (d + SPLIT_D - 1) / SPLIT_D;
const int64_t d_start = sub_block * d_per_block;
const int64_t token_offset = token_idx * TOPK * d;
const int64_t d_end = min(d_start + d_per_block, d);
__shared__ __align__(16) scalar_t sem_input[TOPK][BLOCK_DIM];
for (int64_t idx = d_start + threadIdx.x; idx < d_end; idx += blockDim.x) {
sem_input[0][threadIdx.x] = input[token_offset + 0 * d + idx];
sem_input[1][threadIdx.x] = input[token_offset + 1 * d + idx];
sem_input[2][threadIdx.x] = input[token_offset + 2 * d + idx];
sem_input[3][threadIdx.x] = input[token_offset + 3 * d + idx];
sem_input[4][threadIdx.x] = input[token_offset + 4 * d + idx];
sem_input[5][threadIdx.x] = input[token_offset + 5 * d + idx];
sem_input[6][threadIdx.x] = input[token_offset + 6 * d + idx];
sem_input[7][threadIdx.x] = input[token_offset + 7 * d + idx];
__syncthreads();
scalar_t x = sem_input[0][threadIdx.x] + sem_input[1][threadIdx.x] + sem_input[2][threadIdx.x] +
sem_input[3][threadIdx.x] + sem_input[4][threadIdx.x] + sem_input[5][threadIdx.x] +
sem_input[6][threadIdx.x] + sem_input[7][threadIdx.x];
out[token_idx * d + idx] = x;
}
}
template <typename scalar_t>
__global__ void moe_align_block_size_small_batch_expert_kernel(
const scalar_t* __restrict__ topk_ids,
......@@ -358,3 +387,64 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
break;
}
}
void moe_sum_opt1(torch::Tensor& input, // [num_tokens, topk, hidden_size]
torch::Tensor& output) // [num_tokens, hidden_size]
{
const int hidden_size = input.size(-1);
const auto num_tokens = output.numel() / hidden_size;
const int topk = input.size(1);
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
constexpr int splitD_ = 8;
const int TOPK8_GRID_DIM = num_tokens * splitD_;
constexpr int TOPK8_BLOCK_DIM = 256;
dim3 grid_8(TOPK8_GRID_DIM);
dim3 block_8(TOPK8_BLOCK_DIM);
switch (topk) {
case 2:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
vllm::moe::moe_sum_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;
case 3:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
vllm::moe::moe_sum_kernel<scalar_t, 3><<<grid, block, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;
case 4:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
vllm::moe::moe_sum_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;
case 8:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_sharedmem_topk8", [&]{
vllm::moe::moe_sum_sharedmem_topk8<scalar_t, 8, splitD_, TOPK8_BLOCK_DIM><<<grid_8, block_8, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;
default:
at::sum_out(output, input, 1);
break;
}
}
\ No newline at end of file
......@@ -7,6 +7,7 @@ void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& gating_output);
void moe_sum(torch::Tensor& input, torch::Tensor& output);
void moe_sum_opt1(torch::Tensor& input, torch::Tensor& output);
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
......
......@@ -11,8 +11,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Calculate the result of moe by summing up the partial results
// from all selected experts.
m.def("moe_sum(Tensor input, Tensor! output) -> ()");
m.def("moe_sum_opt1(Tensor input, Tensor! output) -> ()");
m.impl("moe_sum", torch::kCUDA, &moe_sum);
m.impl("moe_sum_opt1", torch::kCUDA, &moe_sum_opt1);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
m.def(
......
......@@ -1971,7 +1971,8 @@ def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype,
# moe
def moe_sum(input: torch.Tensor, output: torch.Tensor):
torch.ops._moe_C.moe_sum(input, output)
def moe_sum_opt1(input: torch.Tensor, output: torch.Tensor):
torch.ops._moe_C.moe_sum_opt1(input, output)
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment