moe_align_sum_kernels.cu 25 KB
Newer Older
1
#include <torch/all.h>
2
#include <ATen/cuda/CUDAContext.h>
3
#include <c10/cuda/CUDAGuard.h>
4
5
6
7

#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>

8
9
#include "../cuda_compat.h"
#include "../dispatch_utils.h"
10

11
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
12
13

namespace vllm {
14
namespace moe {
15
16

namespace {
17
18
19
20
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
                                         int32_t col) {
  // don't worry about overflow because num_experts is relatively small
  return row * total_col + col;
21
}
22
}  // namespace
23

24
template <typename scalar_t, typename token_cnts_t>
25
26
27
28
29
30
31
32
33
34
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
                                            int32_t* sorted_token_ids,
                                            int32_t* expert_ids,
                                            int32_t* total_tokens_post_pad,
                                            int32_t num_experts,
                                            int32_t block_size, size_t numel) {
  const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
  const size_t start_idx = threadIdx.x * tokens_per_thread;

  extern __shared__ int32_t shared_mem[];
35
  int32_t* cumsum = shared_mem;  // 1d tensor with shape (num_experts + 1)
36
37
38
  token_cnts_t* tokens_cnts =
      (token_cnts_t*)(shared_mem + num_experts +
                      1);  // 2d tensor with shape (blockDim.x + 1, num_experts)
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

  for (int i = 0; i < num_experts; ++i) {
    tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
  }

  /**
   * In the first step we compute token_cnts[thread_index + 1][expert_index],
   * which counts how many tokens in the token shard of thread_index are
   * assigned to expert expert_index.
   */
  for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
    ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
  }

  __syncthreads();

  // For each expert we accumulate the token counts from the different threads.
56
57
58
59
60
61
  if (threadIdx.x < num_experts) {
    tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
    for (int i = 1; i <= blockDim.x; ++i) {
      tokens_cnts[index(num_experts, i, threadIdx.x)] +=
          tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
    }
62
63
64
65
66
67
68
69
70
71
72
73
  }

  __syncthreads();

  // We accumulate the token counts of all experts in thread 0.
  if (threadIdx.x == 0) {
    cumsum[0] = 0;
    for (int i = 1; i <= num_experts; ++i) {
      cumsum[i] = cumsum[i - 1] +
                  CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
                          block_size) *
                      block_size;
74
    }
75
    *total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]);
76
77
78
79
80
81
82
83
  }

  __syncthreads();

  /**
   * For each expert, each thread processes the tokens of the corresponding
   * blocks and stores the corresponding expert_id for each block.
   */
84
85
86
87
88
  if (threadIdx.x < num_experts) {
    for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
         i += block_size) {
      expert_ids[i / block_size] = threadIdx.x;
    }
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
  }

  /**
   * Each thread processes a token shard, calculating the index of each token
   * after sorting by expert number. Given the example topk_ids =
   * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
   * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
   * padding value(preset in python).
   */
  for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
    int32_t expert_id = topk_ids[i];
    /** The cumsum[expert_id] stores the starting index of the tokens that the
     * expert with expert_id needs to process, and
     * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
     * processed by the expert with expert_id within the current thread's token
     * shard.
     */
    int32_t rank_post_pad =
        tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
        cumsum[expert_id];
    sorted_token_ids[rank_post_pad] = i;
    ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
  }
112
}
113

Simon Mo's avatar
Simon Mo committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
// TODO(simon): this is temporarily adapted from
// https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7
// we did this to unblock Deepseek V3 but there should be a better
// implementation to manage shared memory.
template <typename scalar_t>
__global__ void moe_align_block_size_global_mem_kernel(
    scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
    int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
    int32_t block_size, size_t numel, int32_t* tokens_cnts, int32_t* cumsum) {
  const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
  const size_t start_idx = threadIdx.x * tokens_per_thread;

  for (int i = 0; i < num_experts; ++i) {
    tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
  }

  /**
   * In the first step we compute token_cnts[thread_index + 1][expert_index],
   * which counts how many tokens in the token shard of thread_index are
   * assigned to expert expert_index.
   */
  for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
    ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
  }

  __syncthreads();

  // For each expert we accumulate the token counts from the different threads.
  if (threadIdx.x < num_experts) {
    tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
    for (int i = 1; i <= blockDim.x; ++i) {
      tokens_cnts[index(num_experts, i, threadIdx.x)] +=
          tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
    }
  }

  __syncthreads();

  // We accumulate the token counts of all experts in thread 0.
  if (threadIdx.x == 0) {
    cumsum[0] = 0;
    for (int i = 1; i <= num_experts; ++i) {
      cumsum[i] = cumsum[i - 1] +
                  CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
                          block_size) *
                      block_size;
    }
    *total_tokens_post_pad = cumsum[num_experts];
  }

  __syncthreads();

  /**
   * For each expert, each thread processes the tokens of the corresponding
   * blocks and stores the corresponding expert_id for each block.
   */
  if (threadIdx.x < num_experts) {
    for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
         i += block_size) {
      expert_ids[i / block_size] = threadIdx.x;
    }
  }

  /**
   * Each thread processes a token shard, calculating the index of each token
   * after sorting by expert number. Given the example topk_ids =
   * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
   * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
   * padding value(preset in python).
   */
  for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
    int32_t expert_id = topk_ids[i];
    /** The cumsum[expert_id] stores the starting index of the tokens that the
     * expert with expert_id needs to process, and
     * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
     * processed by the expert with expert_id within the current thread's token
     * shard.
     */
    int32_t rank_post_pad =
        tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
        cumsum[expert_id];
    sorted_token_ids[rank_post_pad] = i;
    ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
  }
}

200
201
202
203
204
205
206
207
208
209
// taken from
// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
template <typename scalar_t>
__global__ void sgl_moe_align_block_size_kernel(
    scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
    int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
    int32_t block_size, size_t numel, int32_t* cumsum) {
  __shared__ int32_t shared_counts[32][8];
  __shared__ int32_t local_offsets[256];

210
211
  const int warp_id = threadIdx.x / 32;
  const int lane_id = threadIdx.x % 32;
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
  const int experts_per_warp = 8;
  const int my_expert_start = warp_id * experts_per_warp;

  for (int i = 0; i < experts_per_warp; ++i) {
    if (my_expert_start + i < num_experts) {
      shared_counts[warp_id][i] = 0;
    }
  }

  const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
  const size_t start_idx = threadIdx.x * tokens_per_thread;

  for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
    int expert_id = topk_ids[i];
    int warp_idx = expert_id / experts_per_warp;
    int expert_offset = expert_id % experts_per_warp;
    atomicAdd(&shared_counts[warp_idx][expert_offset], 1);
  }

  __syncthreads();

  if (threadIdx.x == 0) {
    cumsum[0] = 0;
    for (int i = 1; i <= num_experts; ++i) {
      int expert_count = 0;
      int warp_idx = (i - 1) / experts_per_warp;
      int expert_offset = (i - 1) % experts_per_warp;
      expert_count = shared_counts[warp_idx][expert_offset];

      cumsum[i] =
          cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
    }
    *total_tokens_post_pad = cumsum[num_experts];
  }

  __syncthreads();

  if (threadIdx.x < num_experts) {
    for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
         i += block_size) {
      expert_ids[i / block_size] = threadIdx.x;
    }
    local_offsets[threadIdx.x] = cumsum[threadIdx.x];
  }

  __syncthreads();

  for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
    int32_t expert_id = topk_ids[i];
    int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1);
    sorted_token_ids[rank_post_pad] = i;
  }
}

266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
template <typename scalar_t, int TOPK>
__global__ void moe_sum_kernel(
    scalar_t* __restrict__ out,          // [..., d]
    const scalar_t* __restrict__ input,  // [..., topk, d]
    const int d) {
  const int64_t token_idx = blockIdx.x;
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
    scalar_t x = 0.0;
#pragma unroll
    for (int k = 0; k < TOPK; ++k) {
      x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]);
    }
    out[token_idx * d + idx] = x;
  }
}

王敏's avatar
王敏 committed
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378

