Commit 863176e5 authored by SAC_fanth's avatar SAC_fanth
Browse files

增加reduce修改

parent a5d54d38
...@@ -173,6 +173,35 @@ __global__ void moe_sum_kernel( ...@@ -173,6 +173,35 @@ __global__ void moe_sum_kernel(
} }
} }
template <typename scalar_t, int TOPK, int SPLIT_D, int BLOCK_DIM>
__global__ void moe_sum_sharedmem_topk8(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ input,
const int d) {
const int token_idx = blockIdx.x / SPLIT_D;
const int sub_block = blockIdx.x % SPLIT_D;
const int d_per_block = (d + SPLIT_D - 1) / SPLIT_D;
const int64_t d_start = sub_block * d_per_block;
const int64_t token_offset = token_idx * TOPK * d;
const int64_t d_end = min(d_start + d_per_block, d);
__shared__ __align__(16) scalar_t sem_input[TOPK][BLOCK_DIM];
for (int64_t idx = d_start + threadIdx.x; idx < d_end; idx += blockDim.x) {
sem_input[0][threadIdx.x] = input[token_offset + 0 * d + idx];
sem_input[1][threadIdx.x] = input[token_offset + 1 * d + idx];
sem_input[2][threadIdx.x] = input[token_offset + 2 * d + idx];
sem_input[3][threadIdx.x] = input[token_offset + 3 * d + idx];
sem_input[4][threadIdx.x] = input[token_offset + 4 * d + idx];
sem_input[5][threadIdx.x] = input[token_offset + 5 * d + idx];
sem_input[6][threadIdx.x] = input[token_offset + 6 * d + idx];
sem_input[7][threadIdx.x] = input[token_offset + 7 * d + idx];
__syncthreads();
scalar_t x = sem_input[0][threadIdx.x] + sem_input[1][threadIdx.x] + sem_input[2][threadIdx.x] +
sem_input[3][threadIdx.x] + sem_input[4][threadIdx.x] + sem_input[5][threadIdx.x] +
sem_input[6][threadIdx.x] + sem_input[7][threadIdx.x];
out[token_idx * d + idx] = x;
}
}
template <typename scalar_t> template <typename scalar_t>
__global__ void moe_align_block_size_small_batch_expert_kernel( __global__ void moe_align_block_size_small_batch_expert_kernel(
const scalar_t* __restrict__ topk_ids, const scalar_t* __restrict__ topk_ids,
...@@ -353,6 +382,67 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] ...@@ -353,6 +382,67 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
}); });
break; break;
default:
at::sum_out(output, input, 1);
break;
}
}
void moe_sum_opt1(torch::Tensor& input, // [num_tokens, topk, hidden_size]
torch::Tensor& output) // [num_tokens, hidden_size]
{
const int hidden_size = input.size(-1);
const auto num_tokens = output.numel() / hidden_size;
const int topk = input.size(1);
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
constexpr int splitD_ = 8;
const int TOPK8_GRID_DIM = num_tokens * splitD_;
constexpr int TOPK8_BLOCK_DIM = 256;
dim3 grid_8(TOPK8_GRID_DIM);
dim3 block_8(TOPK8_BLOCK_DIM);
switch (topk) {
case 2:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
vllm::moe::moe_sum_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;
case 3:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
vllm::moe::moe_sum_kernel<scalar_t, 3><<<grid, block, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;
case 4:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
vllm::moe::moe_sum_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;
case 8:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_sharedmem_topk8", [&]{
vllm::moe::moe_sum_sharedmem_topk8<scalar_t, 8, splitD_, TOPK8_BLOCK_DIM><<<grid_8, block_8, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;
default: default:
at::sum_out(output, input, 1); at::sum_out(output, input, 1);
break; break;
......
...@@ -7,6 +7,7 @@ void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, ...@@ -7,6 +7,7 @@ void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& gating_output); torch::Tensor& gating_output);
void moe_sum(torch::Tensor& input, torch::Tensor& output); void moe_sum(torch::Tensor& input, torch::Tensor& output);
void moe_sum_opt1(torch::Tensor& input, torch::Tensor& output);
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids, int64_t block_size, torch::Tensor sorted_token_ids,
......
...@@ -11,8 +11,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -11,8 +11,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Calculate the result of moe by summing up the partial results // Calculate the result of moe by summing up the partial results
// from all selected experts. // from all selected experts.
m.def("moe_sum(Tensor input, Tensor! output) -> ()"); m.def("moe_sum(Tensor input, Tensor! output) -> ()");
m.def("moe_sum_opt1(Tensor input, Tensor! output) -> ()");
m.impl("moe_sum", torch::kCUDA, &moe_sum); m.impl("moe_sum", torch::kCUDA, &moe_sum);
m.impl("moe_sum_opt1", torch::kCUDA, &moe_sum_opt1);
// Aligning the number of tokens to be processed by each expert such // Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size. // that it is divisible by the block size.
m.def( m.def(
......
...@@ -1971,7 +1971,8 @@ def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype, ...@@ -1971,7 +1971,8 @@ def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype,
# moe # moe
def moe_sum(input: torch.Tensor, output: torch.Tensor): def moe_sum(input: torch.Tensor, output: torch.Tensor):
torch.ops._moe_C.moe_sum(input, output) torch.ops._moe_C.moe_sum(input, output)
def moe_sum_opt1(input: torch.Tensor, output: torch.Tensor):
torch.ops._moe_C.moe_sum_opt1(input, output)
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor, block_size: int, sorted_token_ids: torch.Tensor,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
try: try:
from ._version import __version__, __version_tuple__ __version__ = "0.9.2"
__version_tuple__ = (0, 9, 2)
__hcu_version__ = f'0.9.2+das.opt1.rc1.a5d54d3.dtk25041'
from vllm.version import __version__, __version_tuple__, __hcu_version__
except Exception as e: except Exception as e:
import warnings import warnings
warnings.warn(f"Failed to read commit hash:\n{e}", warnings.warn(f"Failed to read commit hash:\n + str(e)",
RuntimeWarning, RuntimeWarning,
stacklevel=2) stacklevel=2)
__version__ = "dev" __version__ = "dev"
__version_tuple__ = (0, 0, __version__) __version_tuple__ = (0, 0, __version__)
def _prev_minor_version_was(version_str): def _prev_minor_version_was(version_str):
"""Check whether a given version matches the previous minor version. '''Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version. Return True if version_str matches the previous minor version.
...@@ -23,19 +24,19 @@ def _prev_minor_version_was(version_str): ...@@ -23,19 +24,19 @@ def _prev_minor_version_was(version_str):
supplied version_str is '0.6'. supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version. Used for --show-hidden-metrics-for-version.
""" '''
# Match anything if this is a dev tree # Match anything if this is a dev tree
if __version_tuple__[0:2] == (0, 0): if __version_tuple__[0:2] == (0, 0):
return True return True
# Note - this won't do the right thing when we release 1.0! # Note - this won't do the right thing when we release 1.0!
assert __version_tuple__[0] == 0 # assert __version_tuple__[0] == 0
assert isinstance(__version_tuple__[1], int) assert isinstance(__version_tuple__[1], int)
return version_str == f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}" return version_str == f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
def _prev_minor_version(): def _prev_minor_version():
"""For the purpose of testing, return a previous minor version number.""" '''For the purpose of testing, return a previous minor version number.'''
# In dev tree, this will return "0.-1", but that will work fine" # In dev tree, this will return "0.-1", but that will work fine"
assert isinstance(__version_tuple__[1], int) assert isinstance(__version_tuple__[1], int)
return f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}" return f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment