"examples/model_compress/pruning/legacy/amc/amc_search.py" did not exist on "92f6754e6e2b76dbe9c26743d3bd2898208c6b04"
moe_topk_softmax_kernels.cu 26.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
// Adapt from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/moe/topk_softmax_kernels.cu
// which is originally adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
/* Copyright 2025 SGLang Team. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>

#ifndef USE_ROCM
#include <cub/cub.cuh>
#include <cub/util_type.cuh>
26
#include <cuda/functional>
27
28
29
30
31
32
33
34
35
36
#else
#include <hipcub/hipcub.hpp>
#include <hipcub/util_type.hpp>
#endif

#include "utils.h"

#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))

37
38
39
40
41
42
43
44
45
46
// Define reduction operators based on CUDA version
// CUDA 13 (12.9+) deprecated cub::Max/Min in favor of cuda::maximum/minimum
#if CUDA_VERSION >= 12090
using MaxReduceOp = cuda::maximum<>;
using MinReduceOp = cuda::minimum<>;
#else
using MaxReduceOp = cub::Max;
using MinReduceOp = cub::Min;
#endif

47
48
49
50
51
52
53
54
/// Aligned array type
template <
    typename T,
    /// Number of elements in the array
    int N,
    /// Alignment requirement in bytes
    int Alignment = sizeof(T) * N>
class alignas(Alignment) AlignedArray {
55
  T data[N];
56
57
};

58
59
60
61
62
// ========================== Util functions to convert types ==========================
template <typename T>
__device__ float convert_to_float(T x) {
  if constexpr (std::is_same_v<T, __half>) {
    return __half2float(x);
maxiao1's avatar
maxiao1 committed
63
  } else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
64
65
66
67
68
69
70
71
    return __bfloat162float(x);
  } else if constexpr (std::is_same_v<T, float>) {
    return x;
  } else {
    return static_cast<float>(x);
  }
}

72
73
74
// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing.
75
template <typename T, int TPB>
Lianmin Zheng's avatar
Lianmin Zheng committed
76
77
78
79
80
81
82
__launch_bounds__(TPB) __global__ void moeSoftmax(
    const T* input,
    const bool* finished,
    float* output,
    const int num_cols,
    const float moe_softcapping,
    const float* correction_bias) {
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
  using BlockReduce = cub::BlockReduce<float, TPB>;
  __shared__ typename BlockReduce::TempStorage tmpStorage;

  __shared__ float normalizing_factor;
  __shared__ float float_max;

  const int thread_row_offset = blockIdx.x * num_cols;

  float threadData(-FLT_MAX);

  // Don't touch finished rows.
  if ((finished != nullptr) && finished[blockIdx.x]) {
    return;
  }

Lianmin Zheng's avatar
Lianmin Zheng committed
98
  // First pass: Apply transformation, find max, and write transformed values to output
99
100
  for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
    const int idx = thread_row_offset + ii;
Lianmin Zheng's avatar
Lianmin Zheng committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    float val = convert_to_float<T>(input[idx]);

    // Apply tanh softcapping if enabled
    if (moe_softcapping != 0.0f) {
      val = tanhf(val / moe_softcapping) * moe_softcapping;
    }

    // Apply correction bias if provided
    if (correction_bias != nullptr) {
      val = val + correction_bias[ii];
    }

    output[idx] = val;  // Store transformed value
    threadData = max(val, threadData);
115
116
  }

117
  const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, MaxReduceOp());
118
119
120
121
122
123

  if (threadIdx.x == 0) {
    float_max = maxElem;
  }
  __syncthreads();

Lianmin Zheng's avatar
Lianmin Zheng committed
124
  // Second pass: Compute sum using transformed values from output
125
126
127
  threadData = 0;
  for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
    const int idx = thread_row_offset + ii;
Lianmin Zheng's avatar
Lianmin Zheng committed
128
    threadData += exp((output[idx] - float_max));
129
130
  }

131
  const auto Z = BlockReduce(tmpStorage).Sum(threadData);
132
133
134
135
136
137

  if (threadIdx.x == 0) {
    normalizing_factor = 1.f / Z;
  }
  __syncthreads();

Lianmin Zheng's avatar
Lianmin Zheng committed
138
  // Third pass: Compute final softmax using transformed values from output
139
140
  for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
    const int idx = thread_row_offset + ii;
Lianmin Zheng's avatar
Lianmin Zheng committed
141
142
    const float softmax_val = exp((output[idx] - float_max)) * normalizing_factor;
    output[idx] = softmax_val;
143
144
145
146
147
148
149
150
151
152
153
154
  }
}

template <int TPB>
__launch_bounds__(TPB) __global__ void moeTopK(
    const float* inputs_after_softmax,
    const bool* finished,
    float* output,
    int* indices,
    const int num_experts,
    const int k,
    const int start_expert,
155
156
    const int end_expert,
    const bool renormalize) {
157
158
159
160
161
162
163
164
165
166
167
  using cub_kvp = cub::KeyValuePair<int, float>;
  using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
  __shared__ typename BlockReduce::TempStorage tmpStorage;

  cub_kvp thread_kvp;
  cub::ArgMax arg_max;

  const int block_row = blockIdx.x;

  const bool row_is_active = finished ? !finished[block_row] : true;
  const int thread_read_offset = blockIdx.x * num_experts;
168
  float row_sum_for_renormalize = 0;
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
  for (int k_idx = 0; k_idx < k; ++k_idx) {
    thread_kvp.key = 0;
    thread_kvp.value = -1.f;  // This is OK because inputs are probabilities

    cub_kvp inp_kvp;
    for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
      const int idx = thread_read_offset + expert;
      inp_kvp.key = expert;
      inp_kvp.value = inputs_after_softmax[idx];

      for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
        const int prior_winning_expert = indices[k * block_row + prior_k];

        if (prior_winning_expert == expert) {
          inp_kvp = thread_kvp;
        }
      }

      thread_kvp = arg_max(inp_kvp, thread_kvp);
    }

    const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
    if (threadIdx.x == 0) {
      // Ignore experts the node isn't responsible for with expert parallelism
      const int expert = result_kvp.key;
      const bool node_uses_expert = expert >= start_expert && expert < end_expert;
      const bool should_process_row = row_is_active && node_uses_expert;

      const int idx = k * block_row + k_idx;
      output[idx] = result_kvp.value;
      indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
      assert(indices[idx] >= 0);
201
      row_sum_for_renormalize += result_kvp.value;
202
203
204
    }
    __syncthreads();
  }
205
206
207
208
209
210
211
212

  if (renormalize && threadIdx.x == 0) {
    float row_sum_for_renormalize_inv = 1.f / row_sum_for_renormalize;
    for (int k_idx = 0; k_idx < k; ++k_idx) {
      const int idx = k * block_row + k_idx;
      output[idx] = output[idx] * row_sum_for_renormalize_inv;
    }
  }
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
}

// ====================== TopK softmax things ===============================

/*
  A Top-K gating softmax written to exploit when the number of experts in the MoE layers
  are a small power of 2. This allows us to cleanly share the rows among the threads in
  a single warp and eliminate communication between warps (so no need to use shared mem).

  It fuses the softmax, max and argmax into a single kernel.

  Limitations:
  1) This implementation is intended for when the number of experts is a small power of 2.
  2) This implementation assumes k is small, but will work for any k.
*/

229
template <typename T, int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
230
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
231
    const T* input,
232
233
234
235
236
237
    const bool* finished,
    float* output,
    const int num_rows,
    int* indices,
    const int k,
    const int start_expert,
238
    const int end_expert,
Lianmin Zheng's avatar
Lianmin Zheng committed
239
240
241
    const bool renormalize,
    const float moe_softcapping,
    const float* correction_bias) {
242
243
244
245
246
247
248
  // We begin by enforcing compile time assertions and setting up compile time constants.
  static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
  static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
  static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
  static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");

  // Number of bytes each thread pulls in per load
249
  static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
  static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
  static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
  static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;

  // Restrictions based on previous section.
  static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
  static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
  static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
  static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size");

  // We have NUM_EXPERTS elements per row. We specialize for small #experts
  static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
  static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
  static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;

  // Restrictions for previous section.
  static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp");

  // ===================== From this point, we finally start computing run-time variables. ========================

  // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps.
  // This, each block processes a chunk of rows. We start by computing the start row for each block.
  const int cta_base_row = blockIdx.x * ROWS_PER_CTA;

  // Now, using the base row per thread block, we compute the base row per warp.
  const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;

  // The threads in a warp are split into sub-groups that will work on a row.
  // We compute row offset for each thread sub-group
  const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
  const int thread_row = warp_base_row + thread_row_in_warp;

  // Threads with indices out of bounds should early exit here.
  if (thread_row >= num_rows) {
    return;
  }
  const bool row_is_active = finished ? !finished[thread_row] : true;

  // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
  // row it will read.
290
  const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
291
292
293
294

  // Now, we compute the group each thread belong to in order to determine the first column to start loads.
  const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
  const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
295
  const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
296
297
298
299
300

  // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
  // this can support all powers of 2 up to 16.
  // NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
  // We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
301
  using AccessType = AlignedArray<T, ELTS_PER_LDG>;
302
303

  // Finally, we pull in the data from global mem
304
305
  T row_chunk_temp[VPT];
  AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk_temp);
