sparse_scaled_mm_c3x.cu 12.1 KB
Newer Older
1
2
3
4
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>

5
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
6
7
8
9
10
11
#include "sparse_scaled_mm_c3x.cuh"
// clang-format on

using namespace cute;
using namespace vllm;

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
struct GemmCallerTraits {
  using return_type = void;

  template <typename GemmConfig, typename... Args>
  static return_type invoke(Args&&... args) {
    return cutlass_sparse_gemm_caller<GemmConfig>(std::forward<Args>(args)...);
  }
};

struct GemmCompressorTraits {
  using return_type = CompressorResult;

  template <typename GemmConfig, typename... Args>
  static return_type invoke(Args&&... args) {
    return cutlass_sparse_compress<GemmConfig>(std::forward<Args>(args)...);
  }
};

30
31
template <typename InType, typename OutType,
          template <typename, typename, typename> typename Epilogue,
32
33
34
35
          typename DispatchFunc, typename... Args>
typename DispatchFunc::return_type cutlass_gemm_sm90_fp8_dispatch(
    uint32_t m, uint32_t n, Args&&... args) {
  static_assert(std::is_same_v<InType, cutlass::float_e4m3_t>);
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

  using Cutlass3xGemmDefault =
      typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemmM64 =
      typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemmM128 =
      typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemmM256 =
      typename sm90_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemmM512 =
      typename sm90_fp8_config_M512<InType, OutType, Epilogue>::Cutlass3xGemm;

  using Cutlass3xGemm1 =
      typename sm90_fp8_config_1<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemm2 =
      typename sm90_fp8_config_2<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemm3 =
      typename sm90_fp8_config_3<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemm4 =
      typename sm90_fp8_config_4<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemm5 =
      typename sm90_fp8_config_5<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemm6 =
      typename sm90_fp8_config_6<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemm7 =
      typename sm90_fp8_config_7<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemm8 =
      typename sm90_fp8_config_8<InType, OutType, Epilogue>::Cutlass3xGemm;

  uint32_t const mp2 =
      std::max(static_cast<uint32_t>(64), next_pow_2(m));  // next power of 2

  if (mp2 <= 64) {
    if (n == 28672) {
70
71
      return DispatchFunc::template invoke<Cutlass3xGemm2>(
          std::forward<Args>(args)...);
72
    } else if (n == 4096 || n == 6144) {
73
74
      return DispatchFunc::template invoke<Cutlass3xGemm1>(
          std::forward<Args>(args)...);
75
76
77
    }
  } else if (mp2 <= 128) {
    if (n == 4096) {
78
79
      return DispatchFunc::template invoke<Cutlass3xGemm3>(
          std::forward<Args>(args)...);
80
    } else if (n == 28672) {
81
82
      return DispatchFunc::template invoke<Cutlass3xGemm5>(
          std::forward<Args>(args)...);
83
    } else if (n == 6144) {
84
85
      return DispatchFunc::template invoke<Cutlass3xGemm4>(
          std::forward<Args>(args)...);
86
87
88
    }
  } else if (mp2 <= 256) {
    if (n == 4096) {
89
90
      return DispatchFunc::template invoke<Cutlass3xGemm6>(
          std::forward<Args>(args)...);
91
    } else if (n == 28672) {
92
93
      return DispatchFunc::template invoke<Cutlass3xGemm8>(
          std::forward<Args>(args)...);
94
    } else if (n == 6144) {
95
96
      return DispatchFunc::template invoke<Cutlass3xGemm7>(
          std::forward<Args>(args)...);
97
98
99
    }
  } else {
    if (n == 6144 || n == 28672) {
100
101
      return DispatchFunc::template invoke<Cutlass3xGemm8>(
          std::forward<Args>(args)...);
102
    } else if (n == 4096) {
103
104
      return DispatchFunc::template invoke<Cutlass3xGemm7>(
          std::forward<Args>(args)...);
105
106
107
108
109
110
    }
  }

  // Otherwise the default heuristic
  if (mp2 <= 64) {
    // n in [1, 64]
111
112
    return DispatchFunc::template invoke<Cutlass3xGemmM64>(
        std::forward<Args>(args)...);
113
114
  } else if (mp2 <= 128) {
    // n in (64, 128]
115
116
    return DispatchFunc::template invoke<Cutlass3xGemmM128>(
        std::forward<Args>(args)...);
117
118
  } else if (mp2 <= 256) {
    // n in (128, 256]
119
120
    return DispatchFunc::template invoke<Cutlass3xGemmM256>(
        std::forward<Args>(args)...);
121
122
  } else {
    // n in (256, inf)
123
124
    return DispatchFunc::template invoke<Cutlass3xGemmM512>(
        std::forward<Args>(args)...);
125
126
127
128
129
  }
}

template <typename InType, typename OutType,
          template <typename, typename, typename> typename Epilogue,
130
131
132
          typename DispatchFunc, typename... Args>
typename DispatchFunc::return_type cutlass_gemm_sm90_16bit_dispatch(
    uint32_t m, uint32_t n, Args&&... args) {
133
134
135
  using Cutlass3xGemmDefault =
      typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;

136
137
  return DispatchFunc::template invoke<Cutlass3xGemmDefault>(
      std::forward<Args>(args)...);
138
139
140
141
}

template <typename InType, typename OutType,
          template <typename, typename, typename> typename Epilogue,
142
143
144
145
          typename DispatchFunc, typename... Args>
typename DispatchFunc::return_type cutlass_gemm_sm90_int8_dispatch(
    uint32_t m, uint32_t n, Args&&... args) {
  static_assert(std::is_same_v<InType, int8_t>);
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

  using Cutlass3xGemmDefault =
      typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemmM128 =
      typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemmM64 =
      typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemmM32NBig =
      typename sm90_int8_config_M32_NBig<InType, OutType,
                                         Epilogue>::Cutlass3xGemm;
  using Cutlass3xGemmM32NSmall =
      typename sm90_int8_config_M32_NSmall<InType, OutType,
                                           Epilogue>::Cutlass3xGemm;

  bool const is_small_n = n < 8192;
  uint32_t const mp2 =
      std::max(static_cast<uint32_t>(32), next_pow_2(m));  // next power of 2

  if (mp2 <= 32) {
    // m in [1, 32]
    if (is_small_n) {
167
168
      return DispatchFunc::template invoke<Cutlass3xGemmM32NSmall>(
          std::forward<Args>(args)...);
169
    } else {
170
171
      return DispatchFunc::template invoke<Cutlass3xGemmM32NBig>(
          std::forward<Args>(args)...);
172
173
174
    }
  } else if (mp2 <= 64) {
    // m in (32, 64]
175
176
    return DispatchFunc::template invoke<Cutlass3xGemmM64>(
        std::forward<Args>(args)...);
177
178
  } else if (mp2 <= 128) {
    // m in (64, 128]
179
180
    return DispatchFunc::template invoke<Cutlass3xGemmM128>(
        std::forward<Args>(args)...);
181
182
  } else {
    // m in (128, inf)
183
184
    return DispatchFunc::template invoke<Cutlass3xGemmDefault>(
        std::forward<Args>(args)...);
185
186
187
  }
}

188
// Dispatch to GEMM implementations based on element types
189
190
191
192
193
194
195
template <template <typename, typename, typename> typename Epilogue,
          typename... EpilogueArgs>
void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
                                            torch::Tensor const& a,
                                            torch::Tensor const& bt_nzs,
                                            torch::Tensor const& bt_meta,
                                            EpilogueArgs&&... epilogue_args) {
196
197
198
199
  uint32_t const m = out.size(0);
  uint32_t const n = out.size(1);

  // TODO: add dispatch functions to all of these
200
201
202
203
204
205
  TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
  if (a.dtype() == torch::kInt8) {
    TORCH_CHECK(bt_nzs.dtype() == torch::kInt8);

    if (out.dtype() == torch::kBFloat16) {
      return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
206
207
                                             Epilogue, GemmCallerTraits>(
          m, n, out, a, bt_nzs, bt_meta,
208
209
210
          std::forward<EpilogueArgs>(epilogue_args)...);
    } else {
      TORCH_CHECK(out.dtype() == torch::kFloat16);
211
212
213
      return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue,
                                             GemmCallerTraits>(
          m, n, out, a, bt_nzs, bt_meta,
214
215
216
217
218
219
220
          std::forward<EpilogueArgs>(epilogue_args)...);
    }
  } else if (a.dtype() == torch::kFloat8_e4m3fn) {
    TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn);

    if (out.dtype() == torch::kBFloat16) {
      return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
221
222
223
                                            cutlass::bfloat16_t, Epilogue,
                                            GemmCallerTraits>(
          m, n, out, a, bt_nzs, bt_meta,
224
225
226
          std::forward<EpilogueArgs>(epilogue_args)...);
    } else {
      TORCH_CHECK(out.dtype() == torch::kFloat16);
227
228
229
      return cutlass_gemm_sm90_fp8_dispatch<
          cutlass::float_e4m3_t, cutlass::half_t, Epilogue, GemmCallerTraits>(
          m, n, out, a, bt_nzs, bt_meta,
230
231
232
233
          std::forward<EpilogueArgs>(epilogue_args)...);
    }
  } else if (a.dtype() == torch::kFloat16) {
    TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16);
