Unverified Commit f28125d8 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Perf] Optimize grouped topk kernel, 1.2%~2% E2E Throughput improvement (#32058)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 46f8c6b7
...@@ -31,8 +31,6 @@ namespace moe { ...@@ -31,8 +31,6 @@ namespace moe {
constexpr unsigned FULL_WARP_MASK = 0xffffffff; constexpr unsigned FULL_WARP_MASK = 0xffffffff;
constexpr int32_t WARP_SIZE = 32; constexpr int32_t WARP_SIZE = 32;
constexpr int32_t BLOCK_SIZE = 512;
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
namespace warp_topk { namespace warp_topk {
...@@ -65,14 +63,6 @@ __forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index, ...@@ -65,14 +63,6 @@ __forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index,
return res; return res;
} }
template <typename T, typename idxT>
int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k;
int64_t n = std::max<int>(num_of_warp / 2 * k, num_of_warp * WARP_SIZE);
return max(cache_topk,
round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
}
template <int size, bool ascending, bool reverse, typename T, typename idxT, template <int size, bool ascending, bool reverse, typename T, typename idxT,
bool is_stable> bool is_stable>
struct BitonicMerge { struct BitonicMerge {
...@@ -267,6 +257,15 @@ class WarpSort { ...@@ -267,6 +257,15 @@ class WarpSort {
} }
} }
// Accessors for per-lane selected value/index.
// NOTE: For the common case `capacity == WARP_SIZE`, `max_arr_len_ == 1`
// and callers should use `i == 0`.
__device__ __forceinline__ idxT get_idx(int i = 0) const {
return idx_arr_[i];
}
__device__ __forceinline__ T get_val(int i = 0) const { return val_arr_[i]; }
protected: protected:
static constexpr int max_arr_len_ = capacity / WARP_SIZE; static constexpr int max_arr_len_ = capacity / WARP_SIZE;
...@@ -285,6 +284,7 @@ class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> { ...@@ -285,6 +284,7 @@ class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
__device__ WarpSelect(idxT k, T dummy) __device__ WarpSelect(idxT k, T dummy)
: WarpSort<capacity, greater, T, idxT, is_stable>(k, dummy), : WarpSort<capacity, greater, T, idxT, is_stable>(k, dummy),
k_th_(dummy), k_th_(dummy),
k_th_idx_(0),
k_th_lane_((k - 1) % WARP_SIZE) { k_th_lane_((k - 1) % WARP_SIZE) {
extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[]; extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[];
...@@ -346,9 +346,6 @@ class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> { ...@@ -346,9 +346,6 @@ class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0; idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0;
merge_buf_(val, idx); merge_buf_(val, idx);
} }
// after done(), smem is used for merging results among warps
__syncthreads();
} }
private: private:
...@@ -503,255 +500,186 @@ __device__ void topk_with_k2(T* output, T const* input, BiasT const* bias, ...@@ -503,255 +500,186 @@ __device__ void topk_with_k2(T* output, T const* input, BiasT const* bias,
} }
} }
template <typename T, typename BiasT, ScoringFunc SF> template <typename T, typename BiasT, typename IdxT, ScoringFunc SF>
__global__ void topk_with_k2_kernel(T* output, T* input, BiasT const* bias, __global__ void grouped_topk_fused_kernel(
int64_t const num_tokens, T* scores, float* topk_values, IdxT* topk_indices, BiasT const* bias,
int64_t const num_cases, int64_t const num_tokens, int64_t const num_experts, int64_t const n_group,
int64_t const n_group, int64_t const topk_group, int64_t const topk, bool renormalize,
int64_t const num_experts_per_group) { double routed_scaling_factor) {
int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t const token_id = static_cast<int32_t>(blockIdx.x);
int32_t lane_id = threadIdx.x % WARP_SIZE; if (token_id >= num_tokens) {
return;
int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; }
if (case_id < num_cases) {
input += case_id * num_experts_per_group;
// bias is per expert group, offset to current group
int32_t group_id = case_id % n_group;
BiasT const* group_bias = bias + group_id * num_experts_per_group;
output += case_id;
cg::thread_block block = cg::this_thread_block(); int32_t const warp_id = threadIdx.x / WARP_SIZE;
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); int32_t const lane_id = threadIdx.x % WARP_SIZE;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) int32_t const n_group_i32 = static_cast<int32_t>(n_group);
asm volatile("griddepcontrol.wait;"); int32_t const topk_group_i32 = static_cast<int32_t>(topk_group);
#endif int32_t const topk_i32 = static_cast<int32_t>(topk);
topk_with_k2<T, BiasT, SF>(output, input, group_bias, tile, lane_id, int32_t const num_experts_i32 = static_cast<int32_t>(num_experts);
num_experts_per_group);
int32_t const num_warps = blockDim.x / WARP_SIZE;
if (warp_id >= n_group_i32 || num_warps < n_group_i32) {
return;
} }
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
template <typename T, typename BiasT, typename IdxT, ScoringFunc SF, int32_t const num_experts_per_group = num_experts_i32 / n_group_i32;
int NGroup = -1>
__global__ void group_idx_and_topk_idx_kernel( T* scores_token = scores + static_cast<int64_t>(token_id) * num_experts;
T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices,
BiasT const* bias, int64_t const num_tokens, int64_t const n_group,
int64_t const topk_group, int64_t const topk, int64_t const num_experts,
int64_t const num_experts_per_group, bool renormalize,
double routed_scaling_factor) {
int32_t warp_id = threadIdx.x / WARP_SIZE;
int32_t lane_id = threadIdx.x % WARP_SIZE;
int32_t case_id =
blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
scores += case_id * num_experts;
group_scores += case_id * n_group;
topk_values += case_id * topk;
topk_indices += case_id * topk;
constexpr bool kUseStaticNGroup = (NGroup > 0);
// use int32 to avoid implicit conversion
int32_t const n_group_i32 =
kUseStaticNGroup ? NGroup : static_cast<int32_t>(n_group);
int32_t align_num_experts_per_group =
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
cg::thread_block block = cg::this_thread_block(); cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to extern __shared__ char smem_buf[];
// store the target topk idx // warpSelect internal staging buffer layout
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf); size_t const val_bytes =
T* s_topk_value = static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(T);
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) + size_t const val_bytes_aligned =
warp_id * topk; warp_topk::round_up_to_multiple_of<256>(val_bytes);
s_topk_idx += warp_id * topk; size_t const idx_bytes =
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(int32_t);
T value = neg_inf<T>(); size_t const internal_bytes = val_bytes_aligned + idx_bytes;
T topk_group_value = neg_inf<T>();
int32_t num_equalto_topkth_group; // user-managed shared memory starts after warpSelect internal staging.
uintptr_t ptr_u = reinterpret_cast<uintptr_t>(smem_buf + internal_bytes);
ptr_u = (ptr_u + 15) & ~static_cast<uintptr_t>(15); // align to 16B
T* s_group_scores = reinterpret_cast<T*>(ptr_u);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before
// acqbulk because it's ptr arithmetic // acqbulk because it's ptr arithmetic
#endif #endif
if (case_id < num_tokens) { // phase 1: per-group scan
// calculate group_idx int32_t const group_offset = warp_id * num_experts_per_group;
int32_t target_num_min = topk_with_k2<T, BiasT, SF>(s_group_scores + warp_id,
WARP_SIZE - n_group_i32 + static_cast<int32_t>(topk_group); scores_token + group_offset, bias + group_offset,
// The check is necessary to avoid abnormal input tile, lane_id, num_experts_per_group);
if (lane_id < n_group_i32 && is_finite(group_scores[lane_id])) {
value = group_scores[lane_id]; __syncthreads();
// phase 2: warp0 selects groups + merges candidates to final topk
if (warp_id != 0) {
return;
} }
int count_equal_to_top_value = WARP_SIZE - n_group_i32; topk_values += static_cast<int64_t>(token_id) * topk;
int pre_count_equal_to_top_value = 0; topk_indices += static_cast<int64_t>(token_id) * topk;
// Use loop to find the largset top_group
while (count_equal_to_top_value < target_num_min) { // select topk_group groups by group score
topk_group_value = cg::reduce(tile, value, cg::greater<T>()); warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t,
if (value == topk_group_value) { /* is_stable */ true>
value = neg_inf<T>(); group_sel(static_cast<int32_t>(topk_group_i32), neg_inf<T>());
// all lanes must participate in WarpSelect::add().
T gscore = (lane_id < n_group_i32) ? s_group_scores[lane_id] : neg_inf<T>();
group_sel.add(gscore, lane_id);
group_sel.done();
// proceed only if the k-th selected group score is not -inf
bool proceed = false;
if (topk_group_i32 > 0) {
int const kth_lane = topk_group_i32 - 1;
// broadcast the k-th selected group score to all lanes
T kth_val = __shfl_sync(FULL_WARP_MASK, group_sel.get_val(0), kth_lane);
proceed = (kth_val != neg_inf<T>());
} }
pre_count_equal_to_top_value = count_equal_to_top_value;
count_equal_to_top_value = if (!proceed) {
__popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf<T>()))); for (int i = lane_id; i < topk_i32; i += WARP_SIZE) {
topk_indices[i] = static_cast<IdxT>(i);
topk_values[i] = 1.0f / static_cast<float>(topk_i32);
} }
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
return;
} }
__syncthreads();
// merge per-group topk candidates for selected groups, then select topk
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t, warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t,
/* is_stable */ true> /* is_stable */ true>
queue((int32_t)topk, neg_inf<T>()); expert_sel(static_cast<int32_t>(topk_i32), neg_inf<T>());
int count_equalto_topkth_group = 0; // selected group ids reside in lanes [0, topk_group)
bool if_proceed_next_topk = topk_group_value != neg_inf<T>(); int32_t sel_gid_lane = (lane_id < topk_group_i32) ? group_sel.get_idx(0) : 0;
if (case_id < num_tokens && if_proceed_next_topk) {
auto process_group = [&](int i_group) { // add candidates from selected groups to expert_sel
if ((group_scores[i_group] > topk_group_value) || for (int32_t g = 0; g < topk_group_i32; ++g) {
((group_scores[i_group] == topk_group_value) && int32_t gid = __shfl_sync(FULL_WARP_MASK, sel_gid_lane, g);
(count_equalto_topkth_group < num_equalto_topkth_group))) { int32_t const offset = gid * num_experts_per_group;
int32_t offset = i_group * num_experts_per_group; int32_t const align_num_experts_per_group =
for (int32_t i = lane_id; i < align_num_experts_per_group; warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
i += WARP_SIZE) { for (int32_t i = lane_id; i < align_num_experts_per_group; i += WARP_SIZE) {
T candidates = neg_inf<T>(); // all lanes must call `add()` the same number of times.
T cand = neg_inf<T>();
int32_t idx = 0;
if (i < num_experts_per_group) { if (i < num_experts_per_group) {
// apply scoring function (if any) and add bias idx = offset + i;
T input = scores[offset + i]; T input = scores_token[idx];
if (is_finite(input)) { if (is_finite(input)) {
T score = apply_scoring<SF>(input); T score = apply_scoring<SF>(input);
candidates = score + static_cast<T>(bias[offset + i]); cand = score + static_cast<T>(bias[idx]);
}
} }
queue.add(candidates, offset + i);
} }
if (group_scores[i_group] == topk_group_value) { expert_sel.add(cand, idx);
count_equalto_topkth_group++;
} }
} }
}; expert_sel.done();
if constexpr (kUseStaticNGroup) { // compute unbiased routing weights + optional renorm.
#pragma unroll float lane_unbiased = 0.0f;
for (int i_group = 0; i_group < NGroup; ++i_group) { IdxT lane_idx = 0;
process_group(i_group); if (lane_id < topk_i32) {
} lane_idx = static_cast<IdxT>(expert_sel.get_idx(0));
} else { T in = scores_token[static_cast<int32_t>(lane_idx)];
for (int i_group = 0; i_group < n_group_i32; ++i_group) { lane_unbiased = cuda_cast<float, T>(apply_scoring<SF>(in));
process_group(i_group);
}
}
queue.done();
// Get the topk_idx
queue.dumpIdx(s_topk_idx);
} }
// Load the valid score value float topk_sum = 1e-20f;
// Calculate the summation
float topk_sum = 1e-20;
if (case_id < num_tokens && if_proceed_next_topk) {
for (int i = lane_id;
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
i += WARP_SIZE) {
T value = cuda_cast<T, float>(0.0f);
if (i < topk) {
// Load the score value (without bias) for normalization
T input = scores[s_topk_idx[i]];
value = apply_scoring<SF>(input);
s_topk_value[i] = value;
}
if (renormalize) { if (renormalize) {
topk_sum += topk_sum += cg::reduce(tile, lane_unbiased, cg::plus<float>());
cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
} }
}
}
__syncthreads();
if (case_id < num_tokens) { float scale = static_cast<float>(routed_scaling_factor);
if (if_proceed_next_topk) {
float scale = routed_scaling_factor;
if (renormalize) { if (renormalize) {
scale /= topk_sum; scale /= topk_sum;
} }
for (int i = lane_id; i < topk; i += WARP_SIZE) {
float base = cuda_cast<float, T>(s_topk_value[i]); if (lane_id < topk_i32) {
float value = base * scale; topk_indices[lane_id] = lane_idx;
topk_indices[i] = s_topk_idx[i]; topk_values[lane_id] = lane_unbiased * scale;
topk_values[i] = value;
}
} else {
for (int i = lane_id; i < topk; i += WARP_SIZE) {
topk_indices[i] = i;
topk_values[i] = 1.0f / topk;
}
}
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the
// default result.
} }
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;"); asm volatile("griddepcontrol.launch_dependents;");
#endif #endif
} }
template <typename T, typename BiasT, typename IdxT, ScoringFunc SF>
inline void launch_group_idx_and_topk_kernel(
cudaLaunchConfig_t const& config, T* scores, T* group_scores,
float* topk_values, IdxT* topk_indices, BiasT const* bias,
int64_t const num_tokens, int64_t const n_group, int64_t const topk_group,
int64_t const topk, int64_t const num_experts,
int64_t const num_experts_per_group, bool const renormalize,
double const routed_scaling_factor) {
auto launch = [&](auto* kernel_instance2) {
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
topk_values, topk_indices, bias, num_tokens, n_group,
topk_group, topk, num_experts, num_experts_per_group,
renormalize, routed_scaling_factor);
};
switch (n_group) {
case 4: {
launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF, 4>);
break;
}
case 8: {
launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF, 8>);
break;
}
case 16: {
launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF, 16>);
break;
}
case 32: {
launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF, 32>);
break;
}
default: {
launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF>);
break;
}
}
}
template <typename T, typename BiasT, typename IdxT> template <typename T, typename BiasT, typename IdxT>
void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, void invokeNoAuxTc(T* scores, float* topk_values, IdxT* topk_indices,
IdxT* topk_indices, BiasT const* bias, BiasT const* bias, int64_t const num_tokens,
int64_t const num_tokens, int64_t const num_experts, int64_t const num_experts, int64_t const n_group,
int64_t const n_group, int64_t const topk_group, int64_t const topk_group, int64_t const topk,
int64_t const topk, bool const renormalize, bool const renormalize, double const routed_scaling_factor,
double const routed_scaling_factor, int const scoring_func, int const scoring_func, bool enable_pdl = false,
bool enable_pdl = false, cudaStream_t const stream = 0) { cudaStream_t const stream = 0) {
int64_t num_cases = num_tokens * n_group;
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
cudaLaunchConfig_t config; cudaLaunchConfig_t config;
config.gridDim = topk_with_k2_num_blocks; // One block per token; one warp per group.
config.blockDim = BLOCK_SIZE; config.gridDim = static_cast<uint32_t>(num_tokens);
config.dynamicSmemBytes = 0; config.blockDim = static_cast<uint32_t>(n_group) * WARP_SIZE;
// Dynamic shared memory: WarpSelect staging + per-group topk buffers.
int32_t const num_warps = static_cast<int32_t>(n_group);
size_t const val_bytes =
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(T);
size_t const val_bytes_aligned =
warp_topk::round_up_to_multiple_of<256>(val_bytes);
size_t const idx_bytes =
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(int32_t);
size_t const internal_bytes = val_bytes_aligned + idx_bytes;
size_t const extra_bytes = 16 + static_cast<size_t>(n_group) * sizeof(T);
config.dynamicSmemBytes = internal_bytes + extra_bytes;
config.stream = stream; config.stream = stream;
cudaLaunchAttribute attrs[1]; cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
...@@ -759,64 +687,33 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, ...@@ -759,64 +687,33 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
config.numAttrs = 1; config.numAttrs = 1;
config.attrs = attrs; config.attrs = attrs;
auto const sf = static_cast<ScoringFunc>(scoring_func); auto const sf = static_cast<ScoringFunc>(scoring_func);
int64_t const num_experts_per_group = num_experts / n_group;
auto launch_topk_with_k2 = [&](auto* kernel_instance1) {
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias,
num_tokens, num_cases, n_group, num_experts_per_group);
};
switch (sf) { switch (sf) {
case SCORING_NONE: { case SCORING_NONE: {
auto* kernel_instance1 = &topk_with_k2_kernel<T, BiasT, SCORING_NONE>; auto* kernel_instance =
launch_topk_with_k2(kernel_instance1); &grouped_topk_fused_kernel<T, BiasT, IdxT, SCORING_NONE>;
break; cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values,
topk_indices, bias, num_tokens, num_experts, n_group,
topk_group, topk, renormalize, routed_scaling_factor);
return;
} }
case SCORING_SIGMOID: { case SCORING_SIGMOID: {
auto* kernel_instance1 = &topk_with_k2_kernel<T, BiasT, SCORING_SIGMOID>; auto* kernel_instance =
launch_topk_with_k2(kernel_instance1); &grouped_topk_fused_kernel<T, BiasT, IdxT, SCORING_SIGMOID>;
break; cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values,
topk_indices, bias, num_tokens, num_experts, n_group,
topk_group, topk, renormalize, routed_scaling_factor);
return;
} }
default: default:
// should be guarded by higher level checks. // should be guarded by higher level checks.
TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc"); TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc");
} }
int64_t topk_with_k_group_num_blocks =
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
size_t dynamic_smem_in_bytes =
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
topk);
config.gridDim = topk_with_k_group_num_blocks;
config.blockDim = BLOCK_SIZE;
config.dynamicSmemBytes = dynamic_smem_in_bytes;
config.stream = stream;
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
config.numAttrs = 1;
config.attrs = attrs;
switch (sf) {
case SCORING_NONE: {
launch_group_idx_and_topk_kernel<T, BiasT, IdxT, SCORING_NONE>(
config, scores, group_scores, topk_values, topk_indices, bias,
num_tokens, n_group, topk_group, topk, num_experts,
num_experts_per_group, renormalize, routed_scaling_factor);
break;
}
case SCORING_SIGMOID: {
launch_group_idx_and_topk_kernel<T, BiasT, IdxT, SCORING_SIGMOID>(
config, scores, group_scores, topk_values, topk_indices, bias,
num_tokens, n_group, topk_group, topk, num_experts,
num_experts_per_group, renormalize, routed_scaling_factor);
break;
}
default:
TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc");
}
} }
#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT) \ #define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT) \
template void invokeNoAuxTc<T, BiasT, IdxT>( \ template void invokeNoAuxTc<T, BiasT, IdxT>( \
T * scores, T * group_scores, float* topk_values, IdxT* topk_indices, \ T * scores, float* topk_values, IdxT* topk_indices, BiasT const* bias, \
BiasT const* bias, int64_t const num_tokens, int64_t const num_experts, \ int64_t const num_tokens, int64_t const num_experts, \
int64_t const n_group, int64_t const topk_group, int64_t const topk, \ int64_t const n_group, int64_t const topk_group, int64_t const topk, \
bool const renormalize, double const routed_scaling_factor, \ bool const renormalize, double const routed_scaling_factor, \
int const scoring_func, bool enable_pdl, cudaStream_t const stream); int const scoring_func, bool enable_pdl, cudaStream_t const stream);
...@@ -843,17 +740,21 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk( ...@@ -843,17 +740,21 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
int64_t num_tokens = input_size[0]; int64_t num_tokens = input_size[0];
int64_t num_experts = input_size[1]; int64_t num_experts = input_size[1];
TORCH_CHECK(input_size.size() == 2, "scores must be a 2D Tensor"); TORCH_CHECK(input_size.size() == 2, "scores must be a 2D Tensor");
TORCH_CHECK(n_group > 0, "n_group must be positive");
TORCH_CHECK(topk > 0, "topk must be positive");
TORCH_CHECK(topk_group > 0, "topk_group must be positive");
TORCH_CHECK(topk_group <= n_group, "topk_group must be <= n_group");
TORCH_CHECK(num_experts % n_group == 0, TORCH_CHECK(num_experts % n_group == 0,
"num_experts should be divisible by n_group"); "num_experts should be divisible by n_group");
TORCH_CHECK(n_group <= 32, TORCH_CHECK(n_group <= 32,
"n_group should be smaller than or equal to 32 for now"); "n_group should be smaller than or equal to 32 for now");
TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now"); TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now");
TORCH_CHECK(topk <= topk_group * (num_experts / n_group),
"topk must be <= topk_group * (num_experts / n_group)");
TORCH_CHECK(scoring_func == vllm::moe::SCORING_NONE || TORCH_CHECK(scoring_func == vllm::moe::SCORING_NONE ||
scoring_func == vllm::moe::SCORING_SIGMOID, scoring_func == vllm::moe::SCORING_SIGMOID,
"scoring_func must be SCORING_NONE (0) or SCORING_SIGMOID (1)"); "scoring_func must be SCORING_NONE (0) or SCORING_SIGMOID (1)");
torch::Tensor group_scores = torch::empty(
{num_tokens, n_group}, torch::dtype(data_type).device(torch::kCUDA));
// Always output float32 for topk_values (eliminates Python-side conversion) // Always output float32 for topk_values (eliminates Python-side conversion)
torch::Tensor topk_values = torch::empty( torch::Tensor topk_values = torch::empty(
{num_tokens, topk}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); {num_tokens, topk}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
...@@ -868,7 +769,6 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk( ...@@ -868,7 +769,6 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
case torch::kFloat16: \ case torch::kFloat16: \
vllm::moe::invokeNoAuxTc<T, half, IdxT>( \ vllm::moe::invokeNoAuxTc<T, half, IdxT>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \ reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<T*>(group_scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \ reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \ reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<half const*>(bias.data_ptr()), num_tokens, \ reinterpret_cast<half const*>(bias.data_ptr()), num_tokens, \
...@@ -879,7 +779,6 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk( ...@@ -879,7 +779,6 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
case torch::kFloat32: \ case torch::kFloat32: \
vllm::moe::invokeNoAuxTc<T, float, IdxT>( \ vllm::moe::invokeNoAuxTc<T, float, IdxT>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \ reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<T*>(group_scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \ reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \ reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<float const*>(bias.data_ptr()), num_tokens, \ reinterpret_cast<float const*>(bias.data_ptr()), num_tokens, \
...@@ -890,7 +789,6 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk( ...@@ -890,7 +789,6 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
case torch::kBFloat16: \ case torch::kBFloat16: \
vllm::moe::invokeNoAuxTc<T, __nv_bfloat16, IdxT>( \ vllm::moe::invokeNoAuxTc<T, __nv_bfloat16, IdxT>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \ reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<T*>(group_scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \ reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \ reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), \ reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), \
......
...@@ -454,6 +454,9 @@ def dummy_hf_overrides( ...@@ -454,6 +454,9 @@ def dummy_hf_overrides(
# Ensure at least 2 expert per group # Ensure at least 2 expert per group
# Since `grouped_topk` assumes top-2 # Since `grouped_topk` assumes top-2
n_group = getattr(text_config, "n_group", None) n_group = getattr(text_config, "n_group", None)
# Kimi uses `num_expert_group` instead of `n_group`.
if n_group is None:
n_group = getattr(text_config, "num_expert_group", None)
num_experts = n_group * 2 if n_group is not None else 2 num_experts = n_group * 2 if n_group is not None else 2
# we use three layers for Gemma-3n to check # we use three layers for Gemma-3n to check
...@@ -487,6 +490,8 @@ def dummy_hf_overrides( ...@@ -487,6 +490,8 @@ def dummy_hf_overrides(
{ {
"num_experts": num_experts, "num_experts": num_experts,
"num_experts_per_tok": 2, "num_experts_per_tok": 2,
# Kimi uses `num_experts_per_token`.
"num_experts_per_token": 2,
"num_local_experts": num_experts, "num_local_experts": num_experts,
# Otherwise there will not be any expert layers # Otherwise there will not be any expert layers
"first_k_dense_replace": 0, "first_k_dense_replace": 0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment