Unverified Commit 63227acc authored by Xin Yang's avatar Xin Yang Committed by GitHub
Browse files

[Kernel] Add topk_sigmoid kernel (#31246)


Signed-off-by: default avatarXin Yang <xyangx@amazon.com>
parent e675dda6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import torch
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk
from vllm.triton_utils import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser
num_tokens_range = [2**i for i in range(0, 8, 2)]
num_experts_range = [16, 32, 64, 128, 256, 512]
topk_range = [3, 4]
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
def torch_topk(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
scoring_func: str = "softmax",
):
if scoring_func == "softmax":
scores = torch.softmax(gating_output.float(), dim=-1)
else:
scores = torch.sigmoid(gating_output.float())
topk_weights, topk_ids = torch.topk(scores, k=topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def get_benchmark(scoring_func):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens", "num_experts", "topk"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["torch", "vllm"],
line_names=["Torch", "vLLM"],
styles=[("blue", "-"), ("red", "-")],
ylabel="us",
plot_name=f"fused-topk-perf-{scoring_func}",
args={},
)
)
def benchmark(num_tokens, num_experts, topk, provider):
dtype = torch.bfloat16
hidden_size = 1024
renormalize = True
hidden_states = torch.randn(
(num_tokens, hidden_size), dtype=dtype, device="cuda"
)
gating_output = torch.randn(
(num_tokens, num_experts), dtype=dtype, device="cuda"
)
quantiles = [0.5, 0.2, 0.8]
if provider == "torch":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch_topk(
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
scoring_func=scoring_func,
),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fused_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
scoring_func=scoring_func,
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark the MoE topk kernel.")
parser.add_argument("--scoring-func", type=str, default="softmax")
parser.add_argument("--save-path", type=str, default="./configs/fused_topk/")
args = parser.parse_args()
# Get the benchmark function
benchmark = get_benchmark(args.scoring_func)
# Run performance benchmark
benchmark.run(print_data=True, save_path=args.save_path)
...@@ -4,7 +4,13 @@ ...@@ -4,7 +4,13 @@
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices, torch::Tensor& token_expert_indices,
torch::Tensor& gating_output, bool renormalize); torch::Tensor& gating_output, bool renormalize,
std::optional<torch::Tensor> bias);
void topk_sigmoid(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output, bool renormalize,
std::optional<torch::Tensor> bias);
void moe_sum(torch::Tensor& input, torch::Tensor& output); void moe_sum(torch::Tensor& input, torch::Tensor& output);
......
...@@ -62,6 +62,12 @@ __device__ __forceinline__ float toFloat(T value) { ...@@ -62,6 +62,12 @@ __device__ __forceinline__ float toFloat(T value) {
} }
} }
// Scoring function enums
enum ScoringFunc {
SCORING_SOFTMAX = 0, // apply softmax
SCORING_SIGMOID = 1 // apply sigmoid
};
// ====================== Softmax things =============================== // ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing the output // We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing. // in the softmax kernel when we extend this module to support expert-choice routing.
...@@ -125,6 +131,27 @@ __launch_bounds__(TPB) __global__ ...@@ -125,6 +131,27 @@ __launch_bounds__(TPB) __global__
} }
} }
template <int TPB, typename InputType>
__launch_bounds__(TPB) __global__
void moeSigmoid(const InputType* input, const bool* finished, float* output, const int num_cols)
{
const int thread_row_offset = blockIdx.x * num_cols;
// Don't touch finished rows.
if ((finished != nullptr) && finished[blockIdx.x])
{
return;
}
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{
const int idx = thread_row_offset + ii;
const float val = toFloat(input[idx]);
const float sigmoid_val = 1.0f / (1.0f + __expf(-val));
output[idx] = sigmoid_val;
}
}
template <int TPB, typename IndType> template <int TPB, typename IndType>
__launch_bounds__(TPB) __global__ void moeTopK( __launch_bounds__(TPB) __global__ void moeTopK(
const float* inputs_after_softmax, const float* inputs_after_softmax,
...@@ -136,7 +163,8 @@ __launch_bounds__(TPB) __global__ void moeTopK( ...@@ -136,7 +163,8 @@ __launch_bounds__(TPB) __global__ void moeTopK(
const int k, const int k,
const int start_expert, const int start_expert,
const int end_expert, const int end_expert,
const bool renormalize) const bool renormalize,
const float* bias)
{ {
using cub_kvp = cub::KeyValuePair<int, float>; using cub_kvp = cub::KeyValuePair<int, float>;
...@@ -162,7 +190,13 @@ __launch_bounds__(TPB) __global__ void moeTopK( ...@@ -162,7 +190,13 @@ __launch_bounds__(TPB) __global__ void moeTopK(
{ {
const int idx = thread_read_offset + expert; const int idx = thread_read_offset + expert;
inp_kvp.key = expert; inp_kvp.key = expert;
// Apply correction bias if provided
if (bias != nullptr) {
inp_kvp.value = inputs_after_softmax[idx] + bias[expert];
} else {
inp_kvp.value = inputs_after_softmax[idx]; inp_kvp.value = inputs_after_softmax[idx];
}
for (int prior_k = 0; prior_k < k_idx; ++prior_k) for (int prior_k = 0; prior_k < k_idx; ++prior_k)
{ {
...@@ -186,12 +220,13 @@ __launch_bounds__(TPB) __global__ void moeTopK( ...@@ -186,12 +220,13 @@ __launch_bounds__(TPB) __global__ void moeTopK(
const bool should_process_row = row_is_active && node_uses_expert; const bool should_process_row = row_is_active && node_uses_expert;
const int idx = k * block_row + k_idx; const int idx = k * block_row + k_idx;
output[idx] = result_kvp.value; // Return the unbiased scores for output weights
output[idx] = inputs_after_softmax[thread_read_offset + expert];
indices[idx] = should_process_row ? (expert - start_expert) : num_experts; indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
assert(indices[idx] >= 0); assert(indices[idx] >= 0);
source_rows[idx] = k_idx * num_rows + block_row; source_rows[idx] = k_idx * num_rows + block_row;
if (renormalize) { if (renormalize) {
selected_sum += result_kvp.value; selected_sum += inputs_after_softmax[thread_read_offset + expert];
} }
} }
__syncthreads(); __syncthreads();
...@@ -225,10 +260,12 @@ __launch_bounds__(TPB) __global__ void moeTopK( ...@@ -225,10 +260,12 @@ __launch_bounds__(TPB) __global__ void moeTopK(
2) This implementation assumes k is small, but will work for any k. 2) This implementation assumes k is small, but will work for any k.
*/ */
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename IndType, typename InputType = float> template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename IndType,
typename InputType = float, ScoringFunc SF>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
void topkGatingSoftmax(const InputType* input, const bool* finished, float* output, const int num_rows, IndType* indices, void topkGating(const InputType* input, const bool* finished, float* output, const int num_rows, IndType* indices,
int* source_rows, const int k, const int start_expert, const int end_expert, const bool renormalize) int* source_rows, const int k, const int start_expert, const int end_expert, const bool renormalize,
const float* bias)
{ {
static_assert(std::is_same_v<InputType, float> || std::is_same_v<InputType, __nv_bfloat16> || static_assert(std::is_same_v<InputType, float> || std::is_same_v<InputType, __nv_bfloat16> ||
std::is_same_v<InputType, __half>, std::is_same_v<InputType, __half>,
...@@ -353,12 +390,11 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ ...@@ -353,12 +390,11 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
} }
} }
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just if constexpr (SF == SCORING_SOFTMAX) {
// convert to float afterwards for the exp + sum reduction. // First, we perform a max reduce within the thread.
float thread_max = row_chunk[0]; float thread_max = row_chunk[0];
#pragma unroll #pragma unroll
for (int ii = 1; ii < VPT; ++ii) for (int ii = 1; ii < VPT; ++ii) {
{
thread_max = max(thread_max, row_chunk[ii]); thread_max = max(thread_max, row_chunk[ii]);
} }
...@@ -398,16 +434,45 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ ...@@ -398,16 +434,45 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
{ {
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
} }
} else if constexpr (SF == SCORING_SIGMOID) {
#pragma unroll
for (int ii = 0; ii < VPT; ++ii)
{
row_chunk[ii] = 1.0f / (1.0f + __expf(-row_chunk[ii]));
}
}
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along // If bias is not null, use biased value for selection
float row_chunk_for_choice[VPT];
// Apply correction bias
if (bias != nullptr) {
#pragma unroll
for (int ldg = 0; ldg < LDG_PER_THREAD; ++ldg) {
#pragma unroll
for (int ii = 0; ii < ELTS_PER_LDG; ++ii) {
const int expert = first_elt_read_by_thread + ldg * COLS_PER_GROUP_LDG + ii;
float bias_val = expert < NUM_EXPERTS ? bias[expert] : 0.0f;
row_chunk_for_choice[ldg * ELTS_PER_LDG + ii] = row_chunk[ldg * ELTS_PER_LDG + ii] + bias_val;
}
}
} else {
#pragma unroll
for (int ii = 0; ii < VPT; ++ii) {
row_chunk_for_choice[ii] = row_chunk[ii];
}
}
// Now, row_chunk contains the softmax / sigmoid of the row chunk. Now, I want to find the topk elements in each row, along
// with the max index. // with the max index.
int start_col = first_elt_read_by_thread; int start_col = first_elt_read_by_thread;
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
float selected_sum = 0.f; float selected_sum = 0.f;
for (int k_idx = 0; k_idx < k; ++k_idx) for (int k_idx = 0; k_idx < k; ++k_idx)
{ {
// First, each thread does the local argmax // First, each thread does the local argmax
float max_val_for_choice = row_chunk_for_choice[0];
float max_val = row_chunk[0]; float max_val = row_chunk[0];
int expert = start_col; int expert = start_col;
#pragma unroll #pragma unroll
...@@ -416,12 +481,14 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ ...@@ -416,12 +481,14 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
#pragma unroll #pragma unroll
for (int ii = 0; ii < ELTS_PER_LDG; ++ii) for (int ii = 0; ii < ELTS_PER_LDG; ++ii)
{ {
float val_for_choice = row_chunk_for_choice[ldg * ELTS_PER_LDG + ii];
float val = row_chunk[ldg * ELTS_PER_LDG + ii]; float val = row_chunk[ldg * ELTS_PER_LDG + ii];
// No check on the experts here since columns with the smallest index are processed first and only // No check on the experts here since columns with the smallest index are processed first and only
// updated if > (not >=) // updated if > (not >=)
if (val > max_val) if (val_for_choice > max_val_for_choice)
{ {
max_val_for_choice = val_for_choice;
max_val = val; max_val = val;
expert = col + ii; expert = col + ii;
} }
...@@ -434,12 +501,14 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ ...@@ -434,12 +501,14 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
#pragma unroll #pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{ {
float other_max_for_choice = VLLM_SHFL_XOR_SYNC_WIDTH(max_val_for_choice, mask, THREADS_PER_ROW);
float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW); float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW); int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);
// We want lower indices to "win" in every thread so we break ties this way // We want lower indices to "win" in every thread so we break ties this way
if (other_max > max_val || (other_max == max_val && other_expert < expert)) if (other_max_for_choice > max_val_for_choice || (other_max_for_choice == max_val_for_choice && other_expert < expert))
{ {
max_val_for_choice = other_max_for_choice;
max_val = other_max; max_val = other_max;
expert = other_expert; expert = other_expert;
} }
...@@ -474,7 +543,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ ...@@ -474,7 +543,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
{ {
const int offset_for_expert = expert % ELTS_PER_LDG; const int offset_for_expert = expert % ELTS_PER_LDG;
// Safe to set to any negative value since row_chunk values must be between 0 and 1. // Safe to set to any negative value since row_chunk values must be between 0 and 1.
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f; row_chunk_for_choice[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
} }
} }
} }
...@@ -508,10 +577,10 @@ struct TopkConstants ...@@ -508,10 +577,10 @@ struct TopkConstants
}; };
} // namespace detail } // namespace detail
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, int MAX_BYTES_PER_LDG, typename IndType, typename InputType> template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, int MAX_BYTES_PER_LDG, typename IndType, typename InputType, ScoringFunc SF>
void topkGatingSoftmaxLauncherHelper(const InputType* input, const bool* finished, float* output, IndType* indices, void topkGatingLauncherHelper(const InputType* input, const bool* finished, float* output, IndType* indices,
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, const bool renormalize, int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, const bool renormalize,
cudaStream_t stream) const float* bias, cudaStream_t stream)
{ {
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(InputType) * EXPERTS); static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(InputType) * EXPERTS);
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM, InputType>; using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM, InputType>;
...@@ -521,43 +590,51 @@ void topkGatingSoftmaxLauncherHelper(const InputType* input, const bool* finishe ...@@ -521,43 +590,51 @@ void topkGatingSoftmaxLauncherHelper(const InputType* input, const bool* finishe
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB); dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, WARP_SIZE_PARAM, IndType, InputType><<<num_blocks, block_dim, 0, stream>>>( topkGating<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, WARP_SIZE_PARAM, IndType, InputType, SF><<<num_blocks, block_dim, 0, stream>>>(
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert, renormalize); input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert, renormalize, bias);
} }
#ifndef USE_ROCM #ifndef USE_ROCM
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \ #define LAUNCH_TOPK(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
static_assert(WARP_SIZE == 32, \ static_assert(WARP_SIZE == 32, \
"Unsupported warp size. Only 32 is supported for CUDA"); \ "Unsupported warp size. Only 32 is supported for CUDA"); \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES>( \ topkGatingLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES, \
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \ IndType, InputType, SF>( \
num_tokens, topk, 0, num_experts, renormalize, stream); gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, renormalize, \
bias, stream);
#else #else
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \ #define LAUNCH_TOPK(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
if (WARP_SIZE == 64) { \ if (WARP_SIZE == 64) { \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64, MAX_BYTES>( \ topkGatingLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64, MAX_BYTES, \
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \ IndType, InputType, SF>( \
num_tokens, topk, 0, num_experts, renormalize, stream); \ gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, renormalize, \
bias, stream); \
} else if (WARP_SIZE == 32) { \ } else if (WARP_SIZE == 32) { \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32, MAX_BYTES>( \ topkGatingLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32, MAX_BYTES, \
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \ IndType, InputType, SF>( \
num_tokens, topk, 0, num_experts, renormalize, stream); \ gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, renormalize, \
bias, stream); \
} else { \ } else { \
assert(false && "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \ assert(false && \
"Unsupported warp size. Only 32 and 64 are supported for ROCm"); \
} }
#endif #endif
template <typename IndType, typename InputType> template <typename IndType, typename InputType, ScoringFunc SF>
void topkGatingSoftmaxKernelLauncher( void topkGatingKernelLauncher(
const InputType* gating_output, const InputType* gating_output,
float* topk_weights, float* topk_weights,
IndType* topk_indices, IndType* topk_indices,
int* token_expert_indices, int* token_expert_indices,
float* softmax_workspace, float* workspace,
const int num_tokens, const int num_tokens,
const int num_experts, const int num_experts,
const int topk, const int topk,
const bool renormalize, const bool renormalize,
const float* bias,
cudaStream_t stream) { cudaStream_t stream) {
static constexpr int WARPS_PER_TB = 4; static constexpr int WARPS_PER_TB = 4;
static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16; static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16;
...@@ -569,64 +646,71 @@ void topkGatingSoftmaxKernelLauncher( ...@@ -569,64 +646,71 @@ void topkGatingSoftmaxKernelLauncher(
#endif #endif
switch (num_experts) { switch (num_experts) {
case 1: case 1:
LAUNCH_SOFTMAX(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); LAUNCH_TOPK(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break; break;
case 2: case 2:
LAUNCH_SOFTMAX(2, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); LAUNCH_TOPK(2, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break; break;
case 4: case 4:
LAUNCH_SOFTMAX(4, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); LAUNCH_TOPK(4, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break; break;
case 8: case 8:
LAUNCH_SOFTMAX(8, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); LAUNCH_TOPK(8, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break; break;
case 16: case 16:
LAUNCH_SOFTMAX(16, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); LAUNCH_TOPK(16, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break; break;
case 32: case 32:
LAUNCH_SOFTMAX(32, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); LAUNCH_TOPK(32, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break; break;
case 64: case 64:
LAUNCH_SOFTMAX(64, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); LAUNCH_TOPK(64, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break; break;
case 128: case 128:
LAUNCH_SOFTMAX(128, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); LAUNCH_TOPK(128, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break; break;
case 256: case 256:
LAUNCH_SOFTMAX(256, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); LAUNCH_TOPK(256, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break; break;
case 512: case 512:
LAUNCH_SOFTMAX(512, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); LAUNCH_TOPK(512, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break; break;
// (CUDA only) support multiples of 64 when num_experts is not power of 2. // (CUDA only) support multiples of 64 when num_experts is not power of 2.
// ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of num_experts, // ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of num_experts,
// alternatively we can test 4 bytes loading and enable it in future. // alternatively we can test 4 bytes loading and enable it in future.
#ifndef USE_ROCM #ifndef USE_ROCM
case 192: case 192:
LAUNCH_SOFTMAX(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); LAUNCH_TOPK(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
break; break;
case 320: case 320:
LAUNCH_SOFTMAX(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); LAUNCH_TOPK(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
break; break;
case 384: case 384:
LAUNCH_SOFTMAX(384, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); LAUNCH_TOPK(384, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
break; break;
case 448: case 448:
LAUNCH_SOFTMAX(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); LAUNCH_TOPK(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
break; break;
case 576: case 576:
LAUNCH_SOFTMAX(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); LAUNCH_TOPK(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
break; break;
#endif #endif
default: { default: {
TORCH_CHECK(softmax_workspace != nullptr, TORCH_CHECK(workspace != nullptr,
"softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64."); "workspace must be provided for num_experts that are not a power of 2 or multiple of 64.");
static constexpr int TPB = 256; static constexpr int TPB = 256;
if constexpr (SF == SCORING_SOFTMAX) {
moeSoftmax<TPB, InputType><<<num_tokens, TPB, 0, stream>>>( moeSoftmax<TPB, InputType><<<num_tokens, TPB, 0, stream>>>(
gating_output, nullptr, softmax_workspace, num_experts); gating_output, nullptr, workspace, num_experts);
} else if constexpr (SF == SCORING_SIGMOID) {
moeSigmoid<TPB, InputType><<<num_tokens, TPB, 0, stream>>>(
gating_output, nullptr, workspace, num_experts);
} else {
TORCH_CHECK(false, "Unsupported scoring func");
}
moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>( moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
softmax_workspace, nullptr, topk_weights, topk_indices, token_expert_indices, workspace, nullptr, topk_weights, topk_indices, token_expert_indices,
num_experts, topk, 0, num_experts, renormalize); num_experts, topk, 0, num_experts, renormalize, bias);
} }
} }
} }
...@@ -635,40 +719,55 @@ void topkGatingSoftmaxKernelLauncher( ...@@ -635,40 +719,55 @@ void topkGatingSoftmaxKernelLauncher(
} // namespace vllm } // namespace vllm
template<typename ComputeType> template<typename ComputeType, vllm::moe::ScoringFunc SF>
void dispatch_topk_softmax_launch( void dispatch_topk_launch(
torch::Tensor& gating_output, torch::Tensor& gating_output,
torch::Tensor& topk_weights, torch::Tensor& topk_weights,
torch::Tensor& topk_indices, torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices, torch::Tensor& token_expert_indices,
torch::Tensor& softmax_workspace, torch::Tensor& softmax_workspace,
int num_tokens, int num_experts, int topk, bool renormalize, cudaStream_t stream) int num_tokens, int num_experts, int topk, bool renormalize,
{ std::optional<torch::Tensor> bias,
cudaStream_t stream)
{
const float* bias_ptr = nullptr;
if (bias.has_value()) {
const torch::Tensor& bias_tensor = bias.value();
TORCH_CHECK(bias_tensor.scalar_type() == at::ScalarType::Float, "bias tensor must be float32");
TORCH_CHECK(bias_tensor.dim() == 1, "bias tensor must be 1D");
TORCH_CHECK(bias_tensor.size(0) == num_experts, "bias size mismatch, expected: ", num_experts);
TORCH_CHECK(bias_tensor.is_contiguous(), "bias tensor must be contiguous");
bias_ptr = bias_tensor.data_ptr<float>();
}
if (topk_indices.scalar_type() == at::ScalarType::Int) { if (topk_indices.scalar_type() == at::ScalarType::Int) {
vllm::moe::topkGatingSoftmaxKernelLauncher<int, ComputeType>( vllm::moe::topkGatingKernelLauncher<int, ComputeType, SF>(
reinterpret_cast<const ComputeType*>(gating_output.data_ptr()), reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
topk_weights.data_ptr<float>(), topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(), topk_indices.data_ptr<int>(),
token_expert_indices.data_ptr<int>(), token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(), softmax_workspace.data_ptr<float>(),
num_tokens, num_experts, topk, renormalize, stream); num_tokens, num_experts, topk, renormalize,
bias_ptr, stream);
} else if (topk_indices.scalar_type() == at::ScalarType::UInt32) { } else if (topk_indices.scalar_type() == at::ScalarType::UInt32) {
vllm::moe::topkGatingSoftmaxKernelLauncher<uint32_t, ComputeType>( vllm::moe::topkGatingKernelLauncher<uint32_t, ComputeType, SF>(
reinterpret_cast<const ComputeType*>(gating_output.data_ptr()), reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
topk_weights.data_ptr<float>(), topk_weights.data_ptr<float>(),
topk_indices.data_ptr<uint32_t>(), topk_indices.data_ptr<uint32_t>(),
token_expert_indices.data_ptr<int>(), token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(), softmax_workspace.data_ptr<float>(),
num_tokens, num_experts, topk, renormalize, stream); num_tokens, num_experts, topk, renormalize,
bias_ptr, stream);
} else { } else {
TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long); TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long);
vllm::moe::topkGatingSoftmaxKernelLauncher<int64_t, ComputeType>( vllm::moe::topkGatingKernelLauncher<int64_t, ComputeType, SF>(
reinterpret_cast<const ComputeType*>(gating_output.data_ptr()), reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
topk_weights.data_ptr<float>(), topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int64_t>(), topk_indices.data_ptr<int64_t>(),
token_expert_indices.data_ptr<int>(), token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(), softmax_workspace.data_ptr<float>(),
num_tokens, num_experts, topk, renormalize, stream); num_tokens, num_experts, topk, renormalize,
bias_ptr, stream);
} }
} }
...@@ -677,7 +776,8 @@ void topk_softmax( ...@@ -677,7 +776,8 @@ void topk_softmax(
torch::Tensor& topk_indices, // [num_tokens, topk] torch::Tensor& topk_indices, // [num_tokens, topk]
torch::Tensor& token_expert_indices, // [num_tokens, topk] torch::Tensor& token_expert_indices, // [num_tokens, topk]
torch::Tensor& gating_output, // [num_tokens, num_experts] torch::Tensor& gating_output, // [num_tokens, num_experts]
bool renormalize) bool renormalize,
std::optional<torch::Tensor> bias)
{ {
const int num_experts = gating_output.size(-1); const int num_experts = gating_output.size(-1);
const auto num_tokens = gating_output.numel() / num_experts; const auto num_tokens = gating_output.numel() / num_experts;
...@@ -693,14 +793,55 @@ void topk_softmax( ...@@ -693,14 +793,55 @@ void topk_softmax(
torch::Tensor softmax_workspace = torch::empty({workspace_size}, workspace_options); torch::Tensor softmax_workspace = torch::empty({workspace_size}, workspace_options);
if (gating_output.scalar_type() == at::ScalarType::Float) { if (gating_output.scalar_type() == at::ScalarType::Float) {
dispatch_topk_softmax_launch<float>(gating_output, topk_weights, topk_indices, dispatch_topk_launch<float, vllm::moe::SCORING_SOFTMAX>(gating_output, topk_weights, topk_indices,
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream); token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize,
bias, stream);
} else if (gating_output.scalar_type() == at::ScalarType::Half) {
dispatch_topk_launch<__half, vllm::moe::SCORING_SOFTMAX>(gating_output, topk_weights, topk_indices,
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize,
bias, stream);
} else if (gating_output.scalar_type() == at::ScalarType::BFloat16) {
dispatch_topk_launch<__nv_bfloat16, vllm::moe::SCORING_SOFTMAX>(gating_output, topk_weights, topk_indices,
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize,
bias, stream);
} else {
TORCH_CHECK(false, "Unsupported gating_output data type: ", gating_output.scalar_type());
}
}
void topk_sigmoid(
torch::Tensor& topk_weights, // [num_tokens, topk]
torch::Tensor& topk_indices, // [num_tokens, topk]
torch::Tensor& token_expert_indices, // [num_tokens, topk]
torch::Tensor& gating_output, // [num_tokens, num_experts]
bool renormalize,
std::optional<torch::Tensor> bias)
{
const int num_experts = gating_output.size(-1);
const auto num_tokens = gating_output.numel() / num_experts;
const int topk = topk_weights.size(-1);
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
const bool needs_workspace = !is_pow_2 || num_experts > 256;
const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const auto workspace_options = gating_output.options().dtype(at::ScalarType::Float);
torch::Tensor workspace = torch::empty({workspace_size}, workspace_options);
if (gating_output.scalar_type() == at::ScalarType::Float) {
dispatch_topk_launch<float, vllm::moe::SCORING_SIGMOID>(gating_output, topk_weights, topk_indices,
token_expert_indices, workspace, num_tokens, num_experts, topk, renormalize,
bias, stream);
} else if (gating_output.scalar_type() == at::ScalarType::Half) { } else if (gating_output.scalar_type() == at::ScalarType::Half) {
dispatch_topk_softmax_launch<__half>(gating_output, topk_weights, topk_indices, dispatch_topk_launch<__half, vllm::moe::SCORING_SIGMOID>(gating_output, topk_weights, topk_indices,
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream); token_expert_indices, workspace, num_tokens, num_experts, topk, renormalize,
bias, stream);
} else if (gating_output.scalar_type() == at::ScalarType::BFloat16) { } else if (gating_output.scalar_type() == at::ScalarType::BFloat16) {
dispatch_topk_softmax_launch<__nv_bfloat16>(gating_output, topk_weights, topk_indices, dispatch_topk_launch<__nv_bfloat16, vllm::moe::SCORING_SIGMOID>(gating_output, topk_weights, topk_indices,
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream); token_expert_indices, workspace, num_tokens, num_experts, topk, renormalize,
bias, stream);
} else { } else {
TORCH_CHECK(false, "Unsupported gating_output data type: ", gating_output.scalar_type()); TORCH_CHECK(false, "Unsupported gating_output data type: ", gating_output.scalar_type());
} }
......
...@@ -5,9 +5,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -5,9 +5,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply topk softmax to the gating outputs. // Apply topk softmax to the gating outputs.
m.def( m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output, bool renormalize) -> ()"); "token_expert_indices, Tensor gating_output, bool renormalize, Tensor? "
"bias) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax); m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
// Apply topk sigmoid to the gating outputs.
m.def(
"topk_sigmoid(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output, bool renormalize, Tensor? "
"bias) -> ()");
m.impl("topk_sigmoid", torch::kCUDA, &topk_sigmoid);
// Calculate the result of moe by summing up the partial results // Calculate the result of moe by summing up the partial results
// from all selected experts. // from all selected experts.
m.def("moe_sum(Tensor input, Tensor! output) -> ()"); m.def("moe_sum(Tensor input, Tensor! output) -> ()");
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the MoE fused topk kernel
Run `pytest tests/kernels/moe/test_fused_topk.py`.
"""
import pytest
import torch
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
fused_topk_bias,
)
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk
from vllm.platforms import current_platform
def torch_topk(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
e_score_correction_bias: torch.Tensor = None,
scoring_func: str = "softmax",
):
if scoring_func == "softmax":
scores = torch.softmax(gating_output.float(), dim=-1)
else:
assert scoring_func == "sigmoid"
scores = torch.sigmoid(gating_output.float())
if e_score_correction_bias is not None:
num_experts = gating_output.shape[-1]
scores_for_choice = scores.view(
-1, num_experts
) + e_score_correction_bias.unsqueeze(0)
_, topk_ids = torch.topk(scores_for_choice, k=topk, dim=-1)
topk_weights = scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(scores, k=topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
@pytest.mark.parametrize("num_tokens", [1, 33, 56])
@pytest.mark.parametrize("hidden_size", [1024, 2048])
@pytest.mark.parametrize("num_experts", [6, 16])
@pytest.mark.parametrize("topk", [3, 4])
@pytest.mark.parametrize("renormalize", [True, False])
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32])
def test_fused_topk(
num_tokens: int,
hidden_size: int,
num_experts: int,
topk: int,
renormalize: bool,
scoring_func: str,
dtype: torch.dtype,
):
torch.manual_seed(0)
hidden_states = torch.randn((num_tokens, hidden_size), dtype=dtype, device="cuda")
gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda")
topk_weights_ref, topk_ids_ref = torch_topk(
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
scoring_func=scoring_func,
)
topk_weights, topk_ids, _ = fused_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
scoring_func=scoring_func,
)
torch.testing.assert_close(
topk_weights_ref.to(torch.float32), topk_weights, atol=1e-2, rtol=1e-2
)
torch.testing.assert_close(topk_ids_ref.to(torch.int32), topk_ids, atol=0, rtol=0)
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
@pytest.mark.parametrize("num_tokens", [1, 33, 56])
@pytest.mark.parametrize("hidden_size", [1024, 2048])
@pytest.mark.parametrize("num_experts", [6, 16])
@pytest.mark.parametrize("topk", [3, 4])
@pytest.mark.parametrize("renormalize", [True, False])
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32])
def test_fused_topk_bias(
num_tokens: int,
hidden_size: int,
num_experts: int,
topk: int,
renormalize: bool,
scoring_func: str,
dtype: torch.dtype,
):
torch.manual_seed(0)
hidden_states = torch.randn((num_tokens, hidden_size), dtype=dtype, device="cuda")
gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda")
e_score_correction_bias = torch.randn(
(num_experts,), dtype=torch.float32, device="cuda"
)
topk_weights_ref, topk_ids_ref = torch_topk(
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
e_score_correction_bias=e_score_correction_bias,
scoring_func=scoring_func,
)
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=gating_output,
e_score_correction_bias=e_score_correction_bias,
topk=topk,
renormalize=renormalize,
scoring_func=scoring_func,
)
torch.testing.assert_close(
topk_weights_ref.to(torch.float32), topk_weights, atol=1e-2, rtol=1e-2
)
torch.testing.assert_close(topk_ids_ref.to(torch.int32), topk_ids, atol=0, rtol=0)
...@@ -18,7 +18,9 @@ from vllm.model_executor.layers.activation import ( ...@@ -18,7 +18,9 @@ from vllm.model_executor.layers.activation import (
SiluAndMul, SiluAndMul,
) )
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import ( from vllm.model_executor.layers.fused_moe.router.fused_topk_router import (
dispatch_topk_func, dispatch_topk_sigmoid_func,
dispatch_topk_softmax_func,
vllm_topk_sigmoid,
vllm_topk_softmax, vllm_topk_softmax,
) )
from vllm.model_executor.layers.layernorm import ( from vllm.model_executor.layers.layernorm import (
...@@ -133,8 +135,8 @@ def test_enabled_ops_invalid(env: str): ...@@ -133,8 +135,8 @@ def test_enabled_ops_invalid(env: str):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
) )
def test_topk_dispatch(use_rocm_aiter: bool): def test_topk_softmax_dispatch(use_rocm_aiter: bool):
topk_func = dispatch_topk_func(use_rocm_aiter) topk_func = dispatch_topk_softmax_func(use_rocm_aiter)
if current_platform.is_rocm() and use_rocm_aiter: if current_platform.is_rocm() and use_rocm_aiter:
assert topk_func == rocm_aiter_ops.topk_softmax assert topk_func == rocm_aiter_ops.topk_softmax
...@@ -142,6 +144,18 @@ def test_topk_dispatch(use_rocm_aiter: bool): ...@@ -142,6 +144,18 @@ def test_topk_dispatch(use_rocm_aiter: bool):
assert topk_func == vllm_topk_softmax assert topk_func == vllm_topk_softmax
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
def test_topk_sigmoid_dispatch(use_rocm_aiter: bool):
topk_func = dispatch_topk_sigmoid_func(use_rocm_aiter)
if current_platform.is_rocm() and use_rocm_aiter:
assert topk_func == rocm_aiter_ops.topk_sigmoid
else:
assert topk_func == vllm_topk_sigmoid
@pytest.mark.parametrize("add_residual", [True, False]) @pytest.mark.parametrize("add_residual", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("use_rocm_aiter", [True, False]) @pytest.mark.parametrize("use_rocm_aiter", [True, False])
......
...@@ -200,6 +200,24 @@ def _rocm_aiter_topk_softmax_fake( ...@@ -200,6 +200,24 @@ def _rocm_aiter_topk_softmax_fake(
pass pass
def _rocm_aiter_topk_sigmoid_impl(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
gating_output: torch.Tensor,
) -> None:
from aiter import topk_sigmoid
topk_sigmoid(topk_weights, topk_indices, gating_output)
def _rocm_aiter_topk_sigmoid_fake(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
gating_output: torch.Tensor,
) -> None:
pass
def _rocm_aiter_biased_grouped_topk_impl( def _rocm_aiter_biased_grouped_topk_impl(
gating_output: torch.Tensor, gating_output: torch.Tensor,
correction_bias: torch.Tensor, correction_bias: torch.Tensor,
...@@ -985,6 +1003,14 @@ class rocm_aiter_ops: ...@@ -985,6 +1003,14 @@ class rocm_aiter_ops:
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op(
op_name="rocm_aiter_topk_sigmoid",
op_func=_rocm_aiter_topk_sigmoid_impl,
mutates_args=["topk_weights", "topk_indices"],
fake_impl=_rocm_aiter_topk_sigmoid_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_biased_grouped_topk", op_name="rocm_aiter_biased_grouped_topk",
op_func=_rocm_aiter_biased_grouped_topk_impl, op_func=_rocm_aiter_biased_grouped_topk_impl,
...@@ -1272,6 +1298,19 @@ class rocm_aiter_ops: ...@@ -1272,6 +1298,19 @@ class rocm_aiter_ops:
) )
return topk_weights, topk_indices return topk_weights, topk_indices
@staticmethod
def topk_sigmoid(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> tuple[torch.Tensor, ...]:
torch.ops.vllm.rocm_aiter_topk_sigmoid(
topk_weights, topk_indices, gating_output
)
return topk_weights, topk_indices
@staticmethod @staticmethod
def biased_grouped_topk( def biased_grouped_topk(
gating_output: torch.Tensor, gating_output: torch.Tensor,
......
...@@ -2177,9 +2177,33 @@ def topk_softmax( ...@@ -2177,9 +2177,33 @@ def topk_softmax(
token_expert_indices: torch.Tensor, token_expert_indices: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
renormalize: bool = False, renormalize: bool = False,
e_score_correction_bias: torch.Tensor | None = None,
) -> None: ) -> None:
torch.ops._moe_C.topk_softmax( torch.ops._moe_C.topk_softmax(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize topk_weights,
topk_ids,
token_expert_indices,
gating_output,
renormalize,
e_score_correction_bias,
)
def topk_sigmoid(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
e_score_correction_bias: torch.Tensor | None = None,
) -> None:
torch.ops._moe_C.topk_sigmoid(
topk_weights,
topk_ids,
token_expert_indices,
gating_output,
renormalize,
e_score_correction_bias,
) )
......
...@@ -106,14 +106,14 @@ def _quant_flags_to_group_shape( ...@@ -106,14 +106,14 @@ def _quant_flags_to_group_shape(
class RoutingMethodType(IntEnum): class RoutingMethodType(IntEnum):
# Default: Softmax -> TopK # Default: Softmax -> TopK
Default = (0,) Default = (0,)
# Renormalize: TopK -> Softmax # Renormalize: TopK -> Softmax/Sigmoid
Renormalize = (1,) Renormalize = (1,)
# DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups # DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups
# -> Top8 experts from the Top4 groups # -> Top8 experts from the Top4 groups
DeepSeekV3 = (2,) DeepSeekV3 = (2,)
# Llama4: Top1 -> Sigmoid # Llama4: Top1 -> Sigmoid
Llama4 = (3,) Llama4 = (3,)
# RenormalizeNaive: Softmax -> TopK -> Renormalize # RenormalizeNaive: Softmax/Sigmoid -> TopK -> Renormalize
RenormalizeNaive = (4,) RenormalizeNaive = (4,)
# TopK: TopK (no softmax) # TopK: TopK (no softmax)
TopK = (5,) TopK = (5,)
......
...@@ -4,6 +4,8 @@ from collections.abc import Callable ...@@ -4,6 +4,8 @@ from collections.abc import Callable
import torch import torch
import vllm._custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
...@@ -12,15 +14,106 @@ from vllm.model_executor.layers.fused_moe.config import RoutingMethodType ...@@ -12,15 +14,106 @@ from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
def vllm_topk_softmax(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, ...]:
ops.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
e_score_correction_bias,
)
return topk_weights, topk_indices
def vllm_topk_sigmoid(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, ...]:
ops.topk_sigmoid(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
e_score_correction_bias,
)
return topk_weights, topk_indices
def fused_topk_bias( def fused_topk_bias(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor, e_score_correction_bias: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
scoring_func: str = "softmax",
indices_type: torch.dtype | None = None,
): ):
if not rocm_aiter_ops.is_fused_moe_enabled():
assert hidden_states.size(0) == gating_output.size(0), (
"Number of tokens mismatch"
)
M, _ = hidden_states.size()
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(
M,
topk,
dtype=torch.int32 if indices_type is None else indices_type,
device=hidden_states.device,
)
token_expert_indices = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
if scoring_func == "softmax":
topk_weights, topk_ids = vllm_topk_softmax(
topk_weights,
topk_ids,
token_expert_indices,
gating_output,
renormalize,
e_score_correction_bias,
)
return topk_weights, topk_ids
elif scoring_func == "sigmoid":
topk_weights, topk_ids = vllm_topk_sigmoid(
topk_weights,
topk_ids,
token_expert_indices,
gating_output,
renormalize,
e_score_correction_bias,
)
return topk_weights, topk_ids
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
n_routed_experts = gating_output.shape[-1] n_routed_experts = gating_output.shape[-1]
if scoring_func == "softmax":
scores = gating_output.softmax(dim=-1) scores = gating_output.softmax(dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
scores_for_choice = scores.view( scores_for_choice = scores.view(
-1, n_routed_experts -1, n_routed_experts
) + e_score_correction_bias.unsqueeze(0) ) + e_score_correction_bias.unsqueeze(0)
...@@ -43,6 +136,7 @@ class FusedTopKBiasRouter(BaseRouter): ...@@ -43,6 +136,7 @@ class FusedTopKBiasRouter(BaseRouter):
global_num_experts: int, global_num_experts: int,
eplb_state: EplbLayerState, eplb_state: EplbLayerState,
e_score_correction_bias: torch.Tensor, e_score_correction_bias: torch.Tensor,
scoring_func: str,
renormalize: bool = True, renormalize: bool = True,
routed_scaling_factor: float = 1.0, routed_scaling_factor: float = 1.0,
enable_eplb: bool = False, enable_eplb: bool = False,
...@@ -57,6 +151,7 @@ class FusedTopKBiasRouter(BaseRouter): ...@@ -57,6 +151,7 @@ class FusedTopKBiasRouter(BaseRouter):
) )
self.e_score_correction_bias = e_score_correction_bias self.e_score_correction_bias = e_score_correction_bias
self.renormalize = renormalize self.renormalize = renormalize
self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor self.routed_scaling_factor = routed_scaling_factor
@property @property
...@@ -80,6 +175,7 @@ class FusedTopKBiasRouter(BaseRouter): ...@@ -80,6 +175,7 @@ class FusedTopKBiasRouter(BaseRouter):
e_score_correction_bias=self.e_score_correction_bias.data, e_score_correction_bias=self.e_score_correction_bias.data,
topk=self.top_k, topk=self.top_k,
renormalize=self.renormalize, renormalize=self.renormalize,
scoring_func=self.scoring_func,
) )
if self.routed_scaling_factor != 1.0: if self.routed_scaling_factor != 1.0:
......
...@@ -16,7 +16,7 @@ def vllm_topk_softmax( ...@@ -16,7 +16,7 @@ def vllm_topk_softmax(
topk_indices: torch.Tensor, topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor, token_expert_indices: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
renormalize: bool, renormalize: bool = False,
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
ops.topk_softmax( ops.topk_softmax(
topk_weights, topk_weights,
...@@ -29,7 +29,25 @@ def vllm_topk_softmax( ...@@ -29,7 +29,25 @@ def vllm_topk_softmax(
return topk_weights, topk_indices return topk_weights, topk_indices
def dispatch_topk_func( def vllm_topk_sigmoid(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
) -> tuple[torch.Tensor, ...]:
ops.topk_sigmoid(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
)
return topk_weights, topk_indices
def dispatch_topk_softmax_func(
use_rocm_aiter: bool = False, use_rocm_aiter: bool = False,
) -> Callable[..., tuple[torch.Tensor, ...]]: ) -> Callable[..., tuple[torch.Tensor, ...]]:
if use_rocm_aiter: if use_rocm_aiter:
...@@ -37,12 +55,21 @@ def dispatch_topk_func( ...@@ -37,12 +55,21 @@ def dispatch_topk_func(
return vllm_topk_softmax return vllm_topk_softmax
def dispatch_topk_sigmoid_func(
use_rocm_aiter: bool = False,
) -> Callable[..., tuple[torch.Tensor, ...]]:
if use_rocm_aiter:
return rocm_aiter_ops.topk_sigmoid
return vllm_topk_sigmoid
def fused_topk( def fused_topk(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
indices_type: torch.dtype | None = None, indices_type: torch.dtype | None = None,
scoring_func: str = "softmax",
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
...@@ -61,12 +88,26 @@ def fused_topk( ...@@ -61,12 +88,26 @@ def fused_topk(
M, topk, dtype=torch.int32, device=hidden_states.device M, topk, dtype=torch.int32, device=hidden_states.device
) )
topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()) if scoring_func == "softmax":
topk_func = dispatch_topk_softmax_func(
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
)
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
)
return topk_weights, topk_ids, token_expert_indices
elif scoring_func == "sigmoid":
topk_func = dispatch_topk_sigmoid_func(
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
)
topk_weights, topk_ids = topk_func( topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
) )
return topk_weights, topk_ids, token_expert_indices return topk_weights, topk_ids, token_expert_indices
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
class FusedTopKRouter(BaseRouter): class FusedTopKRouter(BaseRouter):
...@@ -82,7 +123,6 @@ class FusedTopKRouter(BaseRouter): ...@@ -82,7 +123,6 @@ class FusedTopKRouter(BaseRouter):
enable_eplb: bool = False, enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None, indices_type_getter: Callable[[], torch.dtype | None] | None = None,
): ):
assert scoring_func == "softmax", "FusedTopKRouter only supports softmax."
super().__init__( super().__init__(
top_k=top_k, top_k=top_k,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
...@@ -91,6 +131,7 @@ class FusedTopKRouter(BaseRouter): ...@@ -91,6 +131,7 @@ class FusedTopKRouter(BaseRouter):
indices_type_getter=indices_type_getter, indices_type_getter=indices_type_getter,
) )
self.renormalize = renormalize self.renormalize = renormalize
self.scoring_func = scoring_func
@property @property
def routing_method_type(self) -> RoutingMethodType: def routing_method_type(self) -> RoutingMethodType:
...@@ -113,6 +154,7 @@ class FusedTopKRouter(BaseRouter): ...@@ -113,6 +154,7 @@ class FusedTopKRouter(BaseRouter):
topk=self.top_k, topk=self.top_k,
renormalize=self.renormalize, renormalize=self.renormalize,
indices_type=indices_type, indices_type=indices_type,
scoring_func=self.scoring_func,
) )
return topk_weights, topk_ids return topk_weights, topk_ids
...@@ -143,17 +143,13 @@ def create_fused_moe_router( ...@@ -143,17 +143,13 @@ def create_fused_moe_router(
router.capture = capture router.capture = capture
return router return router
if scoring_func != "softmax":
raise ValueError(
"Only softmax scoring function is supported for non-grouped topk."
)
if e_score_correction_bias is not None: if e_score_correction_bias is not None:
router = FusedTopKBiasRouter( router = FusedTopKBiasRouter(
top_k=top_k, top_k=top_k,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
eplb_state=eplb_state, eplb_state=eplb_state,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
scoring_func=scoring_func,
renormalize=renormalize, renormalize=renormalize,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
enable_eplb=enable_eplb, enable_eplb=enable_eplb,
......
...@@ -100,9 +100,6 @@ class MiniMaxM2MoE(nn.Module): ...@@ -100,9 +100,6 @@ class MiniMaxM2MoE(nn.Module):
num_experts=config.num_local_experts, num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
scoring_func=config.scoring_func, scoring_func=config.scoring_func,
use_grouped_topk=True,
num_expert_group=1,
topk_group=1,
e_score_correction_bias=self.e_score_correction_bias, e_score_correction_bias=self.e_score_correction_bias,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment