dsv3_router_gemm_entry.cu 7.36 KB
Newer Older
raojy's avatar
raojy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
/*
 * Adapted from SGLang's sgl-kernel implementation, which was adapted from
 * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu
 * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp
 *
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>

#include <cuda_bf16.h>
#include <cuda_runtime.h>

#include "core/registration.h"
#include "dsv3_router_gemm_utils.h"

static constexpr int DEFAULT_NUM_EXPERTS = 256;
static constexpr int KIMI_K2_NUM_EXPERTS = 384;
static constexpr int DEFAULT_HIDDEN_DIM = 7168;

template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b,
                                 cudaStream_t stream);

template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeRouterGemmBf16Output(__nv_bfloat16* output, T const* mat_a,
                                T const* mat_b, cudaStream_t stream);

template <int kBegin, int kEnd, int kNumExperts, int kHiddenDim>
struct LoopUnroller {
  static void unroll_float_output(int num_tokens, float* output,
                                  __nv_bfloat16 const* input,
                                  __nv_bfloat16 const* weights,
                                  cudaStream_t stream) {
    if (num_tokens == kBegin) {
      invokeRouterGemmFloatOutput<__nv_bfloat16, kBegin, kNumExperts,
                                  kHiddenDim>(output, input, weights, stream);
    } else {
      LoopUnroller<kBegin + 1, kEnd, kNumExperts,
                   kHiddenDim>::unroll_float_output(num_tokens, output, input,
                                                    weights, stream);
    }
  }

  static void unroll_bf16_output(int num_tokens, __nv_bfloat16* output,
                                 __nv_bfloat16 const* input,
                                 __nv_bfloat16 const* weights,
                                 cudaStream_t stream) {
    if (num_tokens == kBegin) {
      invokeRouterGemmBf16Output<__nv_bfloat16, kBegin, kNumExperts,
                                 kHiddenDim>(output, input, weights, stream);
    } else {
      LoopUnroller<kBegin + 1, kEnd, kNumExperts,
                   kHiddenDim>::unroll_bf16_output(num_tokens, output, input,
                                                   weights, stream);
    }
  }
};

template <int kEnd, int kNumExperts, int kHiddenDim>
struct LoopUnroller<kEnd, kEnd, kNumExperts, kHiddenDim> {
  static void unroll_float_output(int num_tokens, float* output,
                                  __nv_bfloat16 const* input,
                                  __nv_bfloat16 const* weights,
                                  cudaStream_t stream) {
    if (num_tokens == kEnd) {
      invokeRouterGemmFloatOutput<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(
          output, input, weights, stream);
    } else {
      throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16");
    }
  }

  static void unroll_bf16_output(int num_tokens, __nv_bfloat16* output,
                                 __nv_bfloat16 const* input,
                                 __nv_bfloat16 const* weights,
                                 cudaStream_t stream) {
    if (num_tokens == kEnd) {
      invokeRouterGemmBf16Output<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(
          output, input, weights, stream);
    } else {
      throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16");
    }
  }
};

void dsv3_router_gemm(at::Tensor& output,       // [num_tokens, num_experts]
                      const at::Tensor& mat_a,  // [num_tokens, hidden_dim]
                      const at::Tensor& mat_b   // [num_experts, hidden_dim]
) {
  TORCH_CHECK(output.dim() == 2 && mat_a.dim() == 2 && mat_b.dim() == 2);

  const int num_tokens = mat_a.size(0);
  const int num_experts = mat_b.size(0);
  const int hidden_dim = mat_a.size(1);

  TORCH_CHECK(mat_a.size(1) == mat_b.size(1),
              "mat_a and mat_b must have the same hidden_dim");
  TORCH_CHECK(hidden_dim == DEFAULT_HIDDEN_DIM,
              "Expected hidden_dim=", DEFAULT_HIDDEN_DIM,
              ", but got hidden_dim=", hidden_dim);
  TORCH_CHECK(
      num_experts == DEFAULT_NUM_EXPERTS || num_experts == KIMI_K2_NUM_EXPERTS,
      "Expected num_experts=", DEFAULT_NUM_EXPERTS,
      " or num_experts=", KIMI_K2_NUM_EXPERTS,
      ", but got num_experts=", num_experts);
  TORCH_CHECK(num_tokens >= 1 && num_tokens <= 16,
              "currently num_tokens must be less than or equal to 16 for "
              "router_gemm");
  TORCH_CHECK(mat_a.dtype() == at::kBFloat16, "mat_a must be bf16");
  TORCH_CHECK(mat_b.dtype() == at::kBFloat16, "mat_b must be bf16");
  TORCH_CHECK(output.dtype() == at::kFloat || output.dtype() == at::kBFloat16,
              "output must be float32 or bf16");

  auto const sm = getSMVersion();
  TORCH_CHECK(sm >= 90 && sm <= 103, "required SM_103 >= CUDA ARCH >= SM_90");

  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  if (output.dtype() == at::kFloat) {
    if (num_experts == DEFAULT_NUM_EXPERTS) {
      LoopUnroller<1, 16, DEFAULT_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::
          unroll_float_output(
              num_tokens, reinterpret_cast<float*>(output.mutable_data_ptr()),
              reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
              reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream);
    } else if (num_experts == KIMI_K2_NUM_EXPERTS) {
      LoopUnroller<1, 16, KIMI_K2_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::
          unroll_float_output(
              num_tokens, reinterpret_cast<float*>(output.mutable_data_ptr()),
              reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
              reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream);
    }
  } else if (output.dtype() == at::kBFloat16) {
    if (num_experts == DEFAULT_NUM_EXPERTS) {
      LoopUnroller<1, 16, DEFAULT_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::
          unroll_bf16_output(
              num_tokens,
              reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()),
              reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
              reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream);
    } else if (num_experts == KIMI_K2_NUM_EXPERTS) {
      LoopUnroller<1, 16, KIMI_K2_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::
          unroll_bf16_output(
              num_tokens,
              reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()),
              reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
              reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream);
    }
  }
}

TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
  m.impl("dsv3_router_gemm", &dsv3_router_gemm);
}