grouped_topk_kernels.cu 31 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * Adapted from
 * https://github.com/NVIDIA/TensorRT-LLM/blob/v0.21.0/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu
 * Copyright (c) 2025, The vLLM team.
 * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
 * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
24
#include <cuda/std/limits>
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace cg = cooperative_groups;

namespace vllm {
namespace moe {

constexpr unsigned FULL_WARP_MASK = 0xffffffff;
constexpr int32_t WARP_SIZE = 32;
constexpr int32_t BLOCK_SIZE = 512;
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;

namespace warp_topk {

template <int size, typename T>
__host__ __device__ constexpr T round_up_to_multiple_of(T len) {
  if (len == 0) {
    return 0;
  }
  return ((len - 1) / size + 1) * size;
}

template <typename T>
constexpr __host__ __device__ bool isPowerOf2(T v) {
  return (v && !(v & (v - 1)));
}

template <bool greater, typename T>
__forceinline__ __device__ bool is_better_than(T val, T baseline) {
  return (val > baseline && greater) || (val < baseline && !greater);
}

template <bool greater, typename T, typename idxT>
__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index,
                                               idxT baseline_index) {
  bool res = (val > baseline && greater) || (val < baseline && !greater);
  if (val == baseline) {
    res = (index < baseline_index && greater) ||
          (index < baseline_index && !greater);
  }
  return res;
}

template <typename T, typename idxT>
int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
  int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k;
  int64_t n = std::max<int>(num_of_warp / 2 * k, num_of_warp * WARP_SIZE);
  return max(cache_topk,
             round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
}

template <int size, bool ascending, bool reverse, typename T, typename idxT,
          bool is_stable>
struct BitonicMerge {
  // input should be a bitonic sequence, and sort it to be a monotonic sequence
  __device__ static void merge(T* __restrict__ val_arr,
                               idxT* __restrict__ idx_arr) {
    static_assert(isPowerOf2(size));
    static_assert(size >= 2 * WARP_SIZE);
    constexpr int arr_len = size / WARP_SIZE;

    constexpr int stride = arr_len / 2;
    for (int i = 0; i < stride; ++i) {
      int const other_i = i + stride;
      T& val = val_arr[i];
      T& other_val = val_arr[other_i];
      bool is_better;
      if constexpr (is_stable) {
        is_better = is_better_than<ascending>(val, other_val, idx_arr[i],
                                              idx_arr[other_i]);
      } else {
        is_better = is_better_than<ascending>(val, other_val);
      }

      if (is_better) {
        T tmp = val;
        val = other_val;
        other_val = tmp;

        idxT tmp2 = idx_arr[i];
        idx_arr[i] = idx_arr[other_i];
        idx_arr[other_i] = tmp2;
      }
    }

    BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
        val_arr, idx_arr);
    BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
        val_arr + arr_len / 2, idx_arr + arr_len / 2);
  }
};

template <int size, bool ascending, typename T, typename idxT, bool is_stable>
struct BitonicSort {
  __device__ static void sort(T* __restrict__ val_arr,
                              idxT* __restrict__ idx_arr) {
    static_assert(isPowerOf2(size));
    static_assert(size >= 2 * WARP_SIZE);
    constexpr int arr_len = size / WARP_SIZE;

    BitonicSort<size / 2, true, T, idxT, is_stable>::sort(val_arr, idx_arr);
    BitonicSort<size / 2, false, T, idxT, is_stable>::sort(
        val_arr + arr_len / 2, idx_arr + arr_len / 2);
    BitonicMerge<size, ascending, ascending, T, idxT, is_stable>::merge(
        val_arr, idx_arr);
  }
};

template <bool ascending, typename T, typename idxT, bool is_stable>
struct BitonicSort<32, ascending, T, idxT, is_stable> {
  __device__ static void sort(T* __restrict__ val_arr,
                              idxT* __restrict__ idx_arr) {
    int const lane = threadIdx.x % WARP_SIZE;

    // ascending doesn't matter before merging since all we need is a bitonic
    // sequence
    for (int stage = 0; stage < 4; ++stage) {
      for (int stride = (1 << stage); stride > 0; stride /= 2) {
        bool reverse = (lane >> stage) & 2;
        bool is_second = lane & stride;

        T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride);
        idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride);

        bool is_better;
        if constexpr (is_stable) {
          if constexpr (ascending) {
            is_better = ((*val_arr > other) ||
                         ((*val_arr == other) && (*idx_arr < other_idx))) !=
                        (reverse != is_second);
          } else {
            is_better = ((*val_arr > other) ||
                         ((*val_arr == other) && (*idx_arr > other_idx))) !=
                        (reverse != is_second);
          }
        } else {
          is_better = (*val_arr != other &&
                       (*val_arr > other) != (reverse != is_second));
        }
        if (is_better) {
          *val_arr = other;
          *idx_arr = other_idx;
        }
      }
    }

    BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr,
                                                                      idx_arr);
  }
};

template <bool ascending, bool reverse, typename T, typename idxT,
          bool is_stable>
struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> {
  __device__ static void merge(T* __restrict__ val_arr,
                               idxT* __restrict__ idx_arr) {
    int const lane = threadIdx.x % WARP_SIZE;
    for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) {
      bool is_second = lane & stride;
      T& val = *val_arr;
      T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride);
      idxT& idx = *idx_arr;
      idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride);

      bool is_better;
      if constexpr (is_stable) {
        if constexpr (ascending) {
          is_better = ((*val_arr > other) ||
                       ((*val_arr == other) && (*idx_arr < other_idx))) ==
                      (reverse != is_second);  // for min
        } else {
          is_better = ((*val_arr > other) ||
                       ((*val_arr == other) && (*idx_arr > other_idx))) ==
                      (reverse != is_second);  // for max
        }
      } else {
        is_better =
            (val != other && ((val > other) == (ascending != is_second)));
      }

      if (is_better) {
        val = other;
        idx = other_idx;
      }
    }
  }
};

template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
class WarpSort {
 public:
  __device__ WarpSort(idxT k, T dummy)
      : lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) {
    static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity));

    for (int i = 0; i < max_arr_len_; ++i) {
      val_arr_[i] = dummy_;
      idx_arr_[i] = 0;
    }
  }

  // load and merge k sorted values
  __device__ void load_sorted(T const* __restrict__ in,
                              idxT const* __restrict__ in_idx, idxT start) {
    idxT idx = start + WARP_SIZE - 1 - lane_;
    for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) {
      if (idx < start + k_) {
        T t = in[idx];
        bool is_better;
        if constexpr (is_stable) {
          is_better =
              is_better_than<greater>(t, val_arr_[i], in_idx[idx], idx_arr_[i]);
        } else {
          is_better = is_better_than<greater>(t, val_arr_[i]);
        }
        if (is_better) {
          val_arr_[i] = t;
          idx_arr_[i] = in_idx[idx];
        }
      }
    }

    BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
        val_arr_, idx_arr_);
  }

  __device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const {
    for (int i = 0; i < max_arr_len_; ++i) {
      idxT out_i = i * WARP_SIZE + lane_;
      if (out_i < k_) {
        out[out_i] = val_arr_[i];
        out_idx[out_i] = idx_arr_[i];
      }
    }
  }

  __device__ void dumpIdx(idxT* __restrict__ out_idx) const {
    for (int i = 0; i < max_arr_len_; ++i) {
      idxT out_i = i * WARP_SIZE + lane_;
      if (out_i < k_) {
        out_idx[out_i] = idx_arr_[i];
      }
    }
  }

 protected:
  static constexpr int max_arr_len_ = capacity / WARP_SIZE;

  T val_arr_[max_arr_len_];
  idxT idx_arr_[max_arr_len_];

  int const lane_;
  idxT const k_;
  T const dummy_;

};  // end class WarpSort

template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
 public:
  __device__ WarpSelect(idxT k, T dummy)
      : WarpSort<capacity, greater, T, idxT, is_stable>(k, dummy),
        k_th_(dummy),
        k_th_lane_((k - 1) % WARP_SIZE) {
    extern __shared__ char smem_buf[];  // extern __shared__ T smem_buf[];

    int const num_of_warp = blockDim.x / WARP_SIZE;
    int const warp_id = threadIdx.x / WARP_SIZE;
    val_smem_ = reinterpret_cast<T*>(smem_buf);
    val_smem_ += warp_id * WARP_SIZE;
    idx_smem_ = reinterpret_cast<idxT*>(
        smem_buf +
        round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE));
    idx_smem_ += warp_id * WARP_SIZE;
  }

  __device__ void add(T const* in, idxT start, idxT end) {
    idxT const end_for_fullwarp =
        round_up_to_multiple_of<WARP_SIZE>(end - start) + start;
    for (idxT i = start + lane_; i < end_for_fullwarp; i += WARP_SIZE) {
      T val = (i < end) ? in[i] : dummy_;
      add(val, i);
    }
  }

  __device__ void add(T val, idxT idx) {
    bool do_add;
    if constexpr (is_stable) {
      do_add = is_better_than<greater>(val, k_th_, idx, k_th_idx_);
    } else {
      do_add = is_better_than<greater>(val, k_th_);
    }

    uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add);
    if (mask == 0) {
      return;
    }

    int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1));
    if (do_add && pos < WARP_SIZE) {
      val_smem_[pos] = val;
      idx_smem_[pos] = idx;
      do_add = false;
    }
    smem_buf_len_ += __popc(mask);
    if (smem_buf_len_ >= WARP_SIZE) {
      __syncwarp();
      merge_buf_(val_smem_[lane_], idx_smem_[lane_]);
      smem_buf_len_ -= WARP_SIZE;
    }
    if (do_add) {
      pos -= WARP_SIZE;
      val_smem_[pos] = val;
      idx_smem_[pos] = idx;
    }
    __syncwarp();
  }

  __device__ void done() {
    if (smem_buf_len_) {
      T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_;
      idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0;
      merge_buf_(val, idx);
    }

    // after done(), smem is used for merging results among warps
    __syncthreads();
  }

 private:
  __device__ void set_k_th_() {
    k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_);
    if constexpr (is_stable) {
      k_th_idx_ =
          __shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_);
    }
  }

  __device__ void merge_buf_(T val, idxT idx) {
    BitonicSort<WARP_SIZE, greater, T, idxT, is_stable>::sort(&val, &idx);

    T& old = val_arr_[max_arr_len_ - 1];

    bool is_better;
    if constexpr (is_stable) {
      is_better =
          is_better_than<greater>(val, old, idx, idx_arr_[max_arr_len_ - 1]);
    } else {
      is_better = is_better_than<greater>(val, old);
    }

    if (is_better) {
      old = val;
      idx_arr_[max_arr_len_ - 1] = idx;
    }

    BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
        val_arr_, idx_arr_);

    set_k_th_();
  }

  using WarpSort<capacity, greater, T, idxT, is_stable>::max_arr_len_;
  using WarpSort<capacity, greater, T, idxT, is_stable>::val_arr_;
  using WarpSort<capacity, greater, T, idxT, is_stable>::idx_arr_;
  using WarpSort<capacity, greater, T, idxT, is_stable>::lane_;
  using WarpSort<capacity, greater, T, idxT, is_stable>::k_;
  using WarpSort<capacity, greater, T, idxT, is_stable>::dummy_;

  T* val_smem_;
  idxT* idx_smem_;
  int smem_buf_len_ = 0;

  T k_th_;
  idxT k_th_idx_;
  int const k_th_lane_;
};  // end class WarpSelect
}  // namespace warp_topk

template <typename T_OUT, typename T_IN>
__device__ inline T_OUT cuda_cast(T_IN val) {
  return val;
}

template <>
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) {
  return __bfloat162float(val);
}

414
415
416
417
418
419
420
template <typename T>
__device__ inline T neg_inf() {
  // cuda::std::numeric_limits<T>::infinity() returns `0` for [T=bf16 or fp16]
  // so we need to cast from fp32
  return cuda_cast<T, float>(-cuda::std::numeric_limits<float>::infinity());
}

421
422
423
424
425
426
427
428
429
template <typename T>
__device__ inline bool is_finite(const T val) {
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
  return cuda::std::isfinite(val);
#else
  return isfinite(cuda_cast<float, T>(val));
#endif
}

430
431
432
433
434
435
436
437
438
439
440
// Scoring function enums
enum ScoringFunc {
  SCORING_NONE = 0,    // no activation function
  SCORING_SIGMOID = 1  // apply sigmoid
};

// Efficient sigmoid approximation from TensorRT-LLM
__device__ inline float sigmoid_accurate(float x) {
  return 0.5f * tanhf(0.5f * x) + 0.5f;
}

