topk_softmax_kernels.cu 30.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
/*
 * Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
 * Copyright (c) 2024, The vLLM team.
 * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
19
#include <type_traits>
20
#include <torch/all.h>
21
22
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
23
#include "../cuda_compat.h"
Aidyn-A's avatar
Aidyn-A committed
24
#include "../cub_helpers.h"
25

26
27
28
29
30
31
32
33
34
35
#ifndef USE_ROCM
    #include <cuda_bf16.h>
    #include <cuda_fp16.h>
#else
    #include <hip/hip_bf16.h>
    #include <hip/hip_fp16.h>
    typedef __hip_bfloat16 __nv_bfloat16;
    typedef __hip_bfloat162 __nv_bfloat162;
#endif

36
37
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
38
39
40
41
42
43
44
45
46
47
48
49

namespace vllm {
namespace moe {

/// Aligned array type
template <
    typename T,
    /// Number of elements in the array
    int N,
    /// Alignment requirement in bytes
    int Alignment = sizeof(T) * N
>
50
51
struct alignas(Alignment) AlignedArray {
    T data[N];
52
53
};

54
55
56
57
58
59
60
61
62
63
64
template <typename T>
__device__ __forceinline__ float toFloat(T value) {
    if constexpr (std::is_same_v<T, float>) {
        return value;
    } else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
        return __bfloat162float(value);
    } else if constexpr (std::is_same_v<T, __half>) {
        return __half2float(value);
    }
}

65
66
67
// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing.
68
template <int TPB, typename InputType>
69
__launch_bounds__(TPB) __global__
70
    void moeSoftmax(const InputType* input, const bool* finished, float* output, const int num_cols)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
{
    using BlockReduce = cub::BlockReduce<float, TPB>;
    __shared__ typename BlockReduce::TempStorage tmpStorage;

    __shared__ float normalizing_factor;
    __shared__ float float_max;

    const int thread_row_offset = blockIdx.x * num_cols;

    float threadData(-FLT_MAX);

    // Don't touch finished rows.
    if ((finished != nullptr) && finished[blockIdx.x])
    {
        return;
    }

    for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
    {
        const int idx = thread_row_offset + ii;
91
92
        const float val = toFloat(input[idx]);
        threadData = max(val, threadData);
93
94
    }

Aidyn-A's avatar
Aidyn-A committed
95
    const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, CubMaxOp());
96
97
98
99
100
101
102
103
104
105
106
    if (threadIdx.x == 0)
    {
        float_max = maxElem;
    }
    __syncthreads();

    threadData = 0;

    for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
    {
        const int idx = thread_row_offset + ii;
107
108
        const float val = toFloat(input[idx]);
        threadData += expf(val - float_max);
109
110
    }

Aidyn-A's avatar
Aidyn-A committed
111
    const auto Z = BlockReduce(tmpStorage).Reduce(threadData, CubAddOp());
112
113
114
115
116
117
118
119
120
121

    if (threadIdx.x == 0)
    {
        normalizing_factor = 1.f / Z;
    }
    __syncthreads();

    for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
    {
        const int idx = thread_row_offset + ii;
122
123
124
        const float val = toFloat(input[idx]);
        const float softmax_val = expf(val - float_max) * normalizing_factor;
        output[idx] = softmax_val;
125
126
127
    }
}

128
129
130
131
132
133
134
135
136
137
template <int TPB, typename IndType>
__launch_bounds__(TPB) __global__ void moeTopK(
    const float* inputs_after_softmax,
    const bool* finished,
    float* output,
    IndType* indices,
    int* source_rows,
    const int num_experts,
    const int k,
    const int start_expert,
138
139
    const int end_expert,
    const bool renormalize)
140
141
142
143
144
145
146
147
148
149
150
151
152
153
{

    using cub_kvp = cub::KeyValuePair<int, float>;
    using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
    __shared__ typename BlockReduce::TempStorage tmpStorage;

    cub_kvp thread_kvp;
    cub::ArgMax arg_max;

    const int num_rows = gridDim.x;
    const int block_row = blockIdx.x;

    const bool row_is_active = finished ? !finished[block_row] : true;
    const int thread_read_offset = blockIdx.x * num_experts;
154
    float selected_sum = 0.f;
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    for (int k_idx = 0; k_idx < k; ++k_idx)
    {
        thread_kvp.key = 0;
        thread_kvp.value = -1.f; // This is OK because inputs are probabilities

        cub_kvp inp_kvp;
        for (int expert = threadIdx.x; expert < num_experts; expert += TPB)
        {
            const int idx = thread_read_offset + expert;
            inp_kvp.key = expert;
            inp_kvp.value = inputs_after_softmax[idx];

            for (int prior_k = 0; prior_k < k_idx; ++prior_k)
            {
                const int prior_winning_expert = indices[k * block_row + prior_k];

                if (prior_winning_expert == expert)
                {
                    inp_kvp = thread_kvp;
                }
            }

            thread_kvp = arg_max(inp_kvp, thread_kvp);
        }

        const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
        if (threadIdx.x == 0)
        {
            // Ignore experts the node isn't responsible for with expert parallelism
            const int expert = result_kvp.key;
            const bool node_uses_expert = expert >= start_expert && expert < end_expert;
            const bool should_process_row = row_is_active && node_uses_expert;

            const int idx = k * block_row + k_idx;
            output[idx] = result_kvp.value;
            indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
            assert(indices[idx] >= 0);
            source_rows[idx] = k_idx * num_rows + block_row;
193
194
195
            if (renormalize) {
                selected_sum += result_kvp.value;
            }
196
197
198
        }
        __syncthreads();
    }
199
200
201
202
203
204
205
206
207
208
209

    // Renormalize the k weights for this row to sum to 1, if requested.
    if (renormalize) {
        if (threadIdx.x == 0) {
            const float denom = selected_sum > 0.f ? selected_sum : 1.f;
            for (int k_idx = 0; k_idx < k; ++k_idx) {
                const int idx = k * block_row + k_idx;
                output[idx] = output[idx] / denom;
            }
        }
    }
210
211
212
213
214
215
216
217
218
219
220
221
}

// ====================== TopK softmax things ===============================

/*
  A Top-K gating softmax written to exploit when the number of experts in the MoE layers
  are a small power of 2. This allows us to cleanly share the rows among the threads in
  a single warp and eliminate communication between warps (so no need to use shared mem).

  It fuses the softmax, max and argmax into a single kernel.

  Limitations:
222
223
224
  1) This implementation is optimized for when the number of experts is a small power of 2.
     Additionally it also supports when number of experts is multiple of 64 which is still
     faster than the computing softmax and topK separately (only tested on CUDA yet).
225
226
227
  2) This implementation assumes k is small, but will work for any k.
*/

228
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename IndType, typename InputType = float>
229
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
230
231
    void topkGatingSoftmax(const InputType* input, const bool* finished, float* output, const int num_rows, IndType* indices,
        int* source_rows, const int k, const int start_expert, const int end_expert, const bool renormalize)
232
{
233
234
235
236
    static_assert(std::is_same_v<InputType, float> || std::is_same_v<InputType, __nv_bfloat16> ||
                      std::is_same_v<InputType, __half>,
                  "InputType must be float, __nv_bfloat16, or __half");

237
238
239
240
241
    // We begin by enforcing compile time assertions and setting up compile time constants.
    static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
    static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");

    // Number of bytes each thread pulls in per load
242
    static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType);
243
244
245
246
    static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
    static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
    static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;

247
248
249
250
251
    if constexpr (std::is_same_v<InputType, __nv_bfloat16> || std::is_same_v<InputType, __half>) {
        static_assert(ELTS_PER_LDG == 1 || ELTS_PER_LDG % 2 == 0,
            "ELTS_PER_LDG must be 1 or even for 16-bit conversion");
    }

252
253
    // Restrictions based on previous section.
    static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
254
    static_assert(WARP_SIZE_PARAM % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
255
    static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
256
    static_assert(THREADS_PER_ROW <= WARP_SIZE_PARAM, "THREADS_PER_ROW can be at most warp size");
257
258

    // We have NUM_EXPERTS elements per row. We specialize for small #experts
259
    static constexpr int ELTS_PER_WARP = WARP_SIZE_PARAM * VPT;
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
    static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;

    // Restrictions for previous section.
    static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp");

    // ===================== From this point, we finally start computing run-time variables. ========================

    // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps.
    // This, each block processes a chunk of rows. We start by computing the start row for each block.
    const int cta_base_row = blockIdx.x * ROWS_PER_CTA;

    // Now, using the base row per thread block, we compute the base row per warp.
    const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;

    // The threads in a warp are split into sub-groups that will work on a row.
    // We compute row offset for each thread sub-group
    const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
    const int thread_row = warp_base_row + thread_row_in_warp;

    // Threads with indices out of bounds should early exit here.
    if (thread_row >= num_rows)
    {
        return;
    }
    const bool row_is_active = finished ? !finished[thread_row] : true;

    // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
    // row it will read.
289
    const InputType* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
290
291
292
293

    // Now, we compute the group each thread belong to in order to determine the first column to start loads.
    const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
    const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
294
    const InputType* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
295
296
297

    // Finally, we pull in the data from global mem
    float row_chunk[VPT];
298
299
300
301
302
303

    // NOTE(zhuhaoran): dispatch different input types loading, BF16/FP16 convert to float
    if constexpr (std::is_same_v<InputType, float>) {
        using VecType = AlignedArray<float, ELTS_PER_LDG>;
        VecType* row_chunk_vec_ptr = reinterpret_cast<VecType*>(&row_chunk);
        const VecType* vec_thread_read_ptr = reinterpret_cast<const VecType*>(thread_read_ptr);
304
#pragma unroll
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
        for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
            row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
        }
    } else if constexpr (std::is_same_v<InputType, __nv_bfloat16>) {
        if constexpr (ELTS_PER_LDG >= 2) {
            using VecType = AlignedArray<__nv_bfloat16, ELTS_PER_LDG>;
            float2* row_chunk_f2 = reinterpret_cast<float2*>(row_chunk);
            const VecType* vec_thread_read_ptr = reinterpret_cast<const VecType*>(thread_read_ptr);
#pragma unroll
            for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
                VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW];
                int base_idx_f2 = ii * ELTS_PER_LDG / 2;
#pragma unroll
                for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) {
                    row_chunk_f2[base_idx_f2 + jj] = __bfloat1622float2(
                        *reinterpret_cast<const __nv_bfloat162*>(vec.data + jj * 2)
                    );
                }
            }
        } else { // ELTS_PER_LDG == 1
#pragma unroll
            for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
                const __nv_bfloat16* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW;
                row_chunk[ii] = __bfloat162float(*scalar_ptr);
            }
        }
    } else if constexpr (std::is_same_v<InputType, __half>) {
        if constexpr (ELTS_PER_LDG >= 2) {
            using VecType = AlignedArray<__half, ELTS_PER_LDG>;
            float2* row_chunk_f2 = reinterpret_cast<float2*>(row_chunk);
            const VecType* vec_thread_read_ptr = reinterpret_cast<const VecType*>(thread_read_ptr);
#pragma unroll
            for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
                VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW];
                int base_idx_f2 = ii * ELTS_PER_LDG / 2;
#pragma unroll
                for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) {
                    row_chunk_f2[base_idx_f2 + jj] = __half22float2(
                        *reinterpret_cast<const __half2*>(vec.data + jj * 2)
                    );
                }
            }
        } else { // ELTS_PER_LDG == 1
#pragma unroll
            for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
                const __half* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW;
                row_chunk[ii] = __half2float(*scalar_ptr);
            }
        }
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
    }

    // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
    // convert to float afterwards for the exp + sum reduction.
    float thread_max = row_chunk[0];
