moe_align_kernel.cu 11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
/* Copyright 2025 SGLang Team. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

16
17
18
19
20
21
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <THC/THCAtomics.cuh>

22
#include "utils.h"
23

maxiao's avatar
maxiao committed
24
#define WARP_SIZE 64
25
#define VEC_SIZE 4
26
using Vec = int4;
27

28
template <typename scalar_t>
29
30
31
32
33
__global__ void count_and_sort_expert_tokens_kernel(
    const scalar_t* __restrict__ topk_ids,
    int32_t* __restrict__ sorted_token_ids,
    int32_t* __restrict__ cumsum_buffer,
    size_t numel) {
34
35
  const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
36

37
  for (size_t i = tid; i < numel; i += stride) {
38
    int32_t expert_id = topk_ids[i] + 1;
39
40
    int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
    sorted_token_ids[rank_post_pad] = i;
41
42
  }
}
43

44
45
46
47
48
#ifdef __CUDA_ARCH__
__device__ __forceinline__ int warp_exclusive_scan(int v, unsigned mask = 0xffffffffu) {
  int original = v;
#pragma unroll
  for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
maxiao's avatar
maxiao committed
49
    int n = __shfl_up(v, offset);
50
51
52
53
54
55
    if ((threadIdx.x & (WARP_SIZE - 1)) >= offset) v += n;
  }
  return v - original;
}
#endif

56
template <typename scalar_t>
57
58
59
60
61
62
63
64
__global__ void moe_align_block_size_kernel(
    const scalar_t* __restrict__ topk_ids,
    int32_t* __restrict__ sorted_token_ids,
    int32_t* __restrict__ expert_ids,
    int32_t* __restrict__ total_tokens_post_pad,
    int32_t num_experts,
    int32_t block_size,
    size_t numel,
65
    int32_t* __restrict__ cumsum,
66
67
68
69
70
71
72
    bool pad_sorted_token_ids,
    const int32_t scan_size) {
  extern __shared__ int32_t smem[];
  int32_t* shared_counts = smem;                  // [num_experts]
  int32_t* prefix = shared_counts + num_experts;  // [num_experts + 1]
  int32_t* scan_buf = prefix + num_experts + 1;   // [scan_size]
  __shared__ int32_t s_total_tokens_post_pad;
73

74
75
  const size_t tid = threadIdx.x;
  const size_t stride = blockDim.x;
76

77
78
  if (tid < num_experts) {
    shared_counts[tid] = 0;
79
  }
80

81
  __syncthreads();
82

83
  for (size_t i = tid; i < numel; i += stride) {
84
    int expert_id = topk_ids[i] + 1;
85
    atomicAdd(&shared_counts[expert_id], 1);
86
87
88
89
  }

  __syncthreads();

90
91
92
93
94
95
  int32_t padded_count = 0;
  if (tid < num_experts) {
    int32_t count = shared_counts[tid];
    padded_count = (count + block_size - 1) / block_size * block_size;
    scan_buf[tid] = padded_count;
  }
96

97
98
#ifndef __CUDA_ARCH__  // HIP

99
100
101
102
  if (tid >= num_experts && tid < scan_size) {
    scan_buf[tid] = 0;
  }

103
104
  __syncthreads();

105
106
107
108
109
110
111
112
113
114
115
  // Blelloch scan
  int offset = 1;
#pragma unroll
  for (int d = scan_size >> 1; d > 0; d >>= 1) {
    if (tid < d) {
      int ai = offset * (2 * tid + 1) - 1;
      int bi = offset * (2 * tid + 2) - 1;
      scan_buf[bi] += scan_buf[ai];
    }
    offset <<= 1;
    __syncthreads();
116
  }
117

118
  // down-sweep
119
  if (tid == 0) {
120
121
    prefix[num_experts] = scan_buf[scan_size - 1];
    scan_buf[scan_size - 1] = 0;
122
123
  }
  __syncthreads();
124

125
126
127
128
129
130
131
132
133
134
135
136
137
138
#pragma unroll
  for (int d = 1; d < scan_size; d <<= 1) {
    offset >>= 1;
    if (tid < d) {
      int ai = offset * (2 * tid + 1) - 1;
      int bi = offset * (2 * tid + 2) - 1;
      if (bi < scan_size) {
        int temp = scan_buf[ai];
        scan_buf[ai] = scan_buf[bi];
        scan_buf[bi] += temp;
      }
    }
    __syncthreads();
  }
139

140
141
142
  if (tid < num_experts) {
    prefix[tid] = scan_buf[tid];
  }
143

144
145
146
  if (tid == 0) {
    s_total_tokens_post_pad = prefix[num_experts];
    *total_tokens_post_pad = s_total_tokens_post_pad;
147
  }
148
  __syncthreads();
149

150
151
152
153
154
155
156
157
158
#else  // CUDA

  // Intra warp prefix sum
  int32_t* warp_sums = scan_buf + scan_size;  // [<= 32]
  const int warp_id = tid / WARP_SIZE;
  const int lane_id = tid & (WARP_SIZE - 1);
  const int num_warps_for_scan = (scan_size + WARP_SIZE - 1) / WARP_SIZE;
  const int warp_sum = warp_exclusive_scan(padded_count) + padded_count;
  if (lane_id == WARP_SIZE - 1) warp_sums[warp_id] = warp_sum;
159
160
  __syncthreads();

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
  // warp0 accumulate all the block's prefix sum
  if (tid < WARP_SIZE) {
    int val = (tid < num_warps_for_scan) ? warp_sums[tid] : 0;
    int incl = warp_exclusive_scan(val) + val;
    warp_sums[tid] = incl;
  }
  __syncthreads();

  // Every thread obtains the whole block's sum
  if (tid == 0) {
    prefix[num_experts] = warp_sums[num_warps_for_scan - 1];
    s_total_tokens_post_pad = prefix[num_experts];
    *total_tokens_post_pad = s_total_tokens_post_pad;
  }
  __syncthreads();

  // Fill 0 to scan_buf extended area (tid >= num_expert)
  if (tid >= num_experts && tid < scan_size) scan_buf[tid] = 0;
  __syncthreads();

  // Perform 2 level exclusive-prefix-sum to scan_buf
  int v = (tid < scan_size) ? scan_buf[tid] : 0;
  int pre = warp_exclusive_scan(v);
  if (lane_id == WARP_SIZE - 1) warp_sums[warp_id] = pre + v;
  __syncthreads();

  if (warp_id == 0) {
    int val = (lane_id < num_warps_for_scan) ? warp_sums[lane_id] : 0;
    warp_sums[lane_id] = warp_exclusive_scan(val);
  }
  __syncthreads();

  int offset = warp_sums[warp_id];
  if (tid < scan_size) scan_buf[tid] = pre + offset;
  __syncthreads();

  // Write prefix[0..num_experts - 1] and cumsum
  if (tid < num_experts) prefix[tid] = scan_buf[tid];
#endif

201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
  if (tid <= num_experts) {
    cumsum[tid] = prefix[tid];
  }
  // fill expert_ids
  const int32_t num_blocks = s_total_tokens_post_pad / block_size;
  for (int32_t i = tid; i < num_blocks; i += stride) {
    int32_t block_start = i * block_size;
    int left = 0, right = num_experts;
    while (left < right) {
      int mid = (left + right) >> 1;
      if (prefix[mid] <= block_start) {
        left = mid + 1;
      } else {
        right = mid;
      }
    }
217
    expert_ids[i] = left - 2;
218
219
220
221
222
223
224
225
226
  }

  if (pad_sorted_token_ids) {
    Vec fill_vec;
    fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel;
    int32_t total_vecs = (s_total_tokens_post_pad + VEC_SIZE - 1) / VEC_SIZE;
    Vec* out_ptr = reinterpret_cast<Vec*>(sorted_token_ids);
    for (int32_t i = tid; i < total_vecs; i += stride) {
      out_ptr[i] = fill_vec;
227
228
    }
  }
229
230
}

231
232
233
234
235
236
237
238
template <typename scalar_t>
__global__ void moe_align_block_size_small_batch_expert_kernel(
    const scalar_t* __restrict__ topk_ids,
    int32_t* __restrict__ sorted_token_ids,
    int32_t* __restrict__ expert_ids,
    int32_t* __restrict__ total_tokens_post_pad,
    int32_t num_experts,
    int32_t block_size,
239
240
    size_t numel,
    bool pad_sorted_token_ids) {
241
242
243
244
245
246
247
248
249
250
251
252
  const size_t tid = threadIdx.x;
  const size_t stride = blockDim.x;

  extern __shared__ int32_t shared_mem[];
  int32_t* cumsum = shared_mem;
  int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1);

  for (int i = 0; i < num_experts; ++i) {
    tokens_cnts[(threadIdx.x + 1) * num_experts + i] = 0;
  }

  for (size_t i = tid; i < numel; i += stride) {
253
    ++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i] + 1];
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
  }

  __syncthreads();

  if (threadIdx.x < num_experts) {
    tokens_cnts[threadIdx.x] = 0;
    for (int i = 1; i <= blockDim.x; ++i) {
      tokens_cnts[i * num_experts + threadIdx.x] += tokens_cnts[(i - 1) * num_experts + threadIdx.x];
    }
  }

  __syncthreads();

  if (threadIdx.x == 0) {
    cumsum[0] = 0;
    for (int i = 1; i <= num_experts; ++i) {
      cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[blockDim.x * num_experts + i - 1], block_size) * block_size;
    }
    *total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]);
  }

  __syncthreads();

  if (threadIdx.x < num_experts) {
    for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) {
279
      expert_ids[i / block_size] = threadIdx.x - 1;
280
281
282
    }
  }

283
284
  if (pad_sorted_token_ids) {
    Vec fill_vec;
285
286
    fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel;
    int32_t total_vecs = (*total_tokens_post_pad + VEC_SIZE - 1) / VEC_SIZE;
287
    Vec* out_ptr = reinterpret_cast<Vec*>(sorted_token_ids);
288
289
    for (int32_t i = tid; i < total_vecs; i += stride) {
      out_ptr[i] = fill_vec;
290
291
292
293
294
    }
  }

  __syncthreads();

295
  for (size_t i = tid; i < numel; i += stride) {
296
    int32_t expert_id = topk_ids[i] + 1;
297
298
299
300
301
302
    int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id];
    sorted_token_ids[rank_post_pad] = i;
    ++tokens_cnts[threadIdx.x * num_experts + expert_id];
  }
}

303
304
305
306
307
308
309
void moe_align_block_size(
    torch::Tensor topk_ids,
    int64_t num_experts,
    int64_t block_size,
    torch::Tensor sorted_token_ids,
    torch::Tensor experts_ids,
    torch::Tensor num_tokens_post_pad,
310
311
    torch::Tensor cumsum_buffer,
    bool pad_sorted_token_ids) {
312
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
313

314
  int threads = 1024;
315
316
317

  threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;

318
  DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64);

    if (small_batch_expert_mode) {
      const int32_t threads = max((int32_t)num_experts, WARP_SIZE);
      const int32_t shared_mem_size = ((threads + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);

      auto small_batch_expert_kernel = moe_align_block_size_small_batch_expert_kernel<scalar_t>;
      small_batch_expert_kernel<<<1, threads, shared_mem_size, stream>>>(
          topk_ids.data_ptr<scalar_t>(),
          sorted_token_ids.data_ptr<int32_t>(),
          experts_ids.data_ptr<int32_t>(),
          num_tokens_post_pad.data_ptr<int32_t>(),
          num_experts,
          block_size,
333
334
          topk_ids.numel(),
          pad_sorted_token_ids);
335
336
337
    } else {
      auto align_kernel = moe_align_block_size_kernel<scalar_t>;

338
      const size_t scan_size = next_pow2(num_experts);
339
      const size_t shared_mem_size = (num_experts + (num_experts + 1) + scan_size + WARP_SIZE) * sizeof(int32_t);
340
341
342
343
344
345
346
347
      align_kernel<<<1, threads, shared_mem_size, stream>>>(
          topk_ids.data_ptr<scalar_t>(),
          sorted_token_ids.data_ptr<int32_t>(),
          experts_ids.data_ptr<int32_t>(),
          num_tokens_post_pad.data_ptr<int32_t>(),
          num_experts,
          block_size,
          topk_ids.numel(),
348
          cumsum_buffer.data_ptr<int32_t>(),
349
350
          pad_sorted_token_ids,
          scan_size);
351
352
353
354
355
356
357
358
359
360
361
362
363

      const int block_threads = std::min(256, (int)threads);
      const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
      const int max_blocks = 65535;
      const int actual_blocks = std::min(num_blocks, max_blocks);

      auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
      sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
          topk_ids.data_ptr<scalar_t>(),
          sorted_token_ids.data_ptr<int32_t>(),
          cumsum_buffer.data_ptr<int32_t>(),
          topk_ids.numel());
    }
364
365
  });
}