moe_align_sum_kernels.cu 8.5 KB
Newer Older
1
#include <torch/all.h>
2
#include <ATen/cuda/CUDAContext.h>
3
#include <c10/cuda/CUDAGuard.h>
4
5
6
7

#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>

8
9
#include "../cuda_compat.h"
#include "../dispatch_utils.h"
10

11
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
12

13
14
#define MAX_SHARED_MEM_SIZE 64 * 1024

15
namespace vllm {
16
namespace moe {
17
18

namespace {
19
20
21
22
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
                                         int32_t col) {
  // don't worry about overflow because num_experts is relatively small
  return row * total_col + col;
23
}
24
}  // namespace
25

26
template <typename scalar_t, bool experts_num_exceed_limit>
27
28
29
30
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
                                            int32_t* sorted_token_ids,
                                            int32_t* expert_ids,
                                            int32_t* total_tokens_post_pad,
31
                                            int32_t* global_tokens_cnts_ptr,
32
33
34
35
36
37
38
                                            int32_t num_experts,
                                            int32_t block_size, size_t numel) {
  const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
  const size_t start_idx = threadIdx.x * tokens_per_thread;

  extern __shared__ int32_t shared_mem[];

39
40
41
  int32_t* tokens_cnts = nullptr;
  int32_t* cumsum = nullptr;
  if (experts_num_exceed_limit) {
zhuwenwen's avatar
zhuwenwen committed
42
    // 2d tensor with shape (blockDim.x + 1, num_experts)
43
44
45
46
47
    tokens_cnts = global_tokens_cnts_ptr;

    // 1d tensor with shape (num_experts + 1)
    cumsum = shared_mem;
  } else {
zhuwenwen's avatar
zhuwenwen committed
48
49
    tokens_cnts = shared_mem;  // 2d tensor with shape (blockDim.x + 1, num_experts)
    cumsum = shared_mem + (blockDim.x + 1) * num_experts;  // 1d tensor with shape (num_experts + 1)
50
  }
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

  for (int i = 0; i < num_experts; ++i) {
    tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
  }

  /**
   * In the first step we compute token_cnts[thread_index + 1][expert_index],
   * which counts how many tokens in the token shard of thread_index are
   * assigned to expert expert_index.
   */
  for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
    ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
  }

  __syncthreads();

  // For each expert we accumulate the token counts from the different threads.
68
69
70
71
72
73
  if (threadIdx.x < num_experts) {
    tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
    for (int i = 1; i <= blockDim.x; ++i) {
      tokens_cnts[index(num_experts, i, threadIdx.x)] +=
          tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
    }
74
75
76
77
78
79
80
81
82
83
84
85
  }

  __syncthreads();

  // We accumulate the token counts of all experts in thread 0.
  if (threadIdx.x == 0) {
    cumsum[0] = 0;
    for (int i = 1; i <= num_experts; ++i) {
      cumsum[i] = cumsum[i - 1] +
                  CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
                          block_size) *
                      block_size;
86
    }
87
88
89
90
91
92
93
94
95
    *total_tokens_post_pad = cumsum[num_experts];
  }

  __syncthreads();

  /**
   * For each expert, each thread processes the tokens of the corresponding
   * blocks and stores the corresponding expert_id for each block.
   */
96
97
98
99
100
  if (threadIdx.x < num_experts) {
    for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
         i += block_size) {
      expert_ids[i / block_size] = threadIdx.x;
    }
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
  }

  /**
   * Each thread processes a token shard, calculating the index of each token
   * after sorting by expert number. Given the example topk_ids =
   * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
   * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
   * padding value(preset in python).
   */
  for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
    int32_t expert_id = topk_ids[i];
    /** The cumsum[expert_id] stores the starting index of the tokens that the
     * expert with expert_id needs to process, and
     * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
     * processed by the expert with expert_id within the current thread's token
     * shard.
     */
    int32_t rank_post_pad =
        tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
        cumsum[expert_id];
    sorted_token_ids[rank_post_pad] = i;
    ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
  }
124
}
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

template <typename scalar_t, int TOPK>
__global__ void moe_sum_kernel(
    scalar_t* __restrict__ out,          // [..., d]
    const scalar_t* __restrict__ input,  // [..., topk, d]
    const int d) {
  const int64_t token_idx = blockIdx.x;
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
    scalar_t x = 0.0;
#pragma unroll
    for (int k = 0; k < TOPK; ++k) {
      x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]);
    }
    out[token_idx * d + idx] = x;
  }
}

}  // namespace moe
143
144
}  // namespace vllm

145
146
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
                          int64_t block_size, torch::Tensor sorted_token_ids,
147
148
149
150
151
                          torch::Tensor experts_ids,
                          torch::Tensor num_tokens_post_pad) {
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  VLLM_DISPATCH_INTEGRAL_TYPES(
      topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
152
        const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
zhuwenwen's avatar
zhuwenwen committed
153
        int32_t shared_mem_normal = ((num_thread + 1) * num_experts + (num_experts + 1)) *
154
155
156
157
158
159
160
161
162
              sizeof(int32_t);

        const bool experts_num_exceed_limit = shared_mem_normal > MAX_SHARED_MEM_SIZE;

        // calc needed amount of shared mem for `cumsum`
        const int32_t shared_mem = experts_num_exceed_limit ? (num_experts + 1) * sizeof(int32_t) : shared_mem_normal;

        if (experts_num_exceed_limit) {
          // set dynamic shared mem
zhuwenwen's avatar
zhuwenwen committed
163
          auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t, true>;
164
165
166
167
168
169
170
171
172
173
          AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
            (void*)kernel, shared_mem));

          int32_t tokens_cnts[(num_experts + 1) * num_experts];
          torch::Tensor key_cache_ptrs_tensor = torch::from_blob(tokens_cnts, {(num_experts + 1) * num_experts}, torch::kInt32)
              .to(topk_ids.device());

          kernel<<<1, num_experts, shared_mem, stream>>>(
              topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
              experts_ids.data_ptr<int32_t>(),
zhuwenwen's avatar
zhuwenwen committed
174
175
              num_tokens_post_pad.data_ptr<int32_t>(), key_cache_ptrs_tensor.data_ptr<int32_t>(), num_experts, block_size, 
              topk_ids.numel());
176
177
        } else {
          // set dynamic shared mem
zhuwenwen's avatar
zhuwenwen committed
178
          auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t, false>;
179
          AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
180
            (void*)kernel, shared_mem));
181
182
183
184
185
186
          kernel<<<1, num_experts, shared_mem, stream>>>(
              topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
              experts_ids.data_ptr<int32_t>(),
              num_tokens_post_pad.data_ptr<int32_t>(), nullptr, num_experts, block_size,
              topk_ids.numel());
        }
187
      });
188
}
189

190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
void moe_sum(torch::Tensor& input,   // [num_tokens, topk, hidden_size]
             torch::Tensor& output)  // [num_tokens, hidden_size]
{
  const int hidden_size = input.size(-1);
  const int num_tokens = output.numel() / hidden_size;
  const int topk = input.size(1);

  dim3 grid(num_tokens);
  dim3 block(std::min(hidden_size, 1024));
  const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  switch (topk) {
    case 2:
      VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
        vllm::moe::moe_sum_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
            output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
            hidden_size);
      });
      break;

    case 3:
      VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
        vllm::moe::moe_sum_kernel<scalar_t, 3><<<grid, block, 0, stream>>>(
            output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
            hidden_size);
216
      });
217
218
219
220
221
222
223
224
225
226
227
228
229
230
      break;

    case 4:
      VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
        vllm::moe::moe_sum_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
            output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
            hidden_size);
      });
      break;

    default:
      at::sum_out(output, input, 1);
      break;
  }
231
}