moe_align_kernel.cu 4.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
/* Copyright 2025 SGLang Team. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

16
17
18
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
19
#include <torch/extension.h>
20
21
22

#include <THC/THCAtomics.cuh>

23
#include "utils.h"
24

25
#define WARP_SIZE 32
26

27
template <typename scalar_t>
28
29
30
31
32
__global__ void count_and_sort_expert_tokens_kernel(
    const scalar_t* __restrict__ topk_ids,
    int32_t* __restrict__ sorted_token_ids,
    int32_t* __restrict__ cumsum_buffer,
    size_t numel) {
33
34
  const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
35

36
37
38
39
  for (size_t i = tid; i < numel; i += stride) {
    int32_t expert_id = topk_ids[i];
    int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
    sorted_token_ids[rank_post_pad] = i;
40
41
  }
}
42

43
template <typename scalar_t>
44
45
46
47
48
49
__global__ void moe_align_block_size_kernel(
    const scalar_t* __restrict__ topk_ids,
    int32_t* __restrict__ sorted_token_ids,
    int32_t* __restrict__ expert_ids,
    int32_t* __restrict__ total_tokens_post_pad,
    int32_t num_experts,
50
    int32_t experts_per_warp,
51
52
53
    int32_t block_size,
    size_t numel,
    int32_t* __restrict__ cumsum) {
54
  extern __shared__ int32_t shared_counts[];
55

56
57
  const int warp_id = threadIdx.x / WARP_SIZE;
  const int my_expert_start = warp_id * experts_per_warp;
58

59
60
  for (int i = 0; i < experts_per_warp; ++i) {
    if (my_expert_start + i < num_experts) {
61
      shared_counts[warp_id * experts_per_warp + i] = 0;
62
63
    }
  }
64

65
  __syncthreads();
66

67
68
  const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
  const size_t start_idx = threadIdx.x * tokens_per_thread;
69

70
71
72
73
  for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
    int expert_id = topk_ids[i];
    int warp_idx = expert_id / experts_per_warp;
    int expert_offset = expert_id % experts_per_warp;
74
    atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1);
75
76
77
78
79
  }

  __syncthreads();

  if (threadIdx.x == 0) {
80
81
82
    cumsum[0] = 0;
    for (int i = 1; i <= num_experts; ++i) {
      int expert_count = 0;
83
84
      int warp_idx = (i - 1) / experts_per_warp;
      int expert_offset = (i - 1) % experts_per_warp;
85
      expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];
86

87
      cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
88
    }
89
    *total_tokens_post_pad = cumsum[num_experts];
90
91
92
93
94
  }

  __syncthreads();

  if (threadIdx.x < num_experts) {
95
96
    for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) {
      expert_ids[i / block_size] = threadIdx.x;
97
98
99
100
    }
  }
}

101
102
103
104
105
106
107
108
109
void moe_align_block_size(
    torch::Tensor topk_ids,
    int64_t num_experts,
    int64_t block_size,
    torch::Tensor sorted_token_ids,
    torch::Tensor experts_ids,
    torch::Tensor num_tokens_post_pad,
    torch::Tensor token_cnts_buffer,
    torch::Tensor cumsum_buffer) {
110
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
111
112
  TORCH_CHECK(num_experts % WARP_SIZE == 0);
  int experts_per_warp = num_experts / WARP_SIZE;
113
  DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
114
    auto align_kernel = moe_align_block_size_kernel<scalar_t>;
115
116
    size_t shared_mem_size = 32 * experts_per_warp * sizeof(int32_t);
    align_kernel<<<1, 1024, shared_mem_size, stream>>>(
117
118
119
120
121
        topk_ids.data_ptr<scalar_t>(),
        sorted_token_ids.data_ptr<int32_t>(),
        experts_ids.data_ptr<int32_t>(),
        num_tokens_post_pad.data_ptr<int32_t>(),
        num_experts,
122
        experts_per_warp,
123
124
125
        block_size,
        topk_ids.numel(),
        cumsum_buffer.data_ptr<int32_t>());
126
127

    const int block_threads = 256;
128
129
130
131
132
    const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
    const int max_blocks = 65535;
    const int actual_blocks = std::min(num_blocks, max_blocks);

    auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
133
134
135
136
137
    sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
        topk_ids.data_ptr<scalar_t>(),
        sorted_token_ids.data_ptr<int32_t>(),
        cumsum_buffer.data_ptr<int32_t>(),
        topk_ids.numel());
138
139
  });
}