441
template <typename T>
442
443
444
445
446
__device__ inline T apply_sigmoid(T val) {
  float f = cuda_cast<float, T>(val);
  return cuda_cast<T, float>(sigmoid_accurate(f));
}

447
448
449
450
451
452
453
454
455
456
template <ScoringFunc SF, typename T>
__device__ inline T apply_scoring(T val) {
  if constexpr (SF == SCORING_SIGMOID) {
    return apply_sigmoid(val);
  } else {
    return val;
  }
}

template <typename T, ScoringFunc SF>
457
__device__ void topk_with_k2(T* output, T const* input, T const* bias,
458
459
                             cg::thread_block_tile<32> const& tile,
                             int32_t const lane_id,
460
                             int const num_experts_per_group) {
461
  // Get the top2 per thread
462
463
  T largest = neg_inf<T>();
  T second_largest = neg_inf<T>();
464
465
466

  if (num_experts_per_group > WARP_SIZE) {
    for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
467
      T value = apply_scoring<SF>(input[i]);
468
469
      value = value + bias[i];

470
471
472
473
474
475
476
477
478
      if (value > largest) {
        second_largest = largest;
        largest = value;
      } else if (value > second_largest) {
        second_largest = value;
      }
    }
  } else {
    for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
479
      T value = apply_scoring<SF>(input[i]);
480
481
      value = value + bias[i];
      largest = value;
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
    }
  }

  __syncwarp();  // Ensure all threads have valid data before reduction
  // Get the top2 warpwise
  T max1 = cg::reduce(tile, largest, cg::greater<T>());

  T max2 = max1;
  bool equal_to_max1 = (max1 == largest);

  int count_max1 = __popc(__ballot_sync(FULL_WARP_MASK, equal_to_max1));

  if (count_max1 == 1) {
    largest = (largest == max1) ? second_largest : largest;
    max2 = cg::reduce(tile, largest, cg::greater<T>());
  }

  if (lane_id == 0) {
    *output = max1 + max2;
  }
}

504
template <typename T, ScoringFunc SF>
505
__global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
506
507
508
                                    int64_t const num_tokens,
                                    int64_t const num_cases,
                                    int64_t const n_group,
509
                                    int64_t const num_experts_per_group) {
510
511
512
513
514
515
  int32_t warp_id = threadIdx.x / WARP_SIZE;
  int32_t lane_id = threadIdx.x % WARP_SIZE;

  int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id;
  if (case_id < num_cases) {
    input += case_id * num_experts_per_group;
516
517
518
    // bias is per expert group, offset to current group
    int32_t group_id = case_id % n_group;
    T const* group_bias = bias + group_id * num_experts_per_group;
519
520
521
522
523
524
525
526
    output += case_id;

    cg::thread_block block = cg::this_thread_block();
    cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
    asm volatile("griddepcontrol.wait;");
#endif
527
528
    topk_with_k2<T, SF>(output, input, group_bias, tile, lane_id,
                        num_experts_per_group);
529
530
531
532
533
534
  }
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
  asm volatile("griddepcontrol.launch_dependents;");
#endif
}

535
template <typename T, typename IdxT, ScoringFunc SF, int NGroup = -1>
536
__global__ void group_idx_and_topk_idx_kernel(
537
538
    T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices,
    T const* bias, int64_t const num_tokens, int64_t const n_group,
539
540
    int64_t const topk_group, int64_t const topk, int64_t const num_experts,
    int64_t const num_experts_per_group, bool renormalize,
541
    double routed_scaling_factor) {
542
543
544
545
546
547
548
549
550
  int32_t warp_id = threadIdx.x / WARP_SIZE;
  int32_t lane_id = threadIdx.x % WARP_SIZE;
  int32_t case_id =
      blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id;  // one per token
  scores += case_id * num_experts;
  group_scores += case_id * n_group;
  topk_values += case_id * topk;
  topk_indices += case_id * topk;

551
552
553
554
555
  constexpr bool kUseStaticNGroup = (NGroup > 0);
  // use int32 to avoid implicit conversion
  int32_t const n_group_i32 =
      kUseStaticNGroup ? NGroup : static_cast<int32_t>(n_group);

556
557
558
559
560
561
562
563
564
565
566
567
568
569
  int32_t align_num_experts_per_group =
      warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);

  cg::thread_block block = cg::this_thread_block();
  cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);

  extern __shared__ char smem_buf[];  // NOTE: reuse the shared memory here to
                                      // store the target topk idx
  int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf);
  T* s_topk_value =
      reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
      warp_id * topk;
  s_topk_idx += warp_id * topk;