template <typename scalar_t, typename token_cnts_t>
__global__ void ep_moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
                                            int32_t* sorted_token_ids,
                                            int32_t* expert_ids,
                                            int32_t* total_tokens_post_pad,
                                            int32_t num_experts,
                                            int32_t block_size, size_t numel,
                                            int32_t start_expert, int32_t end_expert) {
  const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
  const size_t start_idx = threadIdx.x * tokens_per_thread;

  extern __shared__ int32_t shared_mem[];
  int32_t* cumsum = shared_mem;  // 1d tensor with shape (num_experts + 1)
  token_cnts_t* tokens_cnts =
      (token_cnts_t*)(shared_mem + num_experts +
                      1);  // 2d tensor with shape (blockDim.x + 1, num_experts)

  for (int i = 0; i < num_experts; ++i) {
    tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
  }

  /**
   * In the first step we compute token_cnts[thread_index + 1][expert_index],
   * which counts how many tokens in the token shard of thread_index are
   * assigned to expert expert_index.
   */
  for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
    if (topk_ids[i] >= start_expert && topk_ids[i] < end_expert) {
      ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i] - start_expert)];
    }
  }

  __syncthreads();

  // For each expert we accumulate the token counts from the different threads.
  if (threadIdx.x < num_experts) {
    tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
    for (int i = 1; i <= blockDim.x; ++i) {
      tokens_cnts[index(num_experts, i, threadIdx.x)] +=
          tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
    }
  }

  __syncthreads();

  // We accumulate the token counts of all experts in thread 0.
  if (threadIdx.x == 0) {
    cumsum[0] = 0;
    for (int i = 1; i <= num_experts; ++i) {
      cumsum[i] = cumsum[i - 1] +
                  CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
                          block_size) *
                      block_size;
    }
    *total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]);
  }

  __syncthreads();

  /**
   * For each expert, each thread processes the tokens of the corresponding
   * blocks and stores the corresponding expert_id for each block.
   */
  if (threadIdx.x < num_experts) {
    for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
         i += block_size) {
      expert_ids[i / block_size] = threadIdx.x;
    }
  }

  /**
   * Each thread processes a token shard, calculating the index of each token
   * after sorting by expert number. Given the example topk_ids =
   * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
   * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
   * padding value(preset in python).
   */
  for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
    int32_t expert_id = topk_ids[i];
    if (expert_id >= start_expert && expert_id < end_expert) {
      expert_id -= start_expert;
      /** The cumsum[expert_id] stores the starting index of the tokens that the
       * expert with expert_id needs to process, and
       * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
       * processed by the expert with expert_id within the current thread's token
       * shard.
       */
      int32_t rank_post_pad =
          tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
          cumsum[expert_id];
      sorted_token_ids[rank_post_pad] = i;
      ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
    }
  }
}

379
}  // namespace moe
380
381
}  // namespace vllm

382
383
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
                          int64_t block_size, torch::Tensor sorted_token_ids,
384
385
386
                          torch::Tensor experts_ids,
                          torch::Tensor num_tokens_post_pad) {
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Simon Mo's avatar
Simon Mo committed
387

388
389
390
391
392
393
394
395
396
397
398
399
400
  int device_max_shared_mem;
  auto dev = topk_ids.get_device();
  cudaDeviceGetAttribute(&device_max_shared_mem,
                         cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);

  const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
  const int32_t shared_mem_i32 =
      ((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
  const int32_t shared_mem_i16 =
      ((num_thread + 1) * num_experts) * sizeof(uint16_t) +
      (num_experts + 1) * sizeof(int32_t);

  bool use_global_memory = false;
401
402
403
404
  bool use_i16 = false;  // Use uint16_t for shared memory token counts
  if (shared_mem_i32 < device_max_shared_mem) {
    // Do nothing in this case. We're all set to use int32_t token counts
  } else if (shared_mem_i16 < device_max_shared_mem &&
405
406
407
408
409
             topk_ids.numel() <= 65535) {
    // when nelements of topk_ids is smaller than 65535 (max value of uint16),
    // element value of token_cnts would also smaller than 65535,
    // so we can use uint16 as dtype of token_cnts
    use_i16 = true;
410
411
  } else {
    use_global_memory = true;
412
413
414
  }

  if (use_global_memory) {
Simon Mo's avatar
Simon Mo committed
415
416
417
418
419
420
    VLLM_DISPATCH_INTEGRAL_TYPES(
        topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
          // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
          // tensors
          const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);

421
422
423
424
425
426
427
          auto options_int = torch::TensorOptions()
                                 .dtype(torch::kInt)
                                 .device(topk_ids.device());
          torch::Tensor token_cnts_buffer =
              torch::empty({(num_experts + 1) * num_experts}, options_int);
          torch::Tensor cumsum_buffer =
              torch::empty({num_experts + 1}, options_int);
Simon Mo's avatar
Simon Mo committed
428
429
430
431
432
433
434
435

          auto kernel =
              vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
          kernel<<<1, num_thread, 0, stream>>>(
              topk_ids.data_ptr<scalar_t>(),
              sorted_token_ids.data_ptr<int32_t>(),
              experts_ids.data_ptr<int32_t>(),
              num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
436
437
              topk_ids.numel(), token_cnts_buffer.data_ptr<int32_t>(),
              cumsum_buffer.data_ptr<int32_t>());
Simon Mo's avatar
Simon Mo committed
438
        });
439
  } else if (use_i16) {
Simon Mo's avatar
Simon Mo committed
440
441
    VLLM_DISPATCH_INTEGRAL_TYPES(
        topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
442
          // set dynamic shared mem
443
444
          auto kernel =
              vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>;
445
          AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
446
447
448
449
              (void*)kernel, shared_mem_i16));
          kernel<<<1, num_thread, shared_mem_i16, stream>>>(
              topk_ids.data_ptr<scalar_t>(),
              sorted_token_ids.data_ptr<int32_t>(),
450
              experts_ids.data_ptr<int32_t>(),
451
              num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
zhuwenwen's avatar
zhuwenwen committed
452
              topk_ids.numel());
453
454
455
456
457
458
        });
  } else {
    VLLM_DISPATCH_INTEGRAL_TYPES(
        topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
          auto kernel =
              vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
459
          AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
460
461
              (void*)kernel, shared_mem_i32));
          kernel<<<1, num_thread, shared_mem_i32, stream>>>(
Simon Mo's avatar
Simon Mo committed
462
463
              topk_ids.data_ptr<scalar_t>(),
              sorted_token_ids.data_ptr<int32_t>(),
464
              experts_ids.data_ptr<int32_t>(),
Simon Mo's avatar
Simon Mo committed
465
              num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
466
              topk_ids.numel());
Simon Mo's avatar
Simon Mo committed
467
468
        });
  }
469
}
470

王敏's avatar
王敏 committed
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
void ep_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
                          int64_t block_size, torch::Tensor sorted_token_ids,
                          torch::Tensor experts_ids,
                          torch::Tensor num_tokens_post_pad,
                          int64_t start_expert, int64_t end_expert) {
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  int device_max_shared_mem;
  auto dev = topk_ids.get_device();
  cudaDeviceGetAttribute(&device_max_shared_mem,
                         cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);

  const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
  const int32_t shared_mem_i32 =
      ((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
  const int32_t shared_mem_i16 =
      ((num_thread + 1) * num_experts) * sizeof(uint16_t) +
      (num_experts + 1) * sizeof(int32_t);

  // bool use_global_memory = false;
  // bool use_i16 = false;  // Use uint16_t for shared memory token counts
  // if (shared_mem_i32 < device_max_shared_mem) {
  //   // Do nothing in this case. We're all set to use int32_t token counts
  // } else if (shared_mem_i16 < device_max_shared_mem &&
  //            topk_ids.numel() <= 65535) {
  //   // when nelements of topk_ids is smaller than 65535 (max value of uint16),
  //   // element value of token_cnts would also smaller than 65535,
  //   // so we can use uint16 as dtype of token_cnts
  //   use_i16 = true;
  // } else {
  //   use_global_memory = true;
  // }

  // if (use_global_memory) {
  //   VLLM_DISPATCH_INTEGRAL_TYPES(
  //       topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
  //         // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
  //         // tensors
  //         const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);

  //         auto options_int = torch::TensorOptions()
  //                                .dtype(torch::kInt)
  //                                .device(topk_ids.device());
  //         torch::Tensor token_cnts_buffer =
  //             torch::empty({(num_experts + 1) * num_experts}, options_int);
  //         torch::Tensor cumsum_buffer =
  //             torch::empty({num_experts + 1}, options_int);

  //         auto kernel =
  //             vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
  //         kernel<<<1, num_thread, 0, stream>>>(
  //             topk_ids.data_ptr<scalar_t>(),
  //             sorted_token_ids.data_ptr<int32_t>(),
  //             experts_ids.data_ptr<int32_t>(),
  //             num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
  //             topk_ids.numel(), token_cnts_buffer.data_ptr<int32_t>(),
  //             cumsum_buffer.data_ptr<int32_t>());
  //       });
  // } else if (use_i16) {
  //   VLLM_DISPATCH_INTEGRAL_TYPES(
  //       topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
  //         // set dynamic shared mem
  //         auto kernel =
  //             vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>;
  //         AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
  //             (void*)kernel, shared_mem_i16));
  //         kernel<<<1, num_thread, shared_mem_i16, stream>>>(
  //             topk_ids.data_ptr<scalar_t>(),
  //             sorted_token_ids.data_ptr<int32_t>(),
  //             experts_ids.data_ptr<int32_t>(),
  //             num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
  //             topk_ids.numel());
  //       });
  // } else {
  //   VLLM_DISPATCH_INTEGRAL_TYPES(
  //       topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
  //         auto kernel =
  //             vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
  //         AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
  //             (void*)kernel, shared_mem_i32));
  //         kernel<<<1, num_thread, shared_mem_i32, stream>>>(
  //             topk_ids.data_ptr<scalar_t>(),
  //             sorted_token_ids.data_ptr<int32_t>(),
  //             experts_ids.data_ptr<int32_t>(),
  //             num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
  //             topk_ids.numel());
  //       });
  // }
  VLLM_DISPATCH_INTEGRAL_TYPES(
        topk_ids.scalar_type(), "ep_moe_align_block_size_kernel", [&] {
          auto kernel =
              vllm::moe::ep_moe_align_block_size_kernel<scalar_t, int32_t>;
          AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
              (void*)kernel, shared_mem_i32));
          kernel<<<1, num_thread, shared_mem_i32, stream>>>(
              topk_ids.data_ptr<scalar_t>(),
              sorted_token_ids.data_ptr<int32_t>(),
              experts_ids.data_ptr<int32_t>(),
              num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
              topk_ids.numel(), start_expert, end_expert);
        });
}

574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
                              int64_t block_size,
                              torch::Tensor sorted_token_ids,
                              torch::Tensor experts_ids,
                              torch::Tensor num_tokens_post_pad) {
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  VLLM_DISPATCH_INTEGRAL_TYPES(
      topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
        // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
        // tensors
        auto options_int =
            torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
        // torch::Tensor token_cnts_buffer =
        //     torch::empty({(num_experts + 1) * num_experts}, options_int);
        torch::Tensor cumsum_buffer =
            torch::empty({num_experts + 1}, options_int);

        auto kernel = vllm::moe::sgl_moe_align_block_size_kernel<scalar_t>;
        kernel<<<1, 1024, 0, stream>>>(
            topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
            experts_ids.data_ptr<int32_t>(),
            num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
            topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
      });
}

600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
void moe_sum(torch::Tensor& input,   // [num_tokens, topk, hidden_size]
             torch::Tensor& output)  // [num_tokens, hidden_size]
{
  const int hidden_size = input.size(-1);
  const int num_tokens = output.numel() / hidden_size;
  const int topk = input.size(1);

  dim3 grid(num_tokens);
  dim3 block(std::min(hidden_size, 1024));
  const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  switch (topk) {
    case 2:
      VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
        vllm::moe::moe_sum_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
            output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
            hidden_size);
      });
      break;

    case 3:
      VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
        vllm::moe::moe_sum_kernel<scalar_t, 3><<<grid, block, 0, stream>>>(
            output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
            hidden_size);
626
      });
627
628
629
630
631
632
633
634
635
636
637
638
639
640
      break;

    case 4:
      VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
        vllm::moe::moe_sum_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
            output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
            hidden_size);
      });
      break;

    default:
      at::sum_out(output, input, 1);
      break;
  }
zhuwenwen's avatar
zhuwenwen committed
641
}