ep_moe_reorder_kernel.cu 3.07 KB
Newer Older
1
2
3
4
5
6
7
8
9
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <THC/THCAtomics.cuh>
#include <flashinfer/vec_dtypes.cuh>

#include "utils.h"

10
template <typename scalar_t>
11
__global__ void ep_pre_reorder_cuda_kernel(
12
13
    const scalar_t* __restrict__ input_ptr,
    scalar_t* __restrict__ gateup_input_ptr,
14
15
16
17
18
19
20
21
22
23
24
    const int* __restrict__ src2dst_ptr,
    const int* __restrict__ topk_ids_ptr,
    const float* __restrict__ a1_scales_ptr,
    int start_expert_id,
    int end_expert_id,
    int topk,
    int hidden_size,
    bool use_per_token_if_dynamic) {
  int token_idx = blockIdx.x;
  int tid = threadIdx.x;

25
  const scalar_t* src_ptr = input_ptr + int64_t(token_idx) * hidden_size;
26
27
28
  const int* token_src2dst = src2dst_ptr + token_idx * topk;
  const int* token_topk_ids = topk_ids_ptr + token_idx * topk;

29
30
31
32
33
34
  float scale = 1.0f;

  if (a1_scales_ptr != nullptr and use_per_token_if_dynamic) {
    scale = 1.0f / a1_scales_ptr[token_idx];
  }

35
36
37
38
39
40
41
42
43
44
45
  for (int k = 0; k < topk; ++k) {
    int expert_id = token_topk_ids[k];
    if (expert_id < start_expert_id || expert_id > end_expert_id) continue;

    if (a1_scales_ptr != nullptr) {
      if (!use_per_token_if_dynamic) {
        scale = 1.0f / a1_scales_ptr[expert_id - start_expert_id];
      }
    }

    int dst_idx = token_src2dst[k];
46
    scalar_t* dst_ptr = gateup_input_ptr + int64_t(dst_idx) * hidden_size;
47

48
49
    constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
    using vec_t = flashinfer::vec_t<scalar_t, vec_size>;
50

51
    int vec_elements = (hidden_size / vec_size) * vec_size;
52
53
54
55
56
57
    for (int idx = tid; idx < hidden_size / vec_size; idx += blockDim.x) {
      vec_t input_vec, output_vec;
      input_vec.cast_load(src_ptr + idx * vec_size);
#pragma unroll
      for (uint32_t i = 0; i < vec_size; ++i) {
        float val = static_cast<float>(input_vec[i]);
58
        output_vec[i] = static_cast<scalar_t>(val * scale);
59
60
61
      }
      output_vec.cast_store(dst_ptr + idx * vec_size);
    }
62
63
64
65
66

    for (int idx = vec_elements + tid; idx < hidden_size; idx += blockDim.x) {
      float val = static_cast<float>(src_ptr[idx]);
      dst_ptr[idx] = static_cast<scalar_t>(val * scale);
    }
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
  }
}

void ep_moe_pre_reorder(
    torch::Tensor input,
    torch::Tensor gateup_input,
    torch::Tensor src2dst,
    torch::Tensor topk_ids,
    torch::Tensor a1_scales,
    int64_t start_expert_id,
    int64_t end_expert_id,
    int64_t topk,
    bool use_per_token_if_dynamic) {
  int total_blocks = input.size(0);
  int block_size = 512;
  dim3 grid(total_blocks);
  dim3 block(block_size);
  int hidden_size = input.size(1);
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99

  DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
    ep_pre_reorder_cuda_kernel<scalar_t><<<grid, block>>>(
        static_cast<scalar_t*>(input.data_ptr()),
        static_cast<scalar_t*>(gateup_input.data_ptr()),
        src2dst.data_ptr<int>(),
        topk_ids.data_ptr<int>(),
        a1_scales.defined() ? a1_scales.data_ptr<float>() : nullptr,
        start_expert_id,
        end_expert_id,
        topk,
        hidden_size,
        use_per_token_if_dynamic);
    return true;
  });
100
}