570
571
  T value = neg_inf<T>();
  T topk_group_value = neg_inf<T>();
572
573
574
575
576
577
578
579
580
  int32_t num_equalto_topkth_group;

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
  asm volatile("griddepcontrol.wait;");  // I think all prolog can be put before
                                         // acqbulk because it's ptr arithmetic
#endif

  if (case_id < num_tokens) {
    // calculate group_idx
581
582
    int32_t target_num_min =
        WARP_SIZE - n_group_i32 + static_cast<int32_t>(topk_group);
583
    // The check is necessary to avoid abnormal input
584
    if (lane_id < n_group_i32 && is_finite(group_scores[lane_id])) {
585
586
587
      value = group_scores[lane_id];
    }

588
    int count_equal_to_top_value = WARP_SIZE - n_group_i32;
589
590
591
592
593
594
    int pre_count_equal_to_top_value = 0;
    // Use loop to find the largset top_group
    while (count_equal_to_top_value < target_num_min) {
      __syncwarp();  // Ensure all threads have valid data before reduction
      topk_group_value = cg::reduce(tile, value, cg::greater<T>());
      if (value == topk_group_value) {
595
        value = neg_inf<T>();
596
597
      }
      pre_count_equal_to_top_value = count_equal_to_top_value;
598
599
      count_equal_to_top_value =
          __popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf<T>())));
600
601
602
603
604
605
606
    }
    num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
  }
  __syncthreads();

  warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t,
                        /* is_stable */ true>
607
      queue((int32_t)topk, neg_inf<T>());
608
609

  int count_equalto_topkth_group = 0;
610
  bool if_proceed_next_topk = topk_group_value != neg_inf<T>();
611
  if (case_id < num_tokens && if_proceed_next_topk) {
612
    auto process_group = [&](int i_group) {
613
614
615
616
617
618
      if ((group_scores[i_group] > topk_group_value) ||
          ((group_scores[i_group] == topk_group_value) &&
           (count_equalto_topkth_group < num_equalto_topkth_group))) {
        int32_t offset = i_group * num_experts_per_group;
        for (int32_t i = lane_id; i < align_num_experts_per_group;
             i += WARP_SIZE) {
619
620
          T candidates = neg_inf<T>();
          if (i < num_experts_per_group) {
621
            // apply scoring function (if any) and add bias
622
623
            T input = scores[offset + i];
            if (is_finite(input)) {
624
              T score = apply_scoring<SF>(input);
625
626
627
              candidates = score + bias[offset + i];
            }
          }
628
629
630
631
632
633
          queue.add(candidates, offset + i);
        }
        if (group_scores[i_group] == topk_group_value) {
          count_equalto_topkth_group++;
        }
      }
634
635
636
637
638
639
640
641
642
643
644
    };

    if constexpr (kUseStaticNGroup) {
#pragma unroll
      for (int i_group = 0; i_group < NGroup; ++i_group) {
        process_group(i_group);
      }
    } else {
      for (int i_group = 0; i_group < n_group_i32; ++i_group) {
        process_group(i_group);
      }
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
    }
    queue.done();
    __syncwarp();
    // Get the topk_idx
    queue.dumpIdx(s_topk_idx);
    __syncwarp();
  }

  // Load the valid score value
  // Calculate the summation
  float topk_sum = 1e-20;
  if (case_id < num_tokens && if_proceed_next_topk) {
    for (int i = lane_id;
         i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
         i += WARP_SIZE) {
660
      T value = cuda_cast<T, float>(0.0f);
661
      if (i < topk) {
662
663
        // Load the score value (without bias) for normalization
        T input = scores[s_topk_idx[i]];
664
        value = apply_scoring<SF>(input);
665
666
        s_topk_value[i] = value;
      }
667
668
669
670
      if (renormalize) {
        topk_sum +=
            cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
      }
671
672
673
674
675
676
677
678
    }
  }

  __syncthreads();

  if (case_id < num_tokens) {
    if (if_proceed_next_topk) {
      for (int i = lane_id; i < topk; i += WARP_SIZE) {
679
680
681
        float base = cuda_cast<float, T>(s_topk_value[i]);
        float value = renormalize ? (base / topk_sum * routed_scaling_factor)
                                  : (base * routed_scaling_factor);
682
        topk_indices[i] = s_topk_idx[i];
683
        topk_values[i] = value;
684
685
686
687
      }
    } else {
      for (int i = lane_id; i < topk; i += WARP_SIZE) {
        topk_indices[i] = i;
688
        topk_values[i] = 1.0f / topk;
689
690
691
692
693
694
695
696
697
698
      }
    }
    // Note: when if_proceed_next_topk==false, choose the first 8 experts as the
    // default result.
  }
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
  asm volatile("griddepcontrol.launch_dependents;");