306
307
  const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
#pragma unroll
Lianmin Zheng's avatar
Lianmin Zheng committed
308
309
  // Note(Byron): interleaved loads to achieve better memory coalescing
  // | thread[0] | thread[1] | thread[2] | thread[3] | thread[0] | thread[1] | thread[2] | thread[3] | ...
310
311
312
313
  for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
    row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
  }

314
315
  float row_chunk[VPT];
#pragma unroll
Lianmin Zheng's avatar
Lianmin Zheng committed
316
  // Note(Byron): upcast logits to float32
317
318
319
320
  for (int ii = 0; ii < VPT; ++ii) {
    row_chunk[ii] = convert_to_float<T>(row_chunk_temp[ii]);
  }

Lianmin Zheng's avatar
Lianmin Zheng committed
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
  // Apply tanh softcapping and correction bias
  if (moe_softcapping != 0.0f || correction_bias != nullptr) {
#pragma unroll
    for (int ii = 0; ii < VPT; ++ii) {
      float val = row_chunk[ii];

      // Apply tanh softcapping if enabled
      if (moe_softcapping != 0.0f) {
        val = tanhf(val / moe_softcapping) * moe_softcapping;
      }

      // Apply correction bias if provided
      if (correction_bias != nullptr) {
        /*
        LDG is interleaved
        |thread0 LDG| |thread1 LDG| |thread0 LDG| |thread1 LDG|
        |--------- group0 --------| |----------group1 --------|
                                      ^ local2
        */
        const int group_id = ii / ELTS_PER_LDG;
        const int local_id = ii % ELTS_PER_LDG;
        const int expert_idx = first_elt_read_by_thread + group_id * THREADS_PER_ROW * ELTS_PER_LDG + local_id;
        val = val + correction_bias[expert_idx];
      }

      row_chunk[ii] = val;
    }
  }

350
351
352
353
354
355
356
357
  // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
  // convert to float afterwards for the exp + sum reduction.
  float thread_max = row_chunk[0];
#pragma unroll
  for (int ii = 1; ii < VPT; ++ii) {
    thread_max = max(thread_max, row_chunk[ii]);
  }

Lianmin Zheng's avatar
Lianmin Zheng committed
358
359
360
361
  /*********************************/
  /********* Softmax Begin *********/
  /*********************************/

362
// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
Lianmin Zheng's avatar
Lianmin Zheng committed
363
// lane id: 0-31 within a warp
364
365
#pragma unroll
  for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
Lianmin Zheng's avatar
Lianmin Zheng committed
366
    // butterfly reduce with (lane id ^ mask)
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    thread_max = max(thread_max, SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, thread_max, mask, THREADS_PER_ROW));
  }

  // From this point, thread max in all the threads have the max within the row.
  // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum.
  float row_sum = 0;
#pragma unroll
  for (int ii = 0; ii < VPT; ++ii) {
    row_chunk[ii] = expf(row_chunk[ii] - thread_max);
    row_sum += row_chunk[ii];
  }

// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern.
#pragma unroll
  for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
    row_sum += SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, row_sum, mask, THREADS_PER_ROW);
  }

  // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
  // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to
  // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row.
  // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the
  // argmax after computing the softmax.
  const float reciprocal_row_sum = 1.f / row_sum;