234
    TORCH_CHECK(out.dtype() == torch::kFloat16);
235

236
237
238
239
    return cutlass_gemm_sm90_16bit_dispatch<cutlass::half_t, cutlass::half_t,
                                            Epilogue, GemmCallerTraits>(
        m, n, out, a, bt_nzs, bt_meta,
        std::forward<EpilogueArgs>(epilogue_args)...);
240
241
242
  } else {  // a.dtype() == torch::kBFloat16
    TORCH_CHECK(a.dtype() == torch::kBFloat16);
    TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16);
243
    TORCH_CHECK(out.dtype() == torch::kBFloat16);
244

245
246
247
248
    return cutlass_gemm_sm90_16bit_dispatch<
        cutlass::bfloat16_t, cutlass::bfloat16_t, Epilogue, GemmCallerTraits>(
        m, n, out, a, bt_nzs, bt_meta,
        std::forward<EpilogueArgs>(epilogue_args)...);
249
250
251
252
253
254
255
256
  }
}

void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
                                   torch::Tensor const& bt_nzs,
                                   torch::Tensor const& bt_meta,
                                   torch::Tensor const& a_scales,
                                   torch::Tensor const& b_scales,
257
                                   std::optional<torch::Tensor> const& bias) {
258
  TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
259
260
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
261

262
263
  if (bias) {
    TORCH_CHECK(bias->dtype() == out.dtype(),
264
265
266
267
268
                "CUTLASS scaled_mm bias dtype must match output dtype ",
                out.dtype());
    return cutlass_scaled_sparse_mm_sm90_epilogue<
        c3x::ScaledEpilogueColumnBias>(out, a, bt_nzs, bt_meta, b_scales,
                                       a_scales, *bias);
269
270
271
272
273
274
  } else {
    return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogue>(
        out, a, bt_nzs, bt_meta, b_scales, a_scales);
  }
}

275
276
277
278
279
CompressorResult cutlass_sparse_compress_sm90(torch::Tensor const& a) {
  // These m and n variables are fordispatching to different GEMM algorithms.
  uint32_t const m = 1;  // Set M to 1 for compression
  uint32_t const n = a.size(1);

280
  // Note: For correctness, the compressed format must be invariant in:
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
  //  - M, the flattened number of tokens
  //  - Whether output dtype is fp16 or bf16
  //  - CUTLASS epilogues

  if (a.dtype() == torch::kInt8) {
    return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
                                           c3x::TrivialEpilogue,
                                           GemmCompressorTraits>(m, n, a);
  } else if (a.dtype() == torch::kFloat8_e4m3fn) {
    return cutlass_gemm_sm90_fp8_dispatch<
        cutlass::float_e4m3_t, cutlass::bfloat16_t, c3x::TrivialEpilogue,
        GemmCompressorTraits>(m, n, a);
  } else if (a.dtype() == torch::kFloat16) {
    return cutlass_gemm_sm90_16bit_dispatch<
        cutlass::bfloat16_t, cutlass::bfloat16_t, c3x::TrivialEpilogue,
        GemmCompressorTraits>(m, n, a);
  } else {
    TORCH_CHECK(a.dtype() == torch::kBFloat16,
                "cutlass_sparse_compress only supports int8, fp8_e4m3, fp16, "
                "and bf16 datatypes");
    return cutlass_gemm_sm90_16bit_dispatch<cutlass::half_t, cutlass::half_t,
                                            c3x::TrivialEpilogue,
                                            GemmCompressorTraits>(m, n, a);
  }
}

307
#endif