#endif
}

699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
template <typename T, typename IdxT, ScoringFunc SF>
inline void launch_group_idx_and_topk_kernel(
    cudaLaunchConfig_t const& config, T* scores, T* group_scores,
    float* topk_values, IdxT* topk_indices, T const* bias,
    int64_t const num_tokens, int64_t const n_group, int64_t const topk_group,
    int64_t const topk, int64_t const num_experts,
    int64_t const num_experts_per_group, bool const renormalize,
    double const routed_scaling_factor) {
  auto launch = [&](auto* kernel_instance2) {
    cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
                       topk_values, topk_indices, bias, num_tokens, n_group,
                       topk_group, topk, num_experts, num_experts_per_group,
                       renormalize, routed_scaling_factor);
  };

  switch (n_group) {
    case 4: {
      launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 4>);
      break;
    }
    case 8: {
      launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 8>);
      break;
    }
    case 16: {
      launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 16>);
      break;
    }
    case 32: {
      launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 32>);
      break;
    }
    default: {
      launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF>);
      break;
    }
  }
}

738
template <typename T, typename IdxT>
739
740
741
742
743
744
void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
                   IdxT* topk_indices, T const* bias, int64_t const num_tokens,
                   int64_t const num_experts, int64_t const n_group,
                   int64_t const topk_group, int64_t const topk,
                   bool const renormalize, double const routed_scaling_factor,
                   int const scoring_func, bool enable_pdl = false,
745
746
747
748
749
750
751
752
753
754
755
756
757
                   cudaStream_t const stream = 0) {
  int64_t num_cases = num_tokens * n_group;
  int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
  cudaLaunchConfig_t config;
  config.gridDim = topk_with_k2_num_blocks;
  config.blockDim = BLOCK_SIZE;
  config.dynamicSmemBytes = 0;
  config.stream = stream;
  cudaLaunchAttribute attrs[1];
  attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
  attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
  config.numAttrs = 1;
  config.attrs = attrs;
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
  auto const sf = static_cast<ScoringFunc>(scoring_func);
  int64_t const num_experts_per_group = num_experts / n_group;
  auto launch_topk_with_k2 = [&](auto* kernel_instance1) {
    cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias,
                       num_tokens, num_cases, n_group, num_experts_per_group);
  };
  switch (sf) {
    case SCORING_NONE: {
      auto* kernel_instance1 = &topk_with_k2_kernel<T, SCORING_NONE>;
      launch_topk_with_k2(kernel_instance1);
      break;
    }
    case SCORING_SIGMOID: {
      auto* kernel_instance1 = &topk_with_k2_kernel<T, SCORING_SIGMOID>;
      launch_topk_with_k2(kernel_instance1);
      break;
    }
    default:
      // should be guarded by higher level checks.
      TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc");
  }
779
780
781
782
783
784
785
786
787
788
789
790
791
792

  int64_t topk_with_k_group_num_blocks =
      (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
  size_t dynamic_smem_in_bytes =
      warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
                                                           topk);
  config.gridDim = topk_with_k_group_num_blocks;
  config.blockDim = BLOCK_SIZE;
  config.dynamicSmemBytes = dynamic_smem_in_bytes;
  config.stream = stream;
  attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
  attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
  config.numAttrs = 1;
  config.attrs = attrs;
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
  switch (sf) {
    case SCORING_NONE: {
      launch_group_idx_and_topk_kernel<T, IdxT, SCORING_NONE>(
          config, scores, group_scores, topk_values, topk_indices, bias,
          num_tokens, n_group, topk_group, topk, num_experts,
          num_experts_per_group, renormalize, routed_scaling_factor);
      break;
    }
    case SCORING_SIGMOID: {
      launch_group_idx_and_topk_kernel<T, IdxT, SCORING_SIGMOID>(
          config, scores, group_scores, topk_values, topk_indices, bias,
          num_tokens, n_group, topk_group, topk, num_experts,
          num_experts_per_group, renormalize, routed_scaling_factor);
      break;
    }
    default:
      TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc");
  }