#pragma unroll
  for (int ii = 0; ii < VPT; ++ii) {
    row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
  }
Lianmin Zheng's avatar
Lianmin Zheng committed
396
397
398
  /*******************************/
  /********* Softmax End *********/
  /*******************************/
399
400
401
402
403
404

  // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
  // with the max index.
  int start_col = first_elt_read_by_thread;
  static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;

405
406
  float row_sum_for_renormalize = 0;

407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
  for (int k_idx = 0; k_idx < k; ++k_idx) {
    // First, each thread does the local argmax
    float max_val = row_chunk[0];
    int expert = start_col;
#pragma unroll
    for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) {
#pragma unroll
      for (int ii = 0; ii < ELTS_PER_LDG; ++ii) {
        float val = row_chunk[ldg * ELTS_PER_LDG + ii];

        // No check on the experts here since columns with the smallest index are processed first and only
        // updated if > (not >=)
        if (val > max_val) {
          max_val = val;
          expert = col + ii;
        }
      }
    }

// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max.
// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can
// then blank out their max with -inf and the warp can run more iterations...
#pragma unroll
    for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
      float other_max = SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, max_val, mask, THREADS_PER_ROW);
      int other_expert = SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, expert, mask, THREADS_PER_ROW);

      // We want lower indices to "win" in every thread so we break ties this way
      if (other_max > max_val || (other_max == max_val && other_expert < expert)) {
        max_val = other_max;
        expert = other_expert;
      }
    }

    // Write the max for this k iteration to global memory.
    if (thread_group_idx == 0) {
      // Add a guard to ignore experts not included by this node
      const bool node_uses_expert = expert >= start_expert && expert < end_expert;
      const bool should_process_row = row_is_active && node_uses_expert;

      // The lead thread from each sub-group will write out the final results to global memory. (This will be a
      // single) thread per row of the input/output matrices.
      const int idx = k * thread_row + k_idx;
      output[idx] = max_val;
      indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
452
      row_sum_for_renormalize += max_val;
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
    }

    // Finally, we clear the value in the thread with the current max if there is another iteration to run.
    if (k_idx + 1 < k) {
      const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
      const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;

      // Only the thread in the group which produced the max will reset the "winning" value to -inf.
      if (thread_group_idx == thread_to_clear_in_group) {
        const int offset_for_expert = expert % ELTS_PER_LDG;
        // Safe to set to any negative value since row_chunk values must be between 0 and 1.
        row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
      }
    }
  }
468
469
470
471
472
473
474
475
476
477

  // Fuse renormalization of topk_weights into this kernel
  if (renormalize && thread_group_idx == 0) {
    float row_sum_for_renormalize_inv = 1.f / row_sum_for_renormalize;
#pragma unroll
    for (int k_idx = 0; k_idx < k; ++k_idx) {
      const int idx = k * thread_row + k_idx;
      output[idx] = output[idx] * row_sum_for_renormalize_inv;
    }
  }
478
479
480
481
}

namespace detail {
// Constructs some constants needed to partition the work across threads at compile time.
482
template <typename T, int EXPERTS, int BYTES_PER_LDG>
483
struct TopkConstants {
484
  static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
485
486
487
488
489
490
491
492
  static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
  static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
  static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
  static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
  static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
};
}  // namespace detail

493
template <typename T, int EXPERTS, int WARPS_PER_TB>
494
void topkGatingSoftmaxLauncherHelper(
495
    const T* input,
496
497
498
499
500
501
502
    const bool* finished,
    float* output,
    int* indices,
    const int num_rows,
    const int k,
    const int start_expert,
    const int end_expert,
503
    const bool renormalize,
Lianmin Zheng's avatar
Lianmin Zheng committed
504
505
    const float moe_softcapping,
    const float* correction_bias,
506
507
508
    cudaStream_t stream) {
  static constexpr std::size_t MAX_BYTES_PER_LDG = 16;

509
510
  static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(T) * EXPERTS);
  using Constants = detail::TopkConstants<T, EXPERTS, BYTES_PER_LDG>;
511
512
513
514
515
516
  static constexpr int VPT = Constants::VPT;
  static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
  const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
  const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;

  dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
517
  topkGatingSoftmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
