prepare_moe_input.cu 13.7 KB
Newer Older
1
2
3
4
#include <c10/cuda/CUDAGuard.h>
#include <cudaTypedefs.h>
#include <torch/all.h>

Mick's avatar
Mick committed
5
#include <flashinfer/vec_dtypes.cuh>
6
7
#include <iostream>

8
#include "cutlass/array.h"
Mick's avatar
Mick committed
9
#include "utils.h"
10

11
12
13
14
15
16
17
constexpr uint64_t THREADS_PER_EXPERT = 512;

__global__ void compute_problem_sizes(
    const int* __restrict__ topk_ids,
    int32_t* problem_sizes1,
    int32_t* problem_sizes2,
    int32_t* atomic_buffer,
18
19
20
    const int64_t topk_length,
    const int64_t n,
    const int64_t k) {
21
22
23
24
25
26
27
28
29
30
31
32
  int expert_id = blockIdx.x;

  int occurrences = 0;
  for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
    occurrences += (topk_ids[i] == expert_id);
  }
  atomicAdd(&atomic_buffer[expert_id], occurrences);
  __syncthreads();

  if (threadIdx.x == 0) {
    int final_occurrences = atomic_buffer[expert_id];
    problem_sizes1[expert_id * 3] = final_occurrences;
33
34
    problem_sizes1[expert_id * 3 + 1] = static_cast<int32_t>(2 * n);
    problem_sizes1[expert_id * 3 + 2] = static_cast<int32_t>(k);
35
    problem_sizes2[expert_id * 3] = final_occurrences;
36
37
    problem_sizes2[expert_id * 3 + 1] = static_cast<int32_t>(k);
    problem_sizes2[expert_id * 3 + 2] = static_cast<int32_t>(n);
38
39
40
41
42
43
44
  }
}

__global__ void compute_expert_offsets(
    const int32_t* __restrict__ problem_sizes1,
    int32_t* expert_offsets,
    int32_t* atomic_buffer,
45
    const int64_t num_experts) {
46
47
48
49
50
51
52
53
54
  int32_t tot_offset = 0;
  expert_offsets[0] = 0;
  for (int i = 0; i < num_experts; ++i) {
    atomic_buffer[i] = tot_offset;
    tot_offset += problem_sizes1[i * 3];
    expert_offsets[i + 1] = tot_offset;
  }
}

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
__global__ void compute_expert_blockscale_offsets(
    const int32_t* __restrict__ problem_sizes1,
    int32_t* expert_offsets,
    int32_t* blockscale_offsets,
    int32_t* atomic_buffer,
    const int64_t num_experts) {
  int32_t tot_offset = 0;
  int32_t tot_rounded_offset = 0;
  expert_offsets[0] = 0;
  blockscale_offsets[0] = 0;
  for (int i = 0; i < num_experts; ++i) {
    atomic_buffer[i] = tot_offset;
    int num_tokens = problem_sizes1[i * 3];
    int rounded_num_tokens = (num_tokens + (128 - 1)) / 128 * 128;
    tot_offset += num_tokens;
    tot_rounded_offset += rounded_num_tokens;
    expert_offsets[i + 1] = tot_offset;
    blockscale_offsets[i + 1] = tot_rounded_offset;
  }
}

76
__global__ void compute_arg_sorts(
77
    const int32_t* __restrict__ topk_ids,
78
79
80
    int32_t* input_permutation,
    int32_t* output_permutation,
    int32_t* atomic_buffer,
81
82
    const int64_t topk_length,
    const int64_t topk) {
83
84
85
86
87
88
89
90
91
92
93
94
95
96
  int expert_id = blockIdx.x;

  for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
    if (topk_ids[i] == expert_id) {
      int start = atomicAdd(&atomic_buffer[expert_id], 1);
      input_permutation[start] = i / topk;
      output_permutation[i] = start;
    }
  }
}

void get_moe_prepare_input_caller(
    const torch::Tensor& topk_ids,
    torch::Tensor& expert_offsets,
97
    const std::optional<torch::Tensor>& blockscale_offsets,
98
99
100
101
102
103
104
105
106
107
108
    torch::Tensor& problem_sizes1,
    torch::Tensor& problem_sizes2,
    torch::Tensor& input_permutation,
    torch::Tensor& output_permutation,
    const int64_t num_experts,
    const int64_t n,
    const int64_t k) {
  auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
  auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
  torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);

