Unverified Commit 2998c4bd authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

[optimize] fuse renormalize into moe_topk_softmax (#7744)


Co-authored-by: default avatarispobock <ispobaoke@gmail.com>
parent 6840a7bb
...@@ -34,14 +34,10 @@ def sglang_topk_softmax(gating_output, topk): ...@@ -34,14 +34,10 @@ def sglang_topk_softmax(gating_output, topk):
topk_indices = torch.empty( topk_indices = torch.empty(
(num_tokens, topk), dtype=torch.int32, device=gating_output.device (num_tokens, topk), dtype=torch.int32, device=gating_output.device
) )
token_expert_indices = torch.empty(
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
)
topk_softmax( topk_softmax(
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_indices, topk_ids=topk_indices,
token_expert_indices=token_expert_indices,
gating_output=gating_output, gating_output=gating_output,
) )
......
...@@ -169,9 +169,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -169,9 +169,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"pad_sorted_token_ids) -> ()"); "pad_sorted_token_ids) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
m.def( m.def("topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize) -> ()");
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax); m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
m.def( m.def(
......
...@@ -41,15 +41,29 @@ template < ...@@ -41,15 +41,29 @@ template <
/// Alignment requirement in bytes /// Alignment requirement in bytes
int Alignment = sizeof(T) * N> int Alignment = sizeof(T) * N>
class alignas(Alignment) AlignedArray { class alignas(Alignment) AlignedArray {
float data[N]; T data[N];
}; };
// ========================== Util functions to convert types ==========================
template <typename T>
__device__ float convert_to_float(T x) {
if constexpr (std::is_same_v<T, __half>) {
return __half2float(x);
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
return __bfloat162float(x);
} else if constexpr (std::is_same_v<T, float>) {
return x;
} else {
return static_cast<float>(x);
}
}
// ====================== Softmax things =============================== // ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing the output // We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing. // in the softmax kernel when we extend this module to support expert-choice routing.
template <int TPB> template <typename T, int TPB>
__launch_bounds__(TPB) __global__ __launch_bounds__(TPB) __global__
void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols) { void moeSoftmax(const T* input, const bool* finished, float* output, const int num_cols) {
using BlockReduce = cub::BlockReduce<float, TPB>; using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage; __shared__ typename BlockReduce::TempStorage tmpStorage;
...@@ -68,7 +82,7 @@ __launch_bounds__(TPB) __global__ ...@@ -68,7 +82,7 @@ __launch_bounds__(TPB) __global__
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii; const int idx = thread_row_offset + ii;
threadData = max(static_cast<float>(input[idx]), threadData); threadData = max(convert_to_float<T>(input[idx]), threadData);
} }
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
...@@ -82,7 +96,7 @@ __launch_bounds__(TPB) __global__ ...@@ -82,7 +96,7 @@ __launch_bounds__(TPB) __global__
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii; const int idx = thread_row_offset + ii;
threadData += exp((static_cast<float>(input[idx]) - float_max)); threadData += exp((convert_to_float<T>(input[idx]) - float_max));
} }
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
...@@ -94,7 +108,7 @@ __launch_bounds__(TPB) __global__ ...@@ -94,7 +108,7 @@ __launch_bounds__(TPB) __global__
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii; const int idx = thread_row_offset + ii;
const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor; const float val = exp((convert_to_float<T>(input[idx]) - float_max)) * normalizing_factor;
output[idx] = val; output[idx] = val;
} }
} }
...@@ -105,11 +119,11 @@ __launch_bounds__(TPB) __global__ void moeTopK( ...@@ -105,11 +119,11 @@ __launch_bounds__(TPB) __global__ void moeTopK(
const bool* finished, const bool* finished,
float* output, float* output,
int* indices, int* indices,
int* source_rows,
const int num_experts, const int num_experts,
const int k, const int k,
const int start_expert, const int start_expert,
const int end_expert) { const int end_expert,
const bool renormalize) {
using cub_kvp = cub::KeyValuePair<int, float>; using cub_kvp = cub::KeyValuePair<int, float>;
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>; using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage; __shared__ typename BlockReduce::TempStorage tmpStorage;
...@@ -117,11 +131,11 @@ __launch_bounds__(TPB) __global__ void moeTopK( ...@@ -117,11 +131,11 @@ __launch_bounds__(TPB) __global__ void moeTopK(
cub_kvp thread_kvp; cub_kvp thread_kvp;
cub::ArgMax arg_max; cub::ArgMax arg_max;
const int num_rows = gridDim.x;
const int block_row = blockIdx.x; const int block_row = blockIdx.x;
const bool row_is_active = finished ? !finished[block_row] : true; const bool row_is_active = finished ? !finished[block_row] : true;
const int thread_read_offset = blockIdx.x * num_experts; const int thread_read_offset = blockIdx.x * num_experts;
float row_sum_for_renormalize = 0;
for (int k_idx = 0; k_idx < k; ++k_idx) { for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0; thread_kvp.key = 0;
thread_kvp.value = -1.f; // This is OK because inputs are probabilities thread_kvp.value = -1.f; // This is OK because inputs are probabilities
...@@ -154,10 +168,18 @@ __launch_bounds__(TPB) __global__ void moeTopK( ...@@ -154,10 +168,18 @@ __launch_bounds__(TPB) __global__ void moeTopK(
output[idx] = result_kvp.value; output[idx] = result_kvp.value;
indices[idx] = should_process_row ? (expert - start_expert) : num_experts; indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
assert(indices[idx] >= 0); assert(indices[idx] >= 0);
source_rows[idx] = k_idx * num_rows + block_row; row_sum_for_renormalize += result_kvp.value;
} }
__syncthreads(); __syncthreads();
} }
if (renormalize && threadIdx.x == 0) {
float row_sum_for_renormalize_inv = 1.f / row_sum_for_renormalize;
for (int k_idx = 0; k_idx < k; ++k_idx) {
const int idx = k * block_row + k_idx;
output[idx] = output[idx] * row_sum_for_renormalize_inv;
}
}
} }
// ====================== TopK softmax things =============================== // ====================== TopK softmax things ===============================
...@@ -174,17 +196,17 @@ __launch_bounds__(TPB) __global__ void moeTopK( ...@@ -174,17 +196,17 @@ __launch_bounds__(TPB) __global__ void moeTopK(
2) This implementation assumes k is small, but will work for any k. 2) This implementation assumes k is small, but will work for any k.
*/ */
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG> template <typename T, int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax( __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
const float* input, const T* input,
const bool* finished, const bool* finished,
float* output, float* output,
const int num_rows, const int num_rows,
int* indices, int* indices,
int* source_rows,
const int k, const int k,
const int start_expert, const int start_expert,
const int end_expert) { const int end_expert,
const bool renormalize) {
// We begin by enforcing compile time assertions and setting up compile time constants. // We begin by enforcing compile time assertions and setting up compile time constants.
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
...@@ -192,7 +214,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax( ...@@ -192,7 +214,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
// Number of bytes each thread pulls in per load // Number of bytes each thread pulls in per load
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
static constexpr int ELTS_PER_ROW = NUM_EXPERTS; static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
...@@ -233,28 +255,34 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax( ...@@ -233,28 +255,34 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
// row it will read. // row it will read.
const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW; const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
// Now, we compute the group each thread belong to in order to determine the first column to start loads. // Now, we compute the group each thread belong to in order to determine the first column to start loads.
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
// this can support all powers of 2 up to 16. // this can support all powers of 2 up to 16.
// NOTE(woosuk): The original implementation uses CUTLASS aligned array here. // NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
// We defined our own aligned array and use it here to avoid the dependency on CUTLASS. // We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
using AccessType = AlignedArray<float, ELTS_PER_LDG>; using AccessType = AlignedArray<T, ELTS_PER_LDG>;
// Finally, we pull in the data from global mem // Finally, we pull in the data from global mem
float row_chunk[VPT]; T row_chunk_temp[VPT];
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk); AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk_temp);
const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr); const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
#pragma unroll #pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
} }
float row_chunk[VPT];
#pragma unroll
for (int ii = 0; ii < VPT; ++ii) {
row_chunk[ii] = convert_to_float<T>(row_chunk_temp[ii]);
}
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
// convert to float afterwards for the exp + sum reduction. // convert to float afterwards for the exp + sum reduction.
float thread_max = row_chunk[0]; float thread_max = row_chunk[0];
...@@ -301,6 +329,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax( ...@@ -301,6 +329,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
int start_col = first_elt_read_by_thread; int start_col = first_elt_read_by_thread;
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
float row_sum_for_renormalize = 0;
for (int k_idx = 0; k_idx < k; ++k_idx) { for (int k_idx = 0; k_idx < k; ++k_idx) {
// First, each thread does the local argmax // First, each thread does the local argmax
float max_val = row_chunk[0]; float max_val = row_chunk[0];
...@@ -346,7 +376,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax( ...@@ -346,7 +376,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
const int idx = k * thread_row + k_idx; const int idx = k * thread_row + k_idx;
output[idx] = max_val; output[idx] = max_val;
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS; indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
source_rows[idx] = k_idx * num_rows + thread_row; row_sum_for_renormalize += max_val;
} }
// Finally, we clear the value in the thread with the current max if there is another iteration to run. // Finally, we clear the value in the thread with the current max if there is another iteration to run.
...@@ -362,13 +392,23 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax( ...@@ -362,13 +392,23 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
} }
} }
} }
// Fuse renormalization of topk_weights into this kernel
if (renormalize && thread_group_idx == 0) {
float row_sum_for_renormalize_inv = 1.f / row_sum_for_renormalize;
#pragma unroll
for (int k_idx = 0; k_idx < k; ++k_idx) {
const int idx = k * thread_row + k_idx;
output[idx] = output[idx] * row_sum_for_renormalize_inv;
}
}
} }
namespace detail { namespace detail {
// Constructs some constants needed to partition the work across threads at compile time. // Constructs some constants needed to partition the work across threads at compile time.
template <int EXPERTS, int BYTES_PER_LDG> template <typename T, int EXPERTS, int BYTES_PER_LDG>
struct TopkConstants { struct TopkConstants {
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
...@@ -377,100 +417,84 @@ struct TopkConstants { ...@@ -377,100 +417,84 @@ struct TopkConstants {
}; };
} // namespace detail } // namespace detail
template <int EXPERTS, int WARPS_PER_TB> template <typename T, int EXPERTS, int WARPS_PER_TB>
void topkGatingSoftmaxLauncherHelper( void topkGatingSoftmaxLauncherHelper(
const float* input, const T* input,
const bool* finished, const bool* finished,
float* output, float* output,
int* indices, int* indices,
int* source_row,
const int num_rows, const int num_rows,
const int k, const int k,
const int start_expert, const int start_expert,
const int end_expert, const int end_expert,
const bool renormalize,
cudaStream_t stream) { cudaStream_t stream) {
static constexpr std::size_t MAX_BYTES_PER_LDG = 16; static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(T) * EXPERTS);
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>; using Constants = detail::TopkConstants<T, EXPERTS, BYTES_PER_LDG>;
static constexpr int VPT = Constants::VPT; static constexpr int VPT = Constants::VPT;
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
dim3 block_dim(WARP_SIZE, WARPS_PER_TB); dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>( topkGatingSoftmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert); input, finished, output, num_rows, indices, k, start_expert, end_expert, renormalize);
} }
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \ #define LAUNCH_SOFTMAX(TYPE, NUM_EXPERTS, WARPS_PER_TB) \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \ topkGatingSoftmaxLauncherHelper<TYPE, NUM_EXPERTS, WARPS_PER_TB>( \
gating_output, \ gating_output, nullptr, topk_weights, topk_indices, num_tokens, topk, 0, num_experts, renormalize, stream);
nullptr, \
topk_weights, \
topk_indices, \
token_expert_indices, \
num_tokens, \
topk, \
0, \
num_experts, \
stream);
template <typename T>
void topkGatingSoftmaxKernelLauncher( void topkGatingSoftmaxKernelLauncher(
const float* gating_output, const T* gating_output,
float* topk_weights, float* topk_weights,
int* topk_indices, int* topk_indices,
int* token_expert_indices,
float* softmax_workspace, float* softmax_workspace,
const int num_tokens, const int num_tokens,
const int num_experts, const int num_experts,
const int topk, const int topk,
const bool renormalize,
cudaStream_t stream) { cudaStream_t stream) {
static constexpr int WARPS_PER_TB = 4; static constexpr int WARPS_PER_TB = 4;
switch (num_experts) { switch (num_experts) {
case 1: case 1:
LAUNCH_SOFTMAX(1, WARPS_PER_TB); LAUNCH_SOFTMAX(T, 1, WARPS_PER_TB);
break; break;
case 2: case 2:
LAUNCH_SOFTMAX(2, WARPS_PER_TB); LAUNCH_SOFTMAX(T, 2, WARPS_PER_TB);
break; break;
case 4: case 4:
LAUNCH_SOFTMAX(4, WARPS_PER_TB); LAUNCH_SOFTMAX(T, 4, WARPS_PER_TB);
break; break;
case 8: case 8:
LAUNCH_SOFTMAX(8, WARPS_PER_TB); LAUNCH_SOFTMAX(T, 8, WARPS_PER_TB);
break; break;
case 16: case 16:
LAUNCH_SOFTMAX(16, WARPS_PER_TB); LAUNCH_SOFTMAX(T, 16, WARPS_PER_TB);
break; break;
case 32: case 32:
LAUNCH_SOFTMAX(32, WARPS_PER_TB); LAUNCH_SOFTMAX(T, 32, WARPS_PER_TB);
break; break;
case 64: case 64:
LAUNCH_SOFTMAX(64, WARPS_PER_TB); LAUNCH_SOFTMAX(T, 64, WARPS_PER_TB);
break; break;
case 128: case 128:
LAUNCH_SOFTMAX(128, WARPS_PER_TB); LAUNCH_SOFTMAX(T, 128, WARPS_PER_TB);
break; break;
case 256: case 256:
LAUNCH_SOFTMAX(256, WARPS_PER_TB); LAUNCH_SOFTMAX(T, 256, WARPS_PER_TB);
break; break;
default: { default: {
TORCH_CHECK( TORCH_CHECK(
softmax_workspace != nullptr, softmax_workspace != nullptr,
"softmax_workspace must be provided for num_experts that are not a power of 2."); "softmax_workspace must be provided for num_experts that are not a power of 2.");
static constexpr int TPB = 256; static constexpr int TPB = 256;
moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(gating_output, nullptr, softmax_workspace, num_experts); moeSoftmax<T, TPB><<<num_tokens, TPB, 0, stream>>>(gating_output, nullptr, softmax_workspace, num_experts);
moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>( moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
softmax_workspace, softmax_workspace, nullptr, topk_weights, topk_indices, num_experts, topk, 0, num_experts, renormalize);
nullptr,
topk_weights,
topk_indices,
token_expert_indices,
num_experts,
topk,
0,
num_experts);
} }
} }
} }
...@@ -478,12 +502,35 @@ void topkGatingSoftmaxKernelLauncher( ...@@ -478,12 +502,35 @@ void topkGatingSoftmaxKernelLauncher(
void topk_softmax( void topk_softmax(
torch::Tensor& topk_weights, // [num_tokens, topk] torch::Tensor& topk_weights, // [num_tokens, topk]
torch::Tensor& topk_indices, // [num_tokens, topk] torch::Tensor& topk_indices, // [num_tokens, topk]
torch::Tensor& token_expert_indices, // [num_tokens, topk] torch::Tensor& gating_output,
torch::Tensor& gating_output) // [num_tokens, num_experts] const bool renormalize) // [num_tokens, num_experts]
{ {
const int num_experts = gating_output.size(-1); // Check data type
const int num_tokens = gating_output.numel() / num_experts; TORCH_CHECK(
const int topk = topk_weights.size(-1); gating_output.scalar_type() == at::ScalarType::Float || gating_output.scalar_type() == at::ScalarType::Half ||
gating_output.scalar_type() == at::ScalarType::BFloat16,
"gating_output must be float32, float16, or bfloat16");
// Check dimensions
TORCH_CHECK(gating_output.dim() == 2, "gating_output must be 2D tensor [num_tokens, num_experts]");
TORCH_CHECK(topk_weights.dim() == 2, "topk_weights must be 2D tensor [num_tokens, topk]");
TORCH_CHECK(topk_indices.dim() == 2, "topk_indices must be 2D tensor [num_tokens, topk]");
// Check shapes
TORCH_CHECK(
gating_output.size(0) == topk_weights.size(0),
"First dimension of topk_weights must match num_tokens in gating_output");
TORCH_CHECK(
gating_output.size(0) == topk_indices.size(0),
"First dimension of topk_indices must match num_tokens in gating_output");
TORCH_CHECK(
topk_weights.size(-1) == topk_indices.size(-1),
"Second dimension of topk_indices must match topk in topk_weights");
TORCH_CHECK(topk_weights.size(-1) <= gating_output.size(-1), "topk must be less than or equal to num_experts");
const int num_experts = static_cast<int>(gating_output.size(-1));
const int num_tokens = static_cast<int>(gating_output.size(0));
const int topk = static_cast<int>(topk_weights.size(-1));
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
const bool needs_workspace = !is_pow_2 || num_experts > 256; const bool needs_workspace = !is_pow_2 || num_experts > 256;
...@@ -491,15 +538,44 @@ void topk_softmax( ...@@ -491,15 +538,44 @@ void topk_softmax(
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options()); torch::Tensor softmax_workspace =
topkGatingSoftmaxKernelLauncher( torch::empty({workspace_size}, gating_output.options().dtype(at::ScalarType::Float));
const at::ScalarType dtype = gating_output.scalar_type();
if (dtype == at::ScalarType::Float) {
topkGatingSoftmaxKernelLauncher<float>(
gating_output.data_ptr<float>(), gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(), topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(), topk_indices.data_ptr<int>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(), softmax_workspace.data_ptr<float>(),
num_tokens, num_tokens,
num_experts, num_experts,
topk, topk,
renormalize,
stream); stream);
} else if (dtype == at::ScalarType::Half) {
topkGatingSoftmaxKernelLauncher<__half>(
reinterpret_cast<const __half*>(gating_output.data_ptr<at::Half>()),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
renormalize,
stream);
} else if (dtype == at::ScalarType::BFloat16) {
topkGatingSoftmaxKernelLauncher<__nv_bfloat16>(
reinterpret_cast<const __nv_bfloat16*>(gating_output.data_ptr<at::BFloat16>()),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
renormalize,
stream);
} else {
TORCH_CHECK(false, "Unsupported gating_output dtype: ", dtype);
}
} }
...@@ -63,9 +63,7 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -63,9 +63,7 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"pad_sorted_token_ids) -> ()"); "pad_sorted_token_ids) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
m.def( m.def("topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize) -> ()");
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax); m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
/* /*
......
...@@ -222,10 +222,7 @@ void moe_align_block_size( ...@@ -222,10 +222,7 @@ void moe_align_block_size(
bool pad_sorted_token_ids); bool pad_sorted_token_ids);
void topk_softmax( void topk_softmax(
torch::Tensor& topk_weights, torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& gating_output, bool renormalize);
torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output);
std::vector<at::Tensor> moe_fused_gate( std::vector<at::Tensor> moe_fused_gate(
at::Tensor& input, at::Tensor& input,
......
...@@ -30,11 +30,11 @@ def moe_align_block_size( ...@@ -30,11 +30,11 @@ def moe_align_block_size(
def topk_softmax( def topk_softmax(
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: float, gating_output: float,
renormalize: bool = False,
) -> None: ) -> None:
torch.ops.sgl_kernel.topk_softmax.default( torch.ops.sgl_kernel.topk_softmax.default(
topk_weights, topk_ids, token_expert_indices, gating_output topk_weights, topk_ids, gating_output, renormalize
) )
......
...@@ -22,14 +22,10 @@ def test_topk_softmax(num_tokens, num_experts, topk): ...@@ -22,14 +22,10 @@ def test_topk_softmax(num_tokens, num_experts, topk):
topk_weights = torch.empty((num_tokens, topk), dtype=torch.float32, device="cuda") topk_weights = torch.empty((num_tokens, topk), dtype=torch.float32, device="cuda")
topk_indices = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda") topk_indices = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda")
token_expert_indices = torch.empty(
(num_tokens, topk), dtype=torch.int32, device="cuda"
)
topk_softmax( topk_softmax(
topk_weights, topk_weights,
topk_indices, topk_indices,
token_expert_indices,
gating_output, gating_output,
) )
...@@ -47,5 +43,97 @@ def test_topk_softmax(num_tokens, num_experts, topk): ...@@ -47,5 +43,97 @@ def test_topk_softmax(num_tokens, num_experts, topk):
), f"Indices mismatch: torch={topk_indices_ref}, SGLang={topk_indices}" ), f"Indices mismatch: torch={topk_indices_ref}, SGLang={topk_indices}"
@pytest.mark.parametrize(
"num_tokens, num_experts, topk, dtype",
list(
itertools.product(
[1, 16, 128, 512, 1024, 2048], # num_tokens
[4, 8, 16, 32, 64, 128, 256], # num_experts
[1, 2, 4], # topk
[torch.float16, torch.bfloat16, torch.float32], # dtype
)
),
)
def test_topk_softmax_dtype_regression(num_tokens, num_experts, topk, dtype):
gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda")
topk_weights = torch.empty((num_tokens, topk), dtype=torch.float32, device="cuda")
topk_indices = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda")
topk_softmax(
topk_weights,
topk_indices,
gating_output,
)
topk_weights_ref = torch.empty(
(num_tokens, topk), dtype=torch.float32, device="cuda"
)
topk_indices_ref = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda")
topk_softmax(
topk_weights_ref,
topk_indices_ref,
gating_output.float(),
)
assert torch.allclose(
topk_weights_ref, topk_weights, atol=1e-3, rtol=1e-3
), f"Weights mismatch: SGLang old interface={topk_indices_ref} vs SGLang new interface={topk_weights}"
assert torch.allclose(
topk_indices_ref.int(), topk_indices, atol=0, rtol=0
), f"Indices mismatch: SGLang old interface={topk_indices_ref}, SGLang new interface={topk_indices}"
@pytest.mark.parametrize(
"num_tokens, num_experts, topk",
list(
itertools.product(
[1, 16, 128, 512, 1024, 2048], # num_tokens
[4, 8, 16, 32, 64, 128, 256], # num_experts
[1, 2, 4], # topk
)
),
)
def test_topk_softmax_renormalize(num_tokens, num_experts, topk):
gating_output = torch.randn(
(num_tokens, num_experts), dtype=torch.bfloat16, device="cuda"
)
topk_weights = torch.empty((num_tokens, topk), dtype=torch.float32, device="cuda")
topk_indices = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda")
topk_softmax(
topk_weights,
topk_indices,
gating_output,
renormalize=True,
)
topk_weights_ref = torch.empty(
(num_tokens, topk), dtype=torch.float32, device="cuda"
)
topk_indices_ref = torch.empty((num_tokens, topk), dtype=torch.int32, device="cuda")
token_expert_indices_ref = torch.empty(
(num_tokens, topk), dtype=torch.int32, device="cuda"
)
topk_softmax(
topk_weights_ref,
topk_indices_ref,
gating_output,
)
topk_weights_ref = topk_weights_ref / topk_weights_ref.sum(dim=-1, keepdim=True)
assert torch.allclose(
topk_weights_ref, topk_weights, atol=1e-3, rtol=1e-3
), f"Weights mismatch: SGLang w/o fused renormalize={topk_indices_ref} vs SGLang w/ fused renormalize={topk_weights}"
assert torch.allclose(
topk_indices_ref.int(), topk_indices, atol=0, rtol=0
), f"Indices mismatch: SGLang w/o fused renormalize={topk_indices_ref}, SGLang w/ fused renormalize={topk_indices}"
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment