"configs/_base_/datasets/refcoco.py" did not exist on "1189a8adcc55e3a2530ce8b716c55a67b2d53af1"
ep_moe_reorder_kernel.cu 2.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <THC/THCAtomics.cuh>
#include <flashinfer/vec_dtypes.cuh>

#include "utils.h"

__global__ void ep_pre_reorder_cuda_kernel(
    const float* __restrict__ input_ptr,
    float* __restrict__ gateup_input_ptr,
    const int* __restrict__ src2dst_ptr,
    const int* __restrict__ topk_ids_ptr,
    const float* __restrict__ a1_scales_ptr,
    int start_expert_id,
    int end_expert_id,
    int topk,
    int hidden_size,
    bool use_per_token_if_dynamic) {
  int token_idx = blockIdx.x;
  int tid = threadIdx.x;

  const float* src_ptr = input_ptr + int64_t(token_idx) * hidden_size;
  const int* token_src2dst = src2dst_ptr + token_idx * topk;
  const int* token_topk_ids = topk_ids_ptr + token_idx * topk;

  for (int k = 0; k < topk; ++k) {
    int expert_id = token_topk_ids[k];
    if (expert_id < start_expert_id || expert_id > end_expert_id) continue;

    float scale = 1.0f;

    if (a1_scales_ptr != nullptr and use_per_token_if_dynamic) {
      scale = 1.0f / a1_scales_ptr[token_idx];
    }

    if (a1_scales_ptr != nullptr) {
      if (!use_per_token_if_dynamic) {
        scale = 1.0f / a1_scales_ptr[expert_id - start_expert_id];
      }
    }

    int dst_idx = token_src2dst[k];
    float* dst_ptr = gateup_input_ptr + int64_t(dst_idx) * hidden_size;

    constexpr uint32_t vec_size = 16 / sizeof(float);
    using vec_t = flashinfer::vec_t<float, vec_size>;

    for (int idx = tid; idx < hidden_size / vec_size; idx += blockDim.x) {
      vec_t input_vec, output_vec;
      input_vec.cast_load(src_ptr + idx * vec_size);
#pragma unroll
      for (uint32_t i = 0; i < vec_size; ++i) {
        float val = static_cast<float>(input_vec[i]);
        output_vec[i] = val * scale;
      }
      output_vec.cast_store(dst_ptr + idx * vec_size);
    }
  }
}

void ep_moe_pre_reorder(
    torch::Tensor input,
    torch::Tensor gateup_input,
    torch::Tensor src2dst,
    torch::Tensor topk_ids,
    torch::Tensor a1_scales,
    int64_t start_expert_id,
    int64_t end_expert_id,
    int64_t topk,
    bool use_per_token_if_dynamic) {
  int total_blocks = input.size(0);
  int block_size = 512;
  dim3 grid(total_blocks);
  dim3 block(block_size);
  int hidden_size = input.size(1);
  ep_pre_reorder_cuda_kernel<<<grid, block>>>(
      input.data_ptr<float>(),
      gateup_input.data_ptr<float>(),
      src2dst.data_ptr<int>(),
      topk_ids.data_ptr<int>(),
      a1_scales.defined() ? a1_scales.data_ptr<float>() : nullptr,
      start_expert_id,
      end_expert_id,
      topk,
      hidden_size,
      use_per_token_if_dynamic);
}