109
110
111
112
  uint32_t num_threads = static_cast<uint32_t>(min(THREADS_PER_EXPERT, topk_ids.numel()));
  uint32_t num_blocks = static_cast<uint32_t>(num_experts);

  compute_problem_sizes<<<num_blocks, num_threads, 0, stream>>>(
113
114
115
116
117
118
119
      static_cast<const int32_t*>(topk_ids.data_ptr()),
      static_cast<int32_t*>(problem_sizes1.data_ptr()),
      static_cast<int32_t*>(problem_sizes2.data_ptr()),
      static_cast<int32_t*>(atomic_buffer.data_ptr()),
      topk_ids.numel(),
      n,
      k);
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
  if (blockscale_offsets.has_value()) {
    compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
        static_cast<const int32_t*>(problem_sizes1.data_ptr()),
        static_cast<int32_t*>(expert_offsets.data_ptr()),
        static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
        static_cast<int32_t*>(atomic_buffer.data_ptr()),
        num_experts);
  } else {
    compute_expert_offsets<<<1, 1, 0, stream>>>(
        static_cast<const int32_t*>(problem_sizes1.data_ptr()),
        static_cast<int32_t*>(expert_offsets.data_ptr()),
        static_cast<int32_t*>(atomic_buffer.data_ptr()),
        num_experts);
  }
  compute_arg_sorts<<<num_blocks, num_threads, 0, stream>>>(
135
136
137
138
139
140
141
142
143
144
145
      static_cast<const int32_t*>(topk_ids.data_ptr()),
      static_cast<int32_t*>(input_permutation.data_ptr()),
      static_cast<int32_t*>(output_permutation.data_ptr()),
      static_cast<int32_t*>(atomic_buffer.data_ptr()),
      topk_ids.numel(),
      topk_ids.size(1));
}

void prepare_moe_input(
    const torch::Tensor& topk_ids,
    torch::Tensor& expert_offsets,
146
    const std::optional<torch::Tensor>& blockscale_offsets,
147
148
149
150
151
152
153
154
155
156
157
    torch::Tensor& problem_sizes1,
    torch::Tensor& problem_sizes2,
    torch::Tensor& input_permutation,
    torch::Tensor& output_permutation,
    const int64_t num_experts,
    const int64_t n,
    const int64_t k) {
  TORCH_CHECK(topk_ids.dtype() == torch::kInt32);
  get_moe_prepare_input_caller(
      topk_ids,
      expert_offsets,
158
      blockscale_offsets,
159
160
161
162
163
164
165
166
167
      problem_sizes1,
      problem_sizes2,
      input_permutation,
      output_permutation,
      num_experts,
      n,
      k);
  return;
}
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256

template <typename T>
__global__ void shuffleRowsKernel(
    const T* input,
    const int32_t* dst2src_map,
    T* output,
    int64_t num_src_rows,
    int64_t num_dst_rows,
    int64_t num_cols) {
  int64_t dest_row_idx = blockIdx.x;
  int64_t const source_row_idx = dst2src_map[dest_row_idx];

  if (blockIdx.x < num_dst_rows) {
    // Load 128-bits per thread
    constexpr uint64_t ELEM_PER_THREAD = 128 / sizeof(T) / 8;
    using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;

    // Duplicate and permute rows
    auto const* source_row_ptr = reinterpret_cast<DataElem const*>(input + source_row_idx * num_cols);
    auto* dest_row_ptr = reinterpret_cast<DataElem*>(output + dest_row_idx * num_cols);

    auto const start_offset = threadIdx.x;
    auto const stride = blockDim.x;
    auto const num_elems_in_col = num_cols / ELEM_PER_THREAD;

    for (auto elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
      dest_row_ptr[elem_index] = source_row_ptr[elem_index];
    }
  }
}

#define DECLARE_SHUFFLE_ROWS(T)      \
  __global__ void shuffleRowsKernel( \
      const T* input,                \
      const int32_t* dst2src_map,    \
      T* output,                     \
      int64_t num_src_rows,          \
      int64_t num_dest_rows,         \
      int64_t num_cols);

DECLARE_SHUFFLE_ROWS(float);
DECLARE_SHUFFLE_ROWS(half);
DECLARE_SHUFFLE_ROWS(__nv_bfloat16);
DECLARE_SHUFFLE_ROWS(__nv_fp8_e4m3);
DECLARE_SHUFFLE_ROWS(uint8_t);

#define SHUFFLE_ROWS(T)                                    \
  shuffleRowsKernel<T><<<blocks, threads, 0, stream>>>(    \
      reinterpret_cast<const T*>(input),                   \
      static_cast<const int32_t*>(dst2src_map.data_ptr()), \
      reinterpret_cast<T*>(output),                        \
      num_src_rows,                                        \
      num_dst_rows,                                        \
      num_cols)

#define DTYPE_DISPATCH_CASE(T, CUDA_T) \
  case T:                              \
    SHUFFLE_ROWS(CUDA_T);              \
    break;

void shuffle_rows_caller(
    const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor) {
  TORCH_CHECK(
      input_tensor.scalar_type() == output_tensor.scalar_type(),
      "Input and output tensors must have the same data type");
  auto stream = at::cuda::getCurrentCUDAStream().stream();
  uint32_t blocks = static_cast<uint32_t>(output_tensor.size(0));
  uint32_t threads = 256;
  int64_t num_dst_rows = output_tensor.size(0);
  int64_t num_src_rows = input_tensor.size(0);
  int64_t num_cols = input_tensor.size(1);
  const void* input = input_tensor.data_ptr();
  void* output = output_tensor.data_ptr();
  switch (input_tensor.scalar_type()) {
    DTYPE_DISPATCH_CASE(torch::kFloat16, half);
    DTYPE_DISPATCH_CASE(torch::kBFloat16, __nv_bfloat16);
    DTYPE_DISPATCH_CASE(torch::kFloat32, float);
    DTYPE_DISPATCH_CASE(torch::kFloat8_e4m3fn, __nv_fp8_e4m3);
    DTYPE_DISPATCH_CASE(torch::kUInt8, uint8_t);
    default:
      TORCH_CHECK(false, "[moe replicate input] data type dispatch fail!");
  }
  return;
}

void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor) {
  shuffle_rows_caller(input_tensor, dst2src_map, output_tensor);
  return;
}
257
258
259

template <typename scalar_t>
__global__ void apply_shuffle_mul_sum_kernel(
Mick's avatar
Mick committed
260
261
262
    const scalar_t* __restrict__ input_tensor,  // [m * topk, k] (expert-major layout)
    scalar_t* __restrict__ output_tensor,       // [m, k] (token-major layout)
    const int32_t* __restrict__ permutation,    // [m * topk] (c_map: token-major-idx -> expert-major-idx)
263
264
265
    int m,
    int topk,
    int row_stride,
Mick's avatar
Mick committed
266
    const scalar_t* __restrict__ factors)  // [m * topk] (topk_weights, token-major layout)
267
{
Mick's avatar
Mick committed
268
269
270
271
  int i = blockIdx.x;
  if (i >= m) {
    return;
  }
272

Mick's avatar
Mick committed
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
  constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
  using t = float;
  using vec_t = flashinfer::vec_t<t, vec_size>;
  int thread_idx = threadIdx.x;
  int stride = blockDim.x;

  for (int d_vec_idx = thread_idx; d_vec_idx < row_stride / vec_size; d_vec_idx += stride) {
    int d = d_vec_idx * vec_size;
    vec_t sum_vec;
    sum_vec.fill(0.0f);

    for (int j = 0; j < topk; ++j) {
      int token_major_idx = i * topk + j;
      int src_row = permutation[token_major_idx];

      vec_t val_vec;
      val_vec.cast_load(input_tensor + src_row * row_stride + d);

      t factor = 1.0;
      if (factors != nullptr) {
        factor = factors[token_major_idx];
      }

#pragma unroll
      for (int k = 0; k < vec_size; ++k) {
        sum_vec[k] += factor * val_vec[k];
      }
300
    }
Mick's avatar
Mick committed
301
    sum_vec.cast_store(output_tensor + i * row_stride + d);
302
303
  }

Mick's avatar
Mick committed
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
  //  remainder part
  int remainder_start = (row_stride / vec_size) * vec_size;
  for (int d = remainder_start + thread_idx; d < row_stride; d += stride) {
    t sum_val = 0.0;
    for (int j = 0; j < topk; ++j) {
      int token_major_idx = i * topk + j;
      int src_row = permutation[token_major_idx];
      t val = input_tensor[src_row * row_stride + d];

      t factor = 1.0;
      if (factors != nullptr) {
        factor = factors[token_major_idx];
      }
      sum_val += factor * val;
    }
    output_tensor[i * row_stride + d] = sum_val;
  }
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
}

void get_apply_shuffle_mul_sum_caller(
    const torch::Tensor& input_tensor,                // [m * topk, row_stride], bf16/f16
    torch::Tensor& output_tensor,                     // [m, row_stride], bf16/f16
    const torch::Tensor& permutation,                 // [m * topk], int32
    const std::optional<torch::Tensor>& factors_opt)  // optional [m * topk], bf16/f16
{
  TORCH_CHECK(input_tensor.dim() == 2, "input_tensor must be 2D [m * topk, row_stride]");
  TORCH_CHECK(output_tensor.dim() == 2, "output_tensor must be 2D [m, row_stride]");
  TORCH_CHECK(permutation.dim() == 1, "permutation must be 1D [m * topk]");

  int m = output_tensor.size(0);
  int topk = int(permutation.size(0) / m);
  int row_stride = output_tensor.size(1);

  TORCH_CHECK(permutation.size(0) == m * topk, "permutation size must match m * topk");

Mick's avatar
Mick committed
339
340
341
342
343
  auto scalar_type = output_tensor.scalar_type();
  uint32_t vec_size = 16 / sizeof(scalar_type);
  auto blockDim = std::min(row_stride / vec_size, 1024U);
  dim3 block(blockDim);

344
345
346
347
348
349
350
351
352
353
354
355
  dim3 grid(m);  // blockIdx.x = j, blockIdx.y = i
  auto stream = at::cuda::getCurrentCUDAStream(input_tensor.device().index());

  const int32_t* perm_ptr = permutation.data_ptr<int32_t>();

  void* factors_ptr = nullptr;
  if (factors_opt.has_value()) {
    TORCH_CHECK(factors_opt->dtype() == output_tensor.dtype(), "Factors must match output dtype");
    TORCH_CHECK(factors_opt->numel() == m * topk, "Factors must have shape [m * topk]");
    factors_ptr = factors_opt->data_ptr();
  }

Mick's avatar
Mick committed
356
357
358
359
  DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(output_tensor.scalar_type(), scalar_t, [&] {
    apply_shuffle_mul_sum_kernel<scalar_t><<<grid, block, 0, stream>>>(
        static_cast<const scalar_t*>(input_tensor.data_ptr()),
        static_cast<scalar_t*>(output_tensor.data_ptr()),
360
361
362
363
        perm_ptr,
        m,
        topk,
        row_stride,
Mick's avatar
Mick committed
364
365
366
        static_cast<const scalar_t*>(factors_ptr));
    return true;
  });
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
}

/**
 * @brief Applies a permutation-based shuffle, element-wise multiplication, and reduction over the second dimension.
 *
 * This function performs the equivalent of the following PyTorch expression:
 *
 *     (c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1)
 *
 * Specifically:
 * - `input` is shuffled using the `permutation` tensor.
 * - The shuffled tensor is reshaped and multiplied element-wise with `factors` (e.g., top-k weights).
 * - The result is summed along dimension 1 (the top-k dimension), and stored in `output`.
 *
 * @param input        Input tensor of shape (m * topk, k), representing c2.
 * @param output       Output tensor of shape (m, k), where the final reduced results are stored.
 * @param permutation  Index tensor (e.g., c_map) that maps positions in `input` to shuffled layout.
 * @param factors      Optional scaling factors (e.g., top-k weights), shape (m * topk) or (m, topk).
 */
void apply_shuffle_mul_sum(
    const torch::Tensor& input,
    torch::Tensor& output,
    const torch::Tensor& permutation,
    const std::optional<torch::Tensor>& factors) {
  get_apply_shuffle_mul_sum_caller(input, output, permutation, factors);
}