Unverified Commit 0ada960a authored by Xin Yang's avatar Xin Yang Committed by GitHub
Browse files

[Kernel] Support bias type in grouped_topk kernel (#31781)


Signed-off-by: default avatarXin Yang <xyangx@amazon.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent c907d221
...@@ -457,8 +457,8 @@ __device__ inline T apply_scoring(T val) { ...@@ -457,8 +457,8 @@ __device__ inline T apply_scoring(T val) {
} }
} }
template <typename T, ScoringFunc SF> template <typename T, typename BiasT, ScoringFunc SF>
__device__ void topk_with_k2(T* output, T const* input, T const* bias, __device__ void topk_with_k2(T* output, T const* input, BiasT const* bias,
cg::thread_block_tile<32> const& tile, cg::thread_block_tile<32> const& tile,
int32_t const lane_id, int32_t const lane_id,
int const num_experts_per_group) { int const num_experts_per_group) {
...@@ -469,7 +469,7 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias, ...@@ -469,7 +469,7 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
if (num_experts_per_group > WARP_SIZE) { if (num_experts_per_group > WARP_SIZE) {
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
T value = apply_scoring<SF>(input[i]); T value = apply_scoring<SF>(input[i]);
value = value + bias[i]; value = value + static_cast<T>(bias[i]);
if (value > largest) { if (value > largest) {
second_largest = largest; second_largest = largest;
...@@ -481,7 +481,7 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias, ...@@ -481,7 +481,7 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
} else { } else {
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
T value = apply_scoring<SF>(input[i]); T value = apply_scoring<SF>(input[i]);
value = value + bias[i]; value = value + static_cast<T>(bias[i]);
largest = value; largest = value;
} }
} }
...@@ -503,8 +503,8 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias, ...@@ -503,8 +503,8 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
} }
} }
template <typename T, ScoringFunc SF> template <typename T, typename BiasT, ScoringFunc SF>
__global__ void topk_with_k2_kernel(T* output, T* input, T const* bias, __global__ void topk_with_k2_kernel(T* output, T* input, BiasT const* bias,
int64_t const num_tokens, int64_t const num_tokens,
int64_t const num_cases, int64_t const num_cases,
int64_t const n_group, int64_t const n_group,
...@@ -517,7 +517,7 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias, ...@@ -517,7 +517,7 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
input += case_id * num_experts_per_group; input += case_id * num_experts_per_group;
// bias is per expert group, offset to current group // bias is per expert group, offset to current group
int32_t group_id = case_id % n_group; int32_t group_id = case_id % n_group;
T const* group_bias = bias + group_id * num_experts_per_group; BiasT const* group_bias = bias + group_id * num_experts_per_group;
output += case_id; output += case_id;
cg::thread_block block = cg::this_thread_block(); cg::thread_block block = cg::this_thread_block();
...@@ -526,18 +526,19 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias, ...@@ -526,18 +526,19 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;"); asm volatile("griddepcontrol.wait;");
#endif #endif
topk_with_k2<T, SF>(output, input, group_bias, tile, lane_id, topk_with_k2<T, BiasT, SF>(output, input, group_bias, tile, lane_id,
num_experts_per_group); num_experts_per_group);
} }
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;"); asm volatile("griddepcontrol.launch_dependents;");
#endif #endif
} }
template <typename T, typename IdxT, ScoringFunc SF, int NGroup = -1> template <typename T, typename BiasT, typename IdxT, ScoringFunc SF,
int NGroup = -1>
__global__ void group_idx_and_topk_idx_kernel( __global__ void group_idx_and_topk_idx_kernel(
T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices, T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices,
T const* bias, int64_t const num_tokens, int64_t const n_group, BiasT const* bias, int64_t const num_tokens, int64_t const n_group,
int64_t const topk_group, int64_t const topk, int64_t const num_experts, int64_t const topk_group, int64_t const topk, int64_t const num_experts,
int64_t const num_experts_per_group, bool renormalize, int64_t const num_experts_per_group, bool renormalize,
double routed_scaling_factor) { double routed_scaling_factor) {
...@@ -623,7 +624,7 @@ __global__ void group_idx_and_topk_idx_kernel( ...@@ -623,7 +624,7 @@ __global__ void group_idx_and_topk_idx_kernel(
T input = scores[offset + i]; T input = scores[offset + i];
if (is_finite(input)) { if (is_finite(input)) {
T score = apply_scoring<SF>(input); T score = apply_scoring<SF>(input);
candidates = score + bias[offset + i]; candidates = score + static_cast<T>(bias[offset + i]);
} }
} }
queue.add(candidates, offset + i); queue.add(candidates, offset + i);
...@@ -698,10 +699,10 @@ __global__ void group_idx_and_topk_idx_kernel( ...@@ -698,10 +699,10 @@ __global__ void group_idx_and_topk_idx_kernel(
#endif #endif
} }
template <typename T, typename IdxT, ScoringFunc SF> template <typename T, typename BiasT, typename IdxT, ScoringFunc SF>
inline void launch_group_idx_and_topk_kernel( inline void launch_group_idx_and_topk_kernel(
cudaLaunchConfig_t const& config, T* scores, T* group_scores, cudaLaunchConfig_t const& config, T* scores, T* group_scores,
float* topk_values, IdxT* topk_indices, T const* bias, float* topk_values, IdxT* topk_indices, BiasT const* bias,
int64_t const num_tokens, int64_t const n_group, int64_t const topk_group, int64_t const num_tokens, int64_t const n_group, int64_t const topk_group,
int64_t const topk, int64_t const num_experts, int64_t const topk, int64_t const num_experts,
int64_t const num_experts_per_group, bool const renormalize, int64_t const num_experts_per_group, bool const renormalize,
...@@ -715,36 +716,36 @@ inline void launch_group_idx_and_topk_kernel( ...@@ -715,36 +716,36 @@ inline void launch_group_idx_and_topk_kernel(
switch (n_group) { switch (n_group) {
case 4: { case 4: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 4>); launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF, 4>);
break; break;
} }
case 8: { case 8: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 8>); launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF, 8>);
break; break;
} }
case 16: { case 16: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 16>); launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF, 16>);
break; break;
} }
case 32: { case 32: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 32>); launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF, 32>);
break; break;
} }
default: { default: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF>); launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF>);
break; break;
} }
} }
} }
template <typename T, typename IdxT> template <typename T, typename BiasT, typename IdxT>
void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
IdxT* topk_indices, T const* bias, int64_t const num_tokens, IdxT* topk_indices, BiasT const* bias,
int64_t const num_experts, int64_t const n_group, int64_t const num_tokens, int64_t const num_experts,
int64_t const topk_group, int64_t const topk, int64_t const n_group, int64_t const topk_group,
bool const renormalize, double const routed_scaling_factor, int64_t const topk, bool const renormalize,
int const scoring_func, bool enable_pdl = false, double const routed_scaling_factor, int const scoring_func,
cudaStream_t const stream = 0) { bool enable_pdl = false, cudaStream_t const stream = 0) {
int64_t num_cases = num_tokens * n_group; int64_t num_cases = num_tokens * n_group;
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
cudaLaunchConfig_t config; cudaLaunchConfig_t config;
...@@ -765,12 +766,12 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, ...@@ -765,12 +766,12 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
}; };
switch (sf) { switch (sf) {
case SCORING_NONE: { case SCORING_NONE: {
auto* kernel_instance1 = &topk_with_k2_kernel<T, SCORING_NONE>; auto* kernel_instance1 = &topk_with_k2_kernel<T, BiasT, SCORING_NONE>;
launch_topk_with_k2(kernel_instance1); launch_topk_with_k2(kernel_instance1);
break; break;
} }
case SCORING_SIGMOID: { case SCORING_SIGMOID: {
auto* kernel_instance1 = &topk_with_k2_kernel<T, SCORING_SIGMOID>; auto* kernel_instance1 = &topk_with_k2_kernel<T, BiasT, SCORING_SIGMOID>;
launch_topk_with_k2(kernel_instance1); launch_topk_with_k2(kernel_instance1);
break; break;
} }
...@@ -794,14 +795,14 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, ...@@ -794,14 +795,14 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
config.attrs = attrs; config.attrs = attrs;
switch (sf) { switch (sf) {
case SCORING_NONE: { case SCORING_NONE: {
launch_group_idx_and_topk_kernel<T, IdxT, SCORING_NONE>( launch_group_idx_and_topk_kernel<T, BiasT, IdxT, SCORING_NONE>(
config, scores, group_scores, topk_values, topk_indices, bias, config, scores, group_scores, topk_values, topk_indices, bias,
num_tokens, n_group, topk_group, topk, num_experts, num_tokens, n_group, topk_group, topk, num_experts,
num_experts_per_group, renormalize, routed_scaling_factor); num_experts_per_group, renormalize, routed_scaling_factor);
break; break;
} }
case SCORING_SIGMOID: { case SCORING_SIGMOID: {
launch_group_idx_and_topk_kernel<T, IdxT, SCORING_SIGMOID>( launch_group_idx_and_topk_kernel<T, BiasT, IdxT, SCORING_SIGMOID>(
config, scores, group_scores, topk_values, topk_indices, bias, config, scores, group_scores, topk_values, topk_indices, bias,
num_tokens, n_group, topk_group, topk, num_experts, num_tokens, n_group, topk_group, topk, num_experts,
num_experts_per_group, renormalize, routed_scaling_factor); num_experts_per_group, renormalize, routed_scaling_factor);
...@@ -812,17 +813,23 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, ...@@ -812,17 +813,23 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
} }
} }
#define INSTANTIATE_NOAUX_TC(T, IdxT) \ #define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT) \
template void invokeNoAuxTc<T, IdxT>( \ template void invokeNoAuxTc<T, BiasT, IdxT>( \
T * scores, T * group_scores, float* topk_values, IdxT* topk_indices, \ T * scores, T * group_scores, float* topk_values, IdxT* topk_indices, \
T const* bias, int64_t const num_tokens, int64_t const num_experts, \ BiasT const* bias, int64_t const num_tokens, int64_t const num_experts, \
int64_t const n_group, int64_t const topk_group, int64_t const topk, \ int64_t const n_group, int64_t const topk_group, int64_t const topk, \
bool const renormalize, double const routed_scaling_factor, \ bool const renormalize, double const routed_scaling_factor, \
int const scoring_func, bool enable_pdl, cudaStream_t const stream); int const scoring_func, bool enable_pdl, cudaStream_t const stream);
INSTANTIATE_NOAUX_TC(float, int32_t); INSTANTIATE_NOAUX_TC(float, float, int32_t);
INSTANTIATE_NOAUX_TC(half, int32_t); INSTANTIATE_NOAUX_TC(float, half, int32_t);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t); INSTANTIATE_NOAUX_TC(float, __nv_bfloat16, int32_t);
INSTANTIATE_NOAUX_TC(half, float, int32_t);
INSTANTIATE_NOAUX_TC(half, half, int32_t);
INSTANTIATE_NOAUX_TC(half, __nv_bfloat16, int32_t);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, float, int32_t);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, half, int32_t);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, __nv_bfloat16, int32_t);
} // end namespace moe } // end namespace moe
} // namespace vllm } // namespace vllm
...@@ -831,6 +838,7 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk( ...@@ -831,6 +838,7 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
int64_t topk, bool renormalize, double routed_scaling_factor, int64_t topk, bool renormalize, double routed_scaling_factor,
torch::Tensor const& bias, int64_t scoring_func = 0) { torch::Tensor const& bias, int64_t scoring_func = 0) {
auto data_type = scores.scalar_type(); auto data_type = scores.scalar_type();
auto bias_type = bias.scalar_type();
auto input_size = scores.sizes(); auto input_size = scores.sizes();
int64_t num_tokens = input_size[0]; int64_t num_tokens = input_size[0];
int64_t num_experts = input_size[1]; int64_t num_experts = input_size[1];
...@@ -854,39 +862,62 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk( ...@@ -854,39 +862,62 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
auto stream = c10::cuda::getCurrentCUDAStream(scores.get_device()); auto stream = c10::cuda::getCurrentCUDAStream(scores.get_device());
#define LAUNCH_KERNEL(T, IdxT) \
do { \
switch (bias_type) { \
case torch::kFloat16: \
vllm::moe::invokeNoAuxTc<T, half, IdxT>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<T*>(group_scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<half const*>(bias.data_ptr()), num_tokens, \
num_experts, n_group, topk_group, topk, renormalize, \
routed_scaling_factor, static_cast<int>(scoring_func), false, \
stream); \
break; \
case torch::kFloat32: \
vllm::moe::invokeNoAuxTc<T, float, IdxT>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<T*>(group_scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<float const*>(bias.data_ptr()), num_tokens, \
num_experts, n_group, topk_group, topk, renormalize, \
routed_scaling_factor, static_cast<int>(scoring_func), false, \
stream); \
break; \
case torch::kBFloat16: \
vllm::moe::invokeNoAuxTc<T, __nv_bfloat16, IdxT>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<T*>(group_scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), \
num_tokens, num_experts, n_group, topk_group, topk, renormalize, \
routed_scaling_factor, static_cast<int>(scoring_func), false, \
stream); \
break; \
default: \
throw std::invalid_argument( \
"Invalid bias dtype, only supports float16, float32, and " \
"bfloat16"); \
break; \
} \
} while (0)
switch (data_type) { switch (data_type) {
case torch::kFloat16: case torch::kFloat16:
// Handle Float16 // Handle Float16
vllm::moe::invokeNoAuxTc<half, int32_t>( LAUNCH_KERNEL(half, int32_t);
reinterpret_cast<half*>(scores.mutable_data_ptr()),
reinterpret_cast<half*>(group_scores.mutable_data_ptr()),
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
reinterpret_cast<half const*>(bias.data_ptr()), num_tokens,
num_experts, n_group, topk_group, topk, renormalize,
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
break; break;
case torch::kFloat32: case torch::kFloat32:
// Handle Float32 // Handle Float32
vllm::moe::invokeNoAuxTc<float, int32_t>( LAUNCH_KERNEL(float, int32_t);
reinterpret_cast<float*>(scores.mutable_data_ptr()),
reinterpret_cast<float*>(group_scores.mutable_data_ptr()),
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
reinterpret_cast<float const*>(bias.data_ptr()), num_tokens,
num_experts, n_group, topk_group, topk, renormalize,
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
break; break;
case torch::kBFloat16: case torch::kBFloat16:
// Handle BFloat16 // Handle BFloat16
vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>( LAUNCH_KERNEL(__nv_bfloat16, int32_t);
reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()),
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), num_tokens,
num_experts, n_group, topk_group, topk, renormalize,
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
break; break;
default: default:
// Handle other data types // Handle other data types
...@@ -894,5 +925,6 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk( ...@@ -894,5 +925,6 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
"Invalid dtype, only supports float16, float32, and bfloat16"); "Invalid dtype, only supports float16, float32, and bfloat16");
break; break;
} }
#undef LAUNCH_KERNEL
return {topk_values, topk_indices}; return {topk_values, topk_indices};
} }
...@@ -34,7 +34,8 @@ from vllm.utils.torch_utils import set_random_seed ...@@ -34,7 +34,8 @@ from vllm.utils.torch_utils import set_random_seed
@pytest.mark.parametrize("topk_group", [2]) @pytest.mark.parametrize("topk_group", [2])
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"]) @pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5]) @pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float32])
@pytest.mark.parametrize("bias_dtype", [torch.float32])
def test_grouped_topk( def test_grouped_topk(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
n_token: int, n_token: int,
...@@ -46,7 +47,8 @@ def test_grouped_topk( ...@@ -46,7 +47,8 @@ def test_grouped_topk(
topk_group: int, topk_group: int,
scoring_func: str, scoring_func: str,
routed_scaling_factor: float, routed_scaling_factor: float,
dtype: torch.dtype, input_dtype: torch.dtype,
bias_dtype: torch.dtype,
): ):
vllm_config = VllmConfig( vllm_config = VllmConfig(
compilation_config=CompilationConfig(custom_ops=["all", "+grouped_topk"]) compilation_config=CompilationConfig(custom_ops=["all", "+grouped_topk"])
...@@ -54,11 +56,9 @@ def test_grouped_topk( ...@@ -54,11 +56,9 @@ def test_grouped_topk(
get_cached_compilation_config.cache_clear() get_cached_compilation_config.cache_clear()
set_random_seed(0) set_random_seed(0)
hidden_states = torch.randn((n_token, n_hidden), dtype=dtype, device="cuda") hidden_states = torch.randn((n_token, n_hidden), dtype=input_dtype, device="cuda")
gating_output = torch.randn((n_token, n_expert), dtype=dtype, device="cuda") gating_output = torch.randn((n_token, n_expert), dtype=input_dtype, device="cuda")
e_score_correction_bias = torch.randn( e_score_correction_bias = torch.randn((n_expert,), dtype=bias_dtype, device="cuda")
(n_expert,), dtype=torch.float32, device="cuda"
)
with set_current_vllm_config(vllm_config), monkeypatch.context() as m: with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0") m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
......
...@@ -1627,7 +1627,7 @@ def fused_grouped_topk( ...@@ -1627,7 +1627,7 @@ def fused_grouped_topk(
topk, topk,
renormalize, renormalize,
routed_scaling_factor, routed_scaling_factor,
e_score_correction_bias.to(gating_output.dtype), e_score_correction_bias,
1, # scoring_func=1 for sigmoid 1, # scoring_func=1 for sigmoid
) )
elif scoring_func == "softmax": elif scoring_func == "softmax":
...@@ -1641,7 +1641,7 @@ def fused_grouped_topk( ...@@ -1641,7 +1641,7 @@ def fused_grouped_topk(
topk, topk,
renormalize, renormalize,
routed_scaling_factor, routed_scaling_factor,
e_score_correction_bias.to(gating_output.dtype), e_score_correction_bias,
0, # scoring_func=0 (no activation, scores already computed) 0, # scoring_func=0 (no activation, scores already computed)
) )
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment