moe_align_sum_kernels.cu 17.5 KB
Newer Older
1
#include <torch/all.h>
2
#include <ATen/cuda/CUDAContext.h>
3
#include <c10/cuda/CUDAGuard.h>
4
5

#include <ATen/ATen.h>
6
#include <ATen/cuda/Atomic.cuh>
7

8
9
#include "../cuda_compat.h"
#include "../dispatch_utils.h"
10

11
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
12
13

namespace vllm {
14
namespace moe {
15
16

namespace {
17
18
19
20
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
                                         int32_t col) {
  // don't worry about overflow because num_experts is relatively small
  return row * total_col + col;
21
}
22
}  // namespace
23

24
template <typename scalar_t, typename token_cnts_t>
25
26
27
28
29
30
31
32
33
34
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
                                            int32_t* sorted_token_ids,
                                            int32_t* expert_ids,
                                            int32_t* total_tokens_post_pad,
                                            int32_t num_experts,
                                            int32_t block_size, size_t numel) {
  const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
  const size_t start_idx = threadIdx.x * tokens_per_thread;

  extern __shared__ int32_t shared_mem[];
35
  int32_t* cumsum = shared_mem;  // 1d tensor with shape (num_experts + 1)
36
37
38
  token_cnts_t* tokens_cnts =
      (token_cnts_t*)(shared_mem + num_experts +
                      1);  // 2d tensor with shape (blockDim.x + 1, num_experts)
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

  for (int i = 0; i < num_experts; ++i) {
    tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
  }

  /**
   * In the first step we compute token_cnts[thread_index + 1][expert_index],
   * which counts how many tokens in the token shard of thread_index are
   * assigned to expert expert_index.
   */
  for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
    ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
  }

  __syncthreads();

  // For each expert we accumulate the token counts from the different threads.
56
57
58
59
60
61
  if (threadIdx.x < num_experts) {
    tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
    for (int i = 1; i <= blockDim.x; ++i) {
      tokens_cnts[index(num_experts, i, threadIdx.x)] +=
          tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
    }
62
63
64
65
66
67
68
69
70
71
72
73
  }

  __syncthreads();

  // We accumulate the token counts of all experts in thread 0.
  if (threadIdx.x == 0) {
    cumsum[0] = 0;
    for (int i = 1; i <= num_experts; ++i) {
      cumsum[i] = cumsum[i - 1] +
                  CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
                          block_size) *
                      block_size;
74
    }
75
    *total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]);
76
77
78
79
80
81
82
83
  }

  __syncthreads();

  /**
   * For each expert, each thread processes the tokens of the corresponding
   * blocks and stores the corresponding expert_id for each block.
   */
84
85
86
87
88
  if (threadIdx.x < num_experts) {
    for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
         i += block_size) {
      expert_ids[i / block_size] = threadIdx.x;
    }
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
  }

  /**
   * Each thread processes a token shard, calculating the index of each token
   * after sorting by expert number. Given the example topk_ids =
   * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
   * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
   * padding value(preset in python).
   */
  for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
    int32_t expert_id = topk_ids[i];
    /** The cumsum[expert_id] stores the starting index of the tokens that the
     * expert with expert_id needs to process, and
     * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
     * processed by the expert with expert_id within the current thread's token
     * shard.
     */
    int32_t rank_post_pad =
        tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
        cumsum[expert_id];
    sorted_token_ids[rank_post_pad] = i;
    ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
  }
112
}
113

Simon Mo's avatar
Simon Mo committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
// TODO(simon): this is temporarily adapted from
// https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7
// we did this to unblock Deepseek V3 but there should be a better
// implementation to manage shared memory.
template <typename scalar_t>
__global__ void moe_align_block_size_global_mem_kernel(
    scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
    int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
    int32_t block_size, size_t numel, int32_t* tokens_cnts, int32_t* cumsum) {
  const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
  const size_t start_idx = threadIdx.x * tokens_per_thread;

  for (int i = 0; i < num_experts; ++i) {
    tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
  }

  /**
   * In the first step we compute token_cnts[thread_index + 1][expert_index],
   * which counts how many tokens in the token shard of thread_index are
   * assigned to expert expert_index.
   */
  for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
    ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
  }

  __syncthreads();

  // For each expert we accumulate the token counts from the different threads.
  if (threadIdx.x < num_experts) {
    tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
    for (int i = 1; i <= blockDim.x; ++i) {
      tokens_cnts[index(num_experts, i, threadIdx.x)] +=
          tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
    }
  }

  __syncthreads();

  // We accumulate the token counts of all experts in thread 0.
  if (threadIdx.x == 0) {
    cumsum[0] = 0;
    for (int i = 1; i <= num_experts; ++i) {
      cumsum[i] = cumsum[i - 1] +
                  CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
                          block_size) *
                      block_size;
    }
    *total_tokens_post_pad = cumsum[num_experts];
  }

  __syncthreads();

  /**
   * For each expert, each thread processes the tokens of the corresponding
   * blocks and stores the corresponding expert_id for each block.
   */
  if (threadIdx.x < num_experts) {
    for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
         i += block_size) {
      expert_ids[i / block_size] = threadIdx.x;
    }
  }

  /**
   * Each thread processes a token shard, calculating the index of each token
   * after sorting by expert number. Given the example topk_ids =
   * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
   * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
   * padding value(preset in python).
   */
  for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
    int32_t expert_id = topk_ids[i];
    /** The cumsum[expert_id] stores the starting index of the tokens that the
     * expert with expert_id needs to process, and
     * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
     * processed by the expert with expert_id within the current thread's token
     * shard.
     */
    int32_t rank_post_pad =
        tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
        cumsum[expert_id];
    sorted_token_ids[rank_post_pad] = i;
    ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
  }
}

200
// taken from
201
// https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957
202
203
204
205
206
207
208
template <typename scalar_t>
__global__ void sgl_moe_align_block_size_kernel(
    scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
    int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
    int32_t block_size, size_t numel, int32_t* cumsum) {
  __shared__ int32_t shared_counts[32][8];

209
  const int warp_id = threadIdx.x / 32;
210
211
212
  const int experts_per_warp = 8;
  const int my_expert_start = warp_id * experts_per_warp;

213
  // Initialize shared_counts for this warp's experts
214
215
216
217
218
219
  for (int i = 0; i < experts_per_warp; ++i) {
    if (my_expert_start + i < num_experts) {
      shared_counts[warp_id][i] = 0;
    }
  }

220
221
  __syncthreads();

222
223
224
225
226
227
228
229
230
231
232
233
  const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
  const size_t start_idx = threadIdx.x * tokens_per_thread;

  for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
    int expert_id = topk_ids[i];
    int warp_idx = expert_id / experts_per_warp;
    int expert_offset = expert_id % experts_per_warp;
    atomicAdd(&shared_counts[warp_idx][expert_offset], 1);
  }

  __syncthreads();

234
  // Single thread computes cumulative sum and total tokens
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
  if (threadIdx.x == 0) {
    cumsum[0] = 0;
    for (int i = 1; i <= num_experts; ++i) {
      int expert_count = 0;
      int warp_idx = (i - 1) / experts_per_warp;
      int expert_offset = (i - 1) % experts_per_warp;
      expert_count = shared_counts[warp_idx][expert_offset];

      cumsum[i] =
          cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
    }
    *total_tokens_post_pad = cumsum[num_experts];
  }

  __syncthreads();

251
  // Assign expert IDs to blocks
252
253
254
255
256
257
  if (threadIdx.x < num_experts) {
    for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
         i += block_size) {
      expert_ids[i / block_size] = threadIdx.x;
    }
  }
258
}
259

260
261
262
263
264
265
266
267
268
269
270
// taken from
// https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957
template <typename scalar_t>
__global__ void sgl_moe_token_sort_kernel(scalar_t* __restrict__ topk_ids,
                                          int32_t* sorted_token_ids,
                                          int32_t* cumsum_buffer,
                                          size_t numel) {
  const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;

  for (size_t i = tid; i < numel; i += stride) {
271
    int32_t expert_id = topk_ids[i];
272
    int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
273
274
275
276
    sorted_token_ids[rank_post_pad] = i;
  }
}

277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
template <typename scalar_t, int TOPK>
__global__ void moe_sum_kernel(
    scalar_t* __restrict__ out,          // [..., d]
    const scalar_t* __restrict__ input,  // [..., topk, d]
    const int d) {
  const int64_t token_idx = blockIdx.x;
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
    scalar_t x = 0.0;
#pragma unroll
    for (int k = 0; k < TOPK; ++k) {
      x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]);
    }
    out[token_idx * d + idx] = x;
  }
}

}  // namespace moe
294
295
}  // namespace vllm

296
297
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
                          int64_t block_size, torch::Tensor sorted_token_ids,
298
299
300
                          torch::Tensor experts_ids,
                          torch::Tensor num_tokens_post_pad) {
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Simon Mo's avatar
Simon Mo committed
301

302
303
304
305
306
307
308
309
310
311
312
313
314
  int device_max_shared_mem;
  auto dev = topk_ids.get_device();
  cudaDeviceGetAttribute(&device_max_shared_mem,
                         cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);

  const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
  const int32_t shared_mem_i32 =
      ((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
  const int32_t shared_mem_i16 =
      ((num_thread + 1) * num_experts) * sizeof(uint16_t) +
      (num_experts + 1) * sizeof(int32_t);

  bool use_global_memory = false;
315
316
317
318
  bool use_i16 = false;  // Use uint16_t for shared memory token counts
  if (shared_mem_i32 < device_max_shared_mem) {
    // Do nothing in this case. We're all set to use int32_t token counts
  } else if (shared_mem_i16 < device_max_shared_mem &&
319
320
321
322
323
             topk_ids.numel() <= 65535) {
    // when nelements of topk_ids is smaller than 65535 (max value of uint16),
    // element value of token_cnts would also smaller than 65535,
    // so we can use uint16 as dtype of token_cnts
    use_i16 = true;
324
325
  } else {
    use_global_memory = true;
326
327
328
  }

  if (use_global_memory) {
329
    VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
Simon Mo's avatar
Simon Mo committed
330
331
332
333
334
        topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
          // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
          // tensors
          const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);

335
336
337
338
339
340
341
          auto options_int = torch::TensorOptions()
                                 .dtype(torch::kInt)
                                 .device(topk_ids.device());
          torch::Tensor token_cnts_buffer =
              torch::empty({(num_experts + 1) * num_experts}, options_int);
          torch::Tensor cumsum_buffer =
              torch::empty({num_experts + 1}, options_int);
Simon Mo's avatar
Simon Mo committed
342
343
344
345
346
347
348
349

          auto kernel =
              vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
          kernel<<<1, num_thread, 0, stream>>>(
              topk_ids.data_ptr<scalar_t>(),
              sorted_token_ids.data_ptr<int32_t>(),
              experts_ids.data_ptr<int32_t>(),
              num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
350
351
              topk_ids.numel(), token_cnts_buffer.data_ptr<int32_t>(),
              cumsum_buffer.data_ptr<int32_t>());
Simon Mo's avatar
Simon Mo committed
352
        });
353
  } else if (use_i16) {
354
    VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
Simon Mo's avatar
Simon Mo committed
355
356
        topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
          // set dynamic shared mem
357
358
359
360
361
362
363
364
365
366
367
368
          auto kernel =
              vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>;
          AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
              (void*)kernel, shared_mem_i16));
          kernel<<<1, num_thread, shared_mem_i16, stream>>>(
              topk_ids.data_ptr<scalar_t>(),
              sorted_token_ids.data_ptr<int32_t>(),
              experts_ids.data_ptr<int32_t>(),
              num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
              topk_ids.numel());
        });
  } else {
369
    VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
370
371
372
        topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
          auto kernel =
              vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
Simon Mo's avatar
Simon Mo committed
373
          AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
374
375
              (void*)kernel, shared_mem_i32));
          kernel<<<1, num_thread, shared_mem_i32, stream>>>(
Simon Mo's avatar
Simon Mo committed
376
377
378
379
380
381
382
              topk_ids.data_ptr<scalar_t>(),
              sorted_token_ids.data_ptr<int32_t>(),
              experts_ids.data_ptr<int32_t>(),
              num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
              topk_ids.numel());
        });
  }
383
}
384

385
386
387
388
389
390
void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
                              int64_t block_size,
                              torch::Tensor sorted_token_ids,
                              torch::Tensor experts_ids,
                              torch::Tensor num_tokens_post_pad) {
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
391
392
393
  TORCH_CHECK(num_experts == 256,
              "sgl_moe_align_block_size kernel only supports deepseek v3.");

394
  VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
395
      topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
396
        // calc needed amount of shared mem for `cumsum` tensors
397
398
399
        auto options_int =
            torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
        torch::Tensor cumsum_buffer =
400
            torch::zeros({num_experts + 1}, options_int);
401

402
403
404
        auto align_kernel =
            vllm::moe::sgl_moe_align_block_size_kernel<scalar_t>;
        align_kernel<<<1, 1024, 0, stream>>>(
405
406
407
408
            topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
            experts_ids.data_ptr<int32_t>(),
            num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
            topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
409
410
411
412
413
414
415
416
417
418

        const int block_threads = 256;
        const int num_blocks =
            (topk_ids.numel() + block_threads - 1) / block_threads;
        const int max_blocks = 65535;
        const int actual_blocks = std::min(num_blocks, max_blocks);
        auto sort_kernel = vllm::moe::sgl_moe_token_sort_kernel<scalar_t>;
        sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
            topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
            cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
419
420
421
      });
}

422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
void moe_sum(torch::Tensor& input,   // [num_tokens, topk, hidden_size]
             torch::Tensor& output)  // [num_tokens, hidden_size]
{
  const int hidden_size = input.size(-1);
  const int num_tokens = output.numel() / hidden_size;
  const int topk = input.size(1);

  dim3 grid(num_tokens);
  dim3 block(std::min(hidden_size, 1024));
  const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  switch (topk) {
    case 2:
      VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
        vllm::moe::moe_sum_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
            output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
            hidden_size);
      });
      break;

    case 3:
      VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
        vllm::moe::moe_sum_kernel<scalar_t, 3><<<grid, block, 0, stream>>>(
            output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
            hidden_size);
      });
      break;

    case 4:
      VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
        vllm::moe::moe_sum_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
            output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
            hidden_size);
      });
      break;

    default:
      at::sum_out(output, input, 1);
      break;
  }
}