Lianmin Zheng's avatar
Lianmin Zheng committed
518
519
520
521
522
523
524
525
526
527
528
      input,
      finished,
      output,
      num_rows,
      indices,
      k,
      start_expert,
      end_expert,
      renormalize,
      moe_softcapping,
      correction_bias);
529
530
}

531
532
#define LAUNCH_SOFTMAX(TYPE, NUM_EXPERTS, WARPS_PER_TB)             \
  topkGatingSoftmaxLauncherHelper<TYPE, NUM_EXPERTS, WARPS_PER_TB>( \
Lianmin Zheng's avatar
Lianmin Zheng committed
533
534
535
536
537
538
539
540
541
542
543
544
      gating_output,                                                \
      nullptr,                                                      \
      topk_weights,                                                 \
      topk_indices,                                                 \
      num_tokens,                                                   \
      topk,                                                         \
      0,                                                            \
      num_experts,                                                  \
      renormalize,                                                  \
      moe_softcapping,                                              \
      correction_bias,                                              \
      stream);
545

546
template <typename T>
547
void topkGatingSoftmaxKernelLauncher(
548
    const T* gating_output,
549
550
551
552
553
554
    float* topk_weights,
    int* topk_indices,
    float* softmax_workspace,
    const int num_tokens,
    const int num_experts,
    const int topk,
555
    const bool renormalize,
Lianmin Zheng's avatar
Lianmin Zheng committed
556
557
    const float moe_softcapping,
    const float* correction_bias,
558
559
560
561
    cudaStream_t stream) {
  static constexpr int WARPS_PER_TB = 4;
  switch (num_experts) {
    case 1:
562
      LAUNCH_SOFTMAX(T, 1, WARPS_PER_TB);
563
564
      break;
    case 2:
565
      LAUNCH_SOFTMAX(T, 2, WARPS_PER_TB);
566
567
      break;
    case 4:
568
      LAUNCH_SOFTMAX(T, 4, WARPS_PER_TB);
569
570
      break;
    case 8:
571
      LAUNCH_SOFTMAX(T, 8, WARPS_PER_TB);
572
573
      break;
    case 16:
574
      LAUNCH_SOFTMAX(T, 16, WARPS_PER_TB);
575
576
      break;
    case 32:
577
      LAUNCH_SOFTMAX(T, 32, WARPS_PER_TB);
578
579
      break;
    case 64:
580
      LAUNCH_SOFTMAX(T, 64, WARPS_PER_TB);
581
582
      break;
    case 128:
583
      LAUNCH_SOFTMAX(T, 128, WARPS_PER_TB);
584
585
      break;
    case 256:
586
      LAUNCH_SOFTMAX(T, 256, WARPS_PER_TB);
587
588
589
590
591
592
      break;
    default: {
      TORCH_CHECK(
          softmax_workspace != nullptr,
          "softmax_workspace must be provided for num_experts that are not a power of 2.");
      static constexpr int TPB = 256;
Lianmin Zheng's avatar
Lianmin Zheng committed
593
594
      moeSoftmax<T, TPB><<<num_tokens, TPB, 0, stream>>>(
          gating_output, nullptr, softmax_workspace, num_experts, moe_softcapping, correction_bias);
595
      moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
596
          softmax_workspace, nullptr, topk_weights, topk_indices, num_experts, topk, 0, num_experts, renormalize);
597
598
599
600
601
    }
  }
}

void topk_softmax(
Lianmin Zheng's avatar
Lianmin Zheng committed
602
603
604
605
606
607
    torch::Tensor& topk_weights,   // [num_tokens, topk]
    torch::Tensor& topk_indices,   // [num_tokens, topk]
    torch::Tensor& gating_output,  // [num_tokens, num_experts]
    const bool renormalize,
    const double moe_softcapping,
    const c10::optional<torch::Tensor>& correction_bias) {
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
  // Check data type
  TORCH_CHECK(
      gating_output.scalar_type() == at::ScalarType::Float || gating_output.scalar_type() == at::ScalarType::Half ||
          gating_output.scalar_type() == at::ScalarType::BFloat16,
      "gating_output must be float32, float16, or bfloat16");

  // Check dimensions
  TORCH_CHECK(gating_output.dim() == 2, "gating_output must be 2D tensor [num_tokens, num_experts]");
  TORCH_CHECK(topk_weights.dim() == 2, "topk_weights must be 2D tensor [num_tokens, topk]");
  TORCH_CHECK(topk_indices.dim() == 2, "topk_indices must be 2D tensor [num_tokens, topk]");

  // Check shapes
  TORCH_CHECK(
      gating_output.size(0) == topk_weights.size(0),
      "First dimension of topk_weights must match num_tokens in gating_output");
  TORCH_CHECK(
      gating_output.size(0) == topk_indices.size(0),
      "First dimension of topk_indices must match num_tokens in gating_output");
  TORCH_CHECK(
      topk_weights.size(-1) == topk_indices.size(-1),
      "Second dimension of topk_indices must match topk in topk_weights");
  TORCH_CHECK(topk_weights.size(-1) <= gating_output.size(-1), "topk must be less than or equal to num_experts");

  const int num_experts = static_cast<int>(gating_output.size(-1));
  const int num_tokens = static_cast<int>(gating_output.size(0));
  const int topk = static_cast<int>(topk_weights.size(-1));
634
635
636
637
638
639
640

  const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
  const bool needs_workspace = !is_pow_2 || num_experts > 256;
  const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;

  const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
641
642
643
644
  torch::Tensor softmax_workspace =
      torch::empty({workspace_size}, gating_output.options().dtype(at::ScalarType::Float));

  const at::ScalarType dtype = gating_output.scalar_type();
Lianmin Zheng's avatar
Lianmin Zheng committed
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661

  // Validate correction_bias if provided - must always be float32
  const float* bias_ptr = nullptr;
  if (correction_bias.has_value()) {
    const torch::Tensor& bias_tensor = correction_bias.value();
    TORCH_CHECK(bias_tensor.dim() == 1, "correction_bias must be 1D tensor [num_experts]");
    TORCH_CHECK(bias_tensor.size(0) == num_experts, "correction_bias size must match num_experts");
    TORCH_CHECK(
        bias_tensor.scalar_type() == at::ScalarType::Float,
        "correction_bias must be float32, got ",
        bias_tensor.scalar_type());
    bias_ptr = bias_tensor.data_ptr<float>();
  }

  // Cast moe_softcapping from double to float for CUDA kernels
  const float moe_softcapping_f = static_cast<float>(moe_softcapping);

662
663
664
665
666
667
668
669
670
671
  if (dtype == at::ScalarType::Float) {
    topkGatingSoftmaxKernelLauncher<float>(
        gating_output.data_ptr<float>(),
        topk_weights.data_ptr<float>(),
        topk_indices.data_ptr<int>(),
        softmax_workspace.data_ptr<float>(),
        num_tokens,
        num_experts,
        topk,
        renormalize,
Lianmin Zheng's avatar
Lianmin Zheng committed
672
673
        moe_softcapping_f,
        bias_ptr,
674
675
676
677
678
679
680
681
682
683
684
        stream);
  } else if (dtype == at::ScalarType::Half) {
    topkGatingSoftmaxKernelLauncher<__half>(
        reinterpret_cast<const __half*>(gating_output.data_ptr<at::Half>()),
        topk_weights.data_ptr<float>(),
        topk_indices.data_ptr<int>(),
        softmax_workspace.data_ptr<float>(),
        num_tokens,
        num_experts,
        topk,
        renormalize,
Lianmin Zheng's avatar
Lianmin Zheng committed
685
686
        moe_softcapping_f,
        bias_ptr,
687
688
        stream);
  } else if (dtype == at::ScalarType::BFloat16) {
maxiao1's avatar
maxiao1 committed
689
690
    topkGatingSoftmaxKernelLauncher<__hip_bfloat16>(
        reinterpret_cast<const __hip_bfloat16*>(gating_output.data_ptr<at::BFloat16>()),
691
692
693
694
695
696
697
        topk_weights.data_ptr<float>(),
        topk_indices.data_ptr<int>(),
        softmax_workspace.data_ptr<float>(),
        num_tokens,
        num_experts,
        topk,
        renormalize,
Lianmin Zheng's avatar
Lianmin Zheng committed
698
699
        moe_softcapping_f,
        bias_ptr,
700
701
702
703
        stream);
  } else {
    TORCH_CHECK(false, "Unsupported gating_output dtype: ", dtype);
  }
704
}