811
812
813
814
}

#define INSTANTIATE_NOAUX_TC(T, IdxT)                                       \
  template void invokeNoAuxTc<T, IdxT>(                                     \
815
816
817
818
819
      T * scores, T * group_scores, float* topk_values, IdxT* topk_indices, \
      T const* bias, int64_t const num_tokens, int64_t const num_experts,   \
      int64_t const n_group, int64_t const topk_group, int64_t const topk,  \
      bool const renormalize, double const routed_scaling_factor,           \
      int const scoring_func, bool enable_pdl, cudaStream_t const stream);
820
821
822
823
824
825
826
827

INSTANTIATE_NOAUX_TC(float, int32_t);
INSTANTIATE_NOAUX_TC(half, int32_t);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t);
}  // end namespace moe
}  // namespace vllm

std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
828
829
830
831
832
    torch::Tensor const& scores, int64_t n_group, int64_t topk_group,
    int64_t topk, bool renormalize, double routed_scaling_factor,
    torch::Tensor const& bias, int64_t scoring_func = 0) {
  auto data_type = scores.scalar_type();
  auto input_size = scores.sizes();
833
834
  int64_t num_tokens = input_size[0];
  int64_t num_experts = input_size[1];
835
  TORCH_CHECK(input_size.size() == 2, "scores must be a 2D Tensor");
836
837
838
839
840
  TORCH_CHECK(num_experts % n_group == 0,
              "num_experts should be divisible by n_group");
  TORCH_CHECK(n_group <= 32,
              "n_group should be smaller than or equal to 32 for now");
  TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now");
841
842
843
  TORCH_CHECK(scoring_func == vllm::moe::SCORING_NONE ||
                  scoring_func == vllm::moe::SCORING_SIGMOID,
              "scoring_func must be SCORING_NONE (0) or SCORING_SIGMOID (1)");
844
845
846

  torch::Tensor group_scores = torch::empty(
      {num_tokens, n_group}, torch::dtype(data_type).device(torch::kCUDA));
847
  // Always output float32 for topk_values (eliminates Python-side conversion)
848
  torch::Tensor topk_values = torch::empty(
849
      {num_tokens, topk}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
850
851
852
  torch::Tensor topk_indices = torch::empty(
      {num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA));

853
  auto stream = c10::cuda::getCurrentCUDAStream(scores.get_device());
854
855
856
857
858
859
860

  switch (data_type) {
    case torch::kFloat16:
      // Handle Float16
      vllm::moe::invokeNoAuxTc<half, int32_t>(
          reinterpret_cast<half*>(scores.mutable_data_ptr()),
          reinterpret_cast<half*>(group_scores.mutable_data_ptr()),
861
          reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
862
          reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
863
          reinterpret_cast<half const*>(bias.data_ptr()), num_tokens,
864
          num_experts, n_group, topk_group, topk, renormalize,
865
          routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
866
867
868
869
870
871
872
873
      break;
    case torch::kFloat32:
      // Handle Float32
      vllm::moe::invokeNoAuxTc<float, int32_t>(
          reinterpret_cast<float*>(scores.mutable_data_ptr()),
          reinterpret_cast<float*>(group_scores.mutable_data_ptr()),
          reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
          reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
874
          reinterpret_cast<float const*>(bias.data_ptr()), num_tokens,
875
          num_experts, n_group, topk_group, topk, renormalize,
876
          routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
877
878
879
880
881
882
      break;
    case torch::kBFloat16:
      // Handle BFloat16
      vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>(
          reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()),
          reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()),
883
          reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
884
          reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
885
886
887
          reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), num_tokens,
          num_experts, n_group, topk_group, topk, renormalize,
          routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
888
889
890
891
892
893
894
895
896
      break;
    default:
      // Handle other data types
      throw std::invalid_argument(
          "Invalid dtype, only supports float16, float32, and bfloat16");
      break;
  }
  return {topk_values, topk_indices};
}