ep_moe_reorder_kernel.cu 5.72 KB
Newer Older
1
2
3
4
5
6
7
8
9
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <THC/THCAtomics.cuh>
#include <flashinfer/vec_dtypes.cuh>

#include "utils.h"

10
template <typename scalar_t>
11
__global__ void ep_pre_reorder_cuda_kernel(
12
13
    const scalar_t* __restrict__ input_ptr,
    scalar_t* __restrict__ gateup_input_ptr,
14
15
16
17
18
19
20
21
22
23
24
    const int* __restrict__ src2dst_ptr,
    const int* __restrict__ topk_ids_ptr,
    const float* __restrict__ a1_scales_ptr,
    int start_expert_id,
    int end_expert_id,
    int topk,
    int hidden_size,
    bool use_per_token_if_dynamic) {
  int token_idx = blockIdx.x;
  int tid = threadIdx.x;

25
  const scalar_t* src_ptr = input_ptr + int64_t(token_idx) * hidden_size;
26
27
28
  const int* token_src2dst = src2dst_ptr + token_idx * topk;
  const int* token_topk_ids = topk_ids_ptr + token_idx * topk;

29
30
31
32
33
34
  float scale = 1.0f;

  if (a1_scales_ptr != nullptr and use_per_token_if_dynamic) {
    scale = 1.0f / a1_scales_ptr[token_idx];
  }

35
36
37
38
39
40
41
42
43
44
45
  for (int k = 0; k < topk; ++k) {
    int expert_id = token_topk_ids[k];
    if (expert_id < start_expert_id || expert_id > end_expert_id) continue;

    if (a1_scales_ptr != nullptr) {
      if (!use_per_token_if_dynamic) {
        scale = 1.0f / a1_scales_ptr[expert_id - start_expert_id];
      }
    }

    int dst_idx = token_src2dst[k];
46
    scalar_t* dst_ptr = gateup_input_ptr + int64_t(dst_idx) * hidden_size;
47

48
49
    constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
    using vec_t = flashinfer::vec_t<scalar_t, vec_size>;
50

51
    int vec_elements = (hidden_size / vec_size) * vec_size;
52
53
54
55
56
57
    for (int idx = tid; idx < hidden_size / vec_size; idx += blockDim.x) {
      vec_t input_vec, output_vec;
      input_vec.cast_load(src_ptr + idx * vec_size);
#pragma unroll
      for (uint32_t i = 0; i < vec_size; ++i) {
        float val = static_cast<float>(input_vec[i]);
58
        output_vec[i] = static_cast<scalar_t>(val * scale);
59
60
61
      }
      output_vec.cast_store(dst_ptr + idx * vec_size);
    }
62
63
64
65
66

    for (int idx = vec_elements + tid; idx < hidden_size; idx += blockDim.x) {
      float val = static_cast<float>(src_ptr[idx]);
      dst_ptr[idx] = static_cast<scalar_t>(val * scale);
    }
67
68
69
  }
}

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
template <typename scalar_t>
__global__ void ep_post_reorder_cuda_kernel(
    const scalar_t* __restrict__ down_output_ptr,
    scalar_t* __restrict__ output_ptr,
    const int* __restrict__ src2dst_ptr,
    const int* __restrict__ topk_ids_ptr,
    const scalar_t* __restrict__ topk_weights_ptr,
    int start_expert_id,
    int end_expert_id,
    int topk,
    int hidden_size) {
  const int token_idx = blockIdx.x;
  const int tid = threadIdx.x;

  const int* token_src2dst = src2dst_ptr + token_idx * topk;
  const int* token_topk_ids = topk_ids_ptr + token_idx * topk;
  const scalar_t* token_topk_weights = topk_weights_ptr + token_idx * topk;

  scalar_t* dst_ptr = output_ptr + static_cast<int64_t>(token_idx) * hidden_size;

  constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
  using vec_t = flashinfer::vec_t<scalar_t, vec_size>;

  const int vec_iters = hidden_size / vec_size;
  for (int idx = tid; idx < vec_iters; idx += blockDim.x) {
    float acc[vec_size] = {0};

    for (int k = 0; k < topk; ++k) {
      const int expert_id = token_topk_ids[k];
      if (expert_id < start_expert_id || expert_id > end_expert_id) continue;
      const int src_row = token_src2dst[k];
      const scalar_t* src_ptr = down_output_ptr + static_cast<int64_t>(src_row) * hidden_size;
      const float weight = static_cast<float>(token_topk_weights[k]);

      vec_t src_vec;
      src_vec.cast_load(src_ptr + idx * vec_size);

#pragma unroll
      for (uint32_t i = 0; i < vec_size; ++i) {
        acc[i] += static_cast<float>(src_vec[i]) * weight;
      }
    }
    vec_t out_vec;
#pragma unroll
    for (uint32_t i = 0; i < vec_size; ++i)
      out_vec[i] = static_cast<scalar_t>(acc[i]);

    out_vec.cast_store(dst_ptr + idx * vec_size);
  }
}

121
122
123
124
125
126
127
128
129
130
void ep_moe_pre_reorder(
    torch::Tensor input,
    torch::Tensor gateup_input,
    torch::Tensor src2dst,
    torch::Tensor topk_ids,
    torch::Tensor a1_scales,
    int64_t start_expert_id,
    int64_t end_expert_id,
    int64_t topk,
    bool use_per_token_if_dynamic) {
131
132
  const int total_blocks = input.size(0);
  const int block_size = 512;
133
134
135
  dim3 grid(total_blocks);
  dim3 block(block_size);
  int hidden_size = input.size(1);
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150

  DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
    ep_pre_reorder_cuda_kernel<scalar_t><<<grid, block>>>(
        static_cast<scalar_t*>(input.data_ptr()),
        static_cast<scalar_t*>(gateup_input.data_ptr()),
        src2dst.data_ptr<int>(),
        topk_ids.data_ptr<int>(),
        a1_scales.defined() ? a1_scales.data_ptr<float>() : nullptr,
        start_expert_id,
        end_expert_id,
        topk,
        hidden_size,
        use_per_token_if_dynamic);
    return true;
  });
151
}
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181

void ep_moe_post_reorder(
    torch::Tensor down_output,
    torch::Tensor output,
    torch::Tensor src2dst,
    torch::Tensor topk_ids,
    torch::Tensor topk_weights,
    int64_t start_expert_id,
    int64_t end_expert_id,
    int64_t topk) {
  const int total_tokens = output.size(0);
  const int block_size = 512;
  dim3 grid(total_tokens);
  dim3 block(block_size);
  const int hidden_size = output.size(1);

  DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(down_output.scalar_type(), scalar_t, [&] {
    ep_post_reorder_cuda_kernel<scalar_t><<<grid, block>>>(
        static_cast<scalar_t*>(down_output.data_ptr()),
        static_cast<scalar_t*>(output.data_ptr()),
        src2dst.data_ptr<int>(),
        topk_ids.data_ptr<int>(),
        static_cast<scalar_t*>(topk_weights.data_ptr()),
        static_cast<int>(start_expert_id),
        static_cast<int>(end_expert_id),
        static_cast<int>(topk),
        hidden_size);
    return true;
  });
}