#pragma unroll
    for (int ii = 1; ii < VPT; ++ii)
    {
        thread_max = max(thread_max, row_chunk[ii]);
    }

// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
#pragma unroll
    for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
    {
369
        thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW));
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    }

    // From this point, thread max in all the threads have the max within the row.
    // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum.
    float row_sum = 0;
#pragma unroll
    for (int ii = 0; ii < VPT; ++ii)
    {
        row_chunk[ii] = expf(row_chunk[ii] - thread_max);
        row_sum += row_chunk[ii];
    }

// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern.
#pragma unroll
    for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
    {
386
        row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW);
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
    }

    // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
    // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to
    // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row.
    // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the
    // argmax after computing the softmax.
    const float reciprocal_row_sum = 1.f / row_sum;

#pragma unroll
    for (int ii = 0; ii < VPT; ++ii)
    {
        row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
    }

    // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
    // with the max index.
    int start_col = first_elt_read_by_thread;
    static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;

407
    float selected_sum = 0.f;
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
    for (int k_idx = 0; k_idx < k; ++k_idx)
    {
        // First, each thread does the local argmax
        float max_val = row_chunk[0];
        int expert = start_col;
#pragma unroll
        for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG)
        {
#pragma unroll
            for (int ii = 0; ii < ELTS_PER_LDG; ++ii)
            {
                float val = row_chunk[ldg * ELTS_PER_LDG + ii];

                // No check on the experts here since columns with the smallest index are processed first and only
                // updated if > (not >=)
                if (val > max_val)
                {
                    max_val = val;
                    expert = col + ii;
                }
            }
        }

// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max.
// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can
// then blank out their max with -inf and the warp can run more iterations...
#pragma unroll
        for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
        {
437
438
            float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
            int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460

            // We want lower indices to "win" in every thread so we break ties this way
            if (other_max > max_val || (other_max == max_val && other_expert < expert))
            {
                max_val = other_max;
                expert = other_expert;
            }
        }

        // Write the max for this k iteration to global memory.
        if (thread_group_idx == 0)
        {
            // Add a guard to ignore experts not included by this node
            const bool node_uses_expert = expert >= start_expert && expert < end_expert;
            const bool should_process_row = row_is_active && node_uses_expert;

            // The lead thread from each sub-group will write out the final results to global memory. (This will be a
            // single) thread per row of the input/output matrices.
            const int idx = k * thread_row + k_idx;
            output[idx] = max_val;
            indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
            source_rows[idx] = k_idx * num_rows + thread_row;
461
462
463
            if (renormalize) {
                selected_sum += max_val;
            }
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
        }

        // Finally, we clear the value in the thread with the current max if there is another iteration to run.
        if (k_idx + 1 < k)
        {
            const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
            const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;

            // Only the thread in the group which produced the max will reset the "winning" value to -inf.
            if (thread_group_idx == thread_to_clear_in_group)
            {
                const int offset_for_expert = expert % ELTS_PER_LDG;
                // Safe to set to any negative value since row_chunk values must be between 0 and 1.
                row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
            }
        }
    }
481
482
483
484
485
486
487
488
489
490
491
492
493

    // Renormalize the k weights for this row to sum to 1, if requested.
    if (renormalize) {
        if (thread_group_idx == 0)
        {
            const float denom = selected_sum > 0.f ? selected_sum : 1.f;
            for (int k_idx = 0; k_idx < k; ++k_idx)
            {
                const int idx = k * thread_row + k_idx;
                output[idx] = output[idx] / denom;
            }
        }
    }
494
495
496
497
498
}

namespace detail
{
// Constructs some constants needed to partition the work across threads at compile time.
499
template <int EXPERTS, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename InputType>
500
501
struct TopkConstants
{
502
    static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType);
503
504
    static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0, "");
    static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM));
505
506
    static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
    static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
507
    static const int ROWS_PER_WARP = WARP_SIZE_PARAM / THREADS_PER_ROW;
508
509
510
};
} // namespace detail

511
512
513
514
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, int MAX_BYTES_PER_LDG, typename IndType, typename InputType>
void topkGatingSoftmaxLauncherHelper(const InputType* input, const bool* finished, float* output, IndType* indices,
    int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, const bool renormalize,
    cudaStream_t stream)
