sparse_scaled_mm_c3x.cu 12.8 KB
Newer Older
1
2
3
4
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>

5
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
#include "sparse_scaled_mm_c3x.cuh"
// clang-format on

using namespace cute;
using namespace vllm;

template <typename InType, typename OutType,
          template <typename, typename, typename> typename Epilogue,
          typename... EpilogueArgs>
void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
                                    torch::Tensor const& bt_nzs,
                                    torch::Tensor const& bt_meta,
                                    EpilogueArgs&&... args) {
  static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
  TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
  TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
  TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn);

  using Cutlass3xGemmDefault =
      typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemmM64 =
      typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemmM128 =
      typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemmM256 =
      typename sm90_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemmM512 =
      typename sm90_fp8_config_M512<InType, OutType, Epilogue>::Cutlass3xGemm;

  using Cutlass3xGemm1 =
      typename sm90_fp8_config_1<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemm2 =
      typename sm90_fp8_config_2<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemm3 =
      typename sm90_fp8_config_3<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemm4 =
      typename sm90_fp8_config_4<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemm5 =
      typename sm90_fp8_config_5<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemm6 =
      typename sm90_fp8_config_6<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemm7 =
      typename sm90_fp8_config_7<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemm8 =
      typename sm90_fp8_config_8<InType, OutType, Epilogue>::Cutlass3xGemm;

  uint32_t const n = bt_nzs.size(0);
  uint32_t const m = a.size(0);  // Batch size
  uint32_t const mp2 =
      std::max(static_cast<uint32_t>(64), next_pow_2(m));  // next power of 2

  if (mp2 <= 64) {
    if (n == 28672) {
      return cutlass_sparse_gemm_caller<Cutlass3xGemm2>(
          out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
    } else if (n == 4096 || n == 6144) {
      return cutlass_sparse_gemm_caller<Cutlass3xGemm1>(
          out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
    }
  } else if (mp2 <= 128) {
    if (n == 4096) {
      return cutlass_sparse_gemm_caller<Cutlass3xGemm3>(
          out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
    } else if (n == 28672) {
      return cutlass_sparse_gemm_caller<Cutlass3xGemm5>(
          out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
    } else if (n == 6144) {
      return cutlass_sparse_gemm_caller<Cutlass3xGemm4>(
          out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
    }
  } else if (mp2 <= 256) {
    if (n == 4096) {
      return cutlass_sparse_gemm_caller<Cutlass3xGemm6>(
          out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
    } else if (n == 28672) {
      return cutlass_sparse_gemm_caller<Cutlass3xGemm8>(
          out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
    } else if (n == 6144) {
      return cutlass_sparse_gemm_caller<Cutlass3xGemm7>(
          out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
    }
  } else {
    if (n == 6144 || n == 28672) {
      return cutlass_sparse_gemm_caller<Cutlass3xGemm8>(
          out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
    } else if (n == 4096) {
      return cutlass_sparse_gemm_caller<Cutlass3xGemm7>(
          out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
    }
  }

  // Otherwise the default heuristic
  if (mp2 <= 64) {
    // n in [1, 64]
    return cutlass_sparse_gemm_caller<Cutlass3xGemmM64>(
        out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
  } else if (mp2 <= 128) {
    // n in (64, 128]
    return cutlass_sparse_gemm_caller<Cutlass3xGemmM128>(
        out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
  } else if (mp2 <= 256) {
    // n in (128, 256]
    return cutlass_sparse_gemm_caller<Cutlass3xGemmM256>(
        out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
  } else {
    // n in (256, inf)
    return cutlass_sparse_gemm_caller<Cutlass3xGemmM512>(
        out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
  }
}

template <typename InType, typename OutType,
          template <typename, typename, typename> typename Epilogue,
          typename... EpilogueArgs>
void cutlass_gemm_sm90_fp16_dispatch(torch::Tensor& out, torch::Tensor const& a,
                                     torch::Tensor const& bt_nzs,
                                     torch::Tensor const& bt_meta,
                                     EpilogueArgs&&... args) {
  static_assert(std::is_same<InType, cutlass::half_t>());
  TORCH_CHECK(a.dtype() == torch::kFloat16);
  TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
  TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16);

  using Cutlass3xGemmDefault =
      typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;

  // m in (128, inf)
  return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
      out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
}

template <typename InType, typename OutType,
          template <typename, typename, typename> typename Epilogue,
          typename... EpilogueArgs>
void cutlass_gemm_sm90_bf16_dispatch(torch::Tensor& out, torch::Tensor const& a,
                                     torch::Tensor const& bt_nzs,
                                     torch::Tensor const& bt_meta,
                                     EpilogueArgs&&... args) {
  static_assert(std::is_same<InType, cutlass::bfloat16_t>());
  TORCH_CHECK(a.dtype() == torch::kBFloat16);
  TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
  TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16);

  using Cutlass3xGemmDefault =
      typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;

  // m in (128, inf)
  return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
      out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
}

template <typename InType, typename OutType,
          template <typename, typename, typename> typename Epilogue,
          typename... EpilogueArgs>
void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
                                     torch::Tensor const& bt_nzs,
                                     torch::Tensor const& bt_meta,
                                     EpilogueArgs&&... args) {
  static_assert(std::is_same<InType, int8_t>());
  TORCH_CHECK(a.dtype() == torch::kInt8);
  TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
  TORCH_CHECK(bt_nzs.dtype() == torch::kInt8);

  using Cutlass3xGemmDefault =
      typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemmM128 =
      typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemmM64 =
      typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemmM32NBig =
      typename sm90_int8_config_M32_NBig<InType, OutType,
                                         Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemmM32NSmall =
      typename sm90_int8_config_M32_NSmall<InType, OutType,
                                           Epilogue>::Cutlass3xGemm;

  uint32_t const n = out.size(1);
  bool const is_small_n = n < 8192;

  uint32_t const m = a.size(0);
  uint32_t const mp2 =
      std::max(static_cast<uint32_t>(32), next_pow_2(m));  // next power of 2

  if (mp2 <= 32) {
    // m in [1, 32]
    if (is_small_n) {
      return cutlass_sparse_gemm_caller<Cutlass3xGemmM32NSmall>(
          out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
    } else {
      return cutlass_sparse_gemm_caller<Cutlass3xGemmM32NBig>(
          out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
    }
  } else if (mp2 <= 64) {
    // m in (32, 64]
    return cutlass_sparse_gemm_caller<Cutlass3xGemmM64>(
        out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
  } else if (mp2 <= 128) {
    // m in (64, 128]
    return cutlass_sparse_gemm_caller<Cutlass3xGemmM128>(
        out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
  } else {
    // m in (128, inf)
    return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
        out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
  }
}

template <template <typename, typename, typename> typename Epilogue,
          typename... EpilogueArgs>
void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
                                            torch::Tensor const& a,
                                            torch::Tensor const& bt_nzs,
                                            torch::Tensor const& bt_meta,
                                            EpilogueArgs&&... epilogue_args) {
  TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
  if (a.dtype() == torch::kInt8) {
    TORCH_CHECK(bt_nzs.dtype() == torch::kInt8);

    if (out.dtype() == torch::kBFloat16) {
      return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
                                             Epilogue>(
          out, a, bt_nzs, bt_meta,
          std::forward<EpilogueArgs>(epilogue_args)...);
    } else {
      TORCH_CHECK(out.dtype() == torch::kFloat16);
      return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
          out, a, bt_nzs, bt_meta,
          std::forward<EpilogueArgs>(epilogue_args)...);
    }
  } else if (a.dtype() == torch::kFloat8_e4m3fn) {
    TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn);

    if (out.dtype() == torch::kBFloat16) {
      return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
                                            cutlass::bfloat16_t, Epilogue>(
          out, a, bt_nzs, bt_meta,
          std::forward<EpilogueArgs>(epilogue_args)...);
    } else {
      TORCH_CHECK(out.dtype() == torch::kFloat16);
      return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
                                            cutlass::half_t, Epilogue>(
          out, a, bt_nzs, bt_meta,
          std::forward<EpilogueArgs>(epilogue_args)...);
    }
  } else if (a.dtype() == torch::kFloat16) {
    TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16);

    if (out.dtype() == torch::kBFloat16) {
      return cutlass_gemm_sm90_fp16_dispatch<cutlass::half_t,
                                             cutlass::bfloat16_t, Epilogue>(
          out, a, bt_nzs, bt_meta,
          std::forward<EpilogueArgs>(epilogue_args)...);
    } else {
      TORCH_CHECK(out.dtype() == torch::kFloat16);
      return cutlass_gemm_sm90_fp16_dispatch<cutlass::half_t, cutlass::half_t,
                                             Epilogue>(
          out, a, bt_nzs, bt_meta,
          std::forward<EpilogueArgs>(epilogue_args)...);
    }
  } else {  // a.dtype() == torch::kBFloat16
    TORCH_CHECK(a.dtype() == torch::kBFloat16);
    TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16);

    if (out.dtype() == torch::kBFloat16) {
      return cutlass_gemm_sm90_bf16_dispatch<cutlass::bfloat16_t,
                                             cutlass::bfloat16_t, Epilogue>(
          out, a, bt_nzs, bt_meta,
          std::forward<EpilogueArgs>(epilogue_args)...);
    } else {
      TORCH_CHECK(out.dtype() == torch::kFloat16);
      return cutlass_gemm_sm90_bf16_dispatch<cutlass::bfloat16_t,
                                             cutlass::half_t, Epilogue>(
          out, a, bt_nzs, bt_meta,
          std::forward<EpilogueArgs>(epilogue_args)...);
    }
  }
}

void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
                                   torch::Tensor const& bt_nzs,
                                   torch::Tensor const& bt_meta,
                                   torch::Tensor const& a_scales,
                                   torch::Tensor const& b_scales,
289
                                   std::optional<torch::Tensor> const& bias) {
290
291
292
293
294
295
296
297
298
299
300
301
302
303
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
  if (bias) {
    TORCH_CHECK(bias->dtype() == out.dtype(),
                "currently bias dtype must match output dtype ", out.dtype());
    return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogueBias>(
        out, a, bt_nzs, bt_meta, b_scales, a_scales, *bias);
  } else {
    return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogue>(
        out, a, bt_nzs, bt_meta, b_scales, a_scales);
  }
}

#endif