515
{
516
517
    static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(InputType) * EXPERTS);
    using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM, InputType>;
518
519
    static constexpr int VPT = Constants::VPT;
    static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
520
    const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
521
522
    const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;

523
    dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
524
525
    topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, WARP_SIZE_PARAM, IndType, InputType><<<num_blocks, block_dim, 0, stream>>>(
        input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert, renormalize);
526
527
}

528
#ifndef USE_ROCM
529
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES)                          \
530
531
    static_assert(WARP_SIZE == 32,                                                    \
                  "Unsupported warp size. Only 32 is supported for CUDA");            \
532
    topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES>( \
533
534
        gating_output, nullptr, topk_weights, topk_indices, token_expert_indices,     \
        num_tokens, topk, 0, num_experts, renormalize, stream);
535
536
537
538
#else
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES)                             \
    if (WARP_SIZE == 64) {                                                               \
        topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64, MAX_BYTES>(       \
539
540
            gating_output, nullptr, topk_weights, topk_indices, token_expert_indices,    \
            num_tokens, topk, 0, num_experts, renormalize, stream);                      \
541
542
    } else if (WARP_SIZE == 32) {                                                        \
        topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32, MAX_BYTES>(       \
543
544
            gating_output, nullptr, topk_weights, topk_indices, token_expert_indices,    \
            num_tokens, topk, 0, num_experts, renormalize, stream);                      \
545
546
547
548
    } else {                                                                             \
        assert(false && "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \
    }
#endif
549

550
template <typename IndType, typename InputType>
551
void topkGatingSoftmaxKernelLauncher(
552
    const InputType* gating_output,
553
    float* topk_weights,
554
    IndType* topk_indices,
555
556
557
558
559
    int* token_expert_indices,
    float* softmax_workspace,
    const int num_tokens,
    const int num_experts,
    const int topk,
560
    const bool renormalize,
561
562
    cudaStream_t stream) {
    static constexpr int WARPS_PER_TB = 4;
563
    static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16;
564
#ifndef USE_ROCM
565
566
567
568
    // for bfloat16 dtype, we need 4 bytes loading to make sure num_experts
    // elements can be loaded by a warp
    static constexpr int BYTES_PER_LDG_MULTIPLE_64 =
    (std::is_same_v<InputType, __nv_bfloat16> || std::is_same_v<InputType, __half>) ? 4 : 8;
569
#endif
570
571
    switch (num_experts) {
        case 1:
572
            LAUNCH_SOFTMAX(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
573
574
            break;
        case 2:
575
            LAUNCH_SOFTMAX(2, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
576
577
            break;
        case 4:
578
            LAUNCH_SOFTMAX(4, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
579
580
            break;
        case 8:
581
            LAUNCH_SOFTMAX(8, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
582
583
            break;
        case 16:
584
            LAUNCH_SOFTMAX(16, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
585
586
            break;
        case 32:
587
            LAUNCH_SOFTMAX(32, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
588
589
            break;
        case 64:
590
            LAUNCH_SOFTMAX(64, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
591
592
            break;
        case 128:
593
            LAUNCH_SOFTMAX(128, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
594
595
            break;
        case 256:
596
597
598
599
600
601
602
603
604
605
606
            LAUNCH_SOFTMAX(256, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
            break;
        case 512:
            LAUNCH_SOFTMAX(512, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
            break;
        // (CUDA only) support multiples of 64 when num_experts is not power of 2.
        // ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of num_experts,
        // alternatively we can test 4 bytes loading and enable it in future.
#ifndef USE_ROCM
        case 192:
            LAUNCH_SOFTMAX(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
607
            break;
608
609
610
611
612
613
614
615
616
617
618
619
620
        case 320:
            LAUNCH_SOFTMAX(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
            break;
        case 384:
            LAUNCH_SOFTMAX(384, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
            break;
        case 448:
            LAUNCH_SOFTMAX(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
            break;
        case 576:
            LAUNCH_SOFTMAX(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
            break;
#endif
621
622
        default: {
            TORCH_CHECK(softmax_workspace != nullptr,
623
                "softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64.");
624
            static constexpr int TPB = 256;
625
            moeSoftmax<TPB, InputType><<<num_tokens, TPB, 0, stream>>>(
626
627
                gating_output, nullptr, softmax_workspace, num_experts);
            moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
628
                softmax_workspace, nullptr, topk_weights, topk_indices, token_expert_indices,
629
                num_experts, topk, 0, num_experts, renormalize);
630
631
632
633
634
635
636
        }
    }
}

} // namespace moe
} // namespace vllm

637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674

template<typename ComputeType>
void dispatch_topk_softmax_launch(
    torch::Tensor& gating_output,
    torch::Tensor& topk_weights,
    torch::Tensor& topk_indices,
    torch::Tensor& token_expert_indices,
    torch::Tensor& softmax_workspace,
    int num_tokens, int num_experts, int topk, bool renormalize, cudaStream_t stream)
{
    if (topk_indices.scalar_type() == at::ScalarType::Int) {
        vllm::moe::topkGatingSoftmaxKernelLauncher<int, ComputeType>(
            reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
            topk_weights.data_ptr<float>(),
            topk_indices.data_ptr<int>(),
            token_expert_indices.data_ptr<int>(),
            softmax_workspace.data_ptr<float>(),
            num_tokens, num_experts, topk, renormalize, stream);
    } else if (topk_indices.scalar_type() == at::ScalarType::UInt32) {
        vllm::moe::topkGatingSoftmaxKernelLauncher<uint32_t, ComputeType>(
            reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
            topk_weights.data_ptr<float>(),
            topk_indices.data_ptr<uint32_t>(),
            token_expert_indices.data_ptr<int>(),
            softmax_workspace.data_ptr<float>(),
            num_tokens, num_experts, topk, renormalize, stream);
    } else {
        TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long);
        vllm::moe::topkGatingSoftmaxKernelLauncher<int64_t, ComputeType>(
            reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
            topk_weights.data_ptr<float>(),
            topk_indices.data_ptr<int64_t>(),
            token_expert_indices.data_ptr<int>(),
            softmax_workspace.data_ptr<float>(),
            num_tokens, num_experts, topk, renormalize, stream);
    }
}

675
676
677
678
void topk_softmax(
    torch::Tensor& topk_weights,                // [num_tokens, topk]
    torch::Tensor& topk_indices,                // [num_tokens, topk]
    torch::Tensor& token_expert_indices,        // [num_tokens, topk]
679
680
    torch::Tensor& gating_output,               // [num_tokens, num_experts]
    bool renormalize)
681
682
{
    const int num_experts = gating_output.size(-1);
683
    const auto num_tokens = gating_output.numel() / num_experts;
684
685
686
687
688
689
690
691
    const int topk = topk_weights.size(-1);

    const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
    const bool needs_workspace = !is_pow_2 || num_experts > 256;
    const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;

    const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
692
693
694
695
696
697
698
699
700
701
702
703
704
705
    const auto workspace_options = gating_output.options().dtype(at::ScalarType::Float);
    torch::Tensor softmax_workspace = torch::empty({workspace_size}, workspace_options);

    if (gating_output.scalar_type() == at::ScalarType::Float) {
        dispatch_topk_softmax_launch<float>(gating_output, topk_weights, topk_indices, 
            token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream);
    } else if (gating_output.scalar_type() == at::ScalarType::Half) {
        dispatch_topk_softmax_launch<__half>(gating_output, topk_weights, topk_indices, 
            token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream);
    } else if (gating_output.scalar_type() == at::ScalarType::BFloat16) {
        dispatch_topk_softmax_launch<__nv_bfloat16>(gating_output, topk_weights, topk_indices, 
            token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream);
    } else {
        TORCH_CHECK(false, "Unsupported gating_output data type: ", gating_output.scalar_type());
706
    }
707
}