scaled_mm_entry.cu 18.2 KB
Newer Older
1
2
#include <cudaTypedefs.h>

3
#include <c10/cuda/CUDAGuard.h>
4
#include <torch/all.h>
5

6
7
#include "cutlass_extensions/common.hpp"

8
9
10
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
11
                            torch::Tensor const& b_scales,
12
                            std::optional<torch::Tensor> const& bias);
13

14
15
16
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
17
                            torch::Tensor const& b_scales,
18
                            std::optional<torch::Tensor> const& bias);
19

20
21
22
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
23
                            torch::Tensor const& b_scales,
24
                            std::optional<torch::Tensor> const& bias);
25

26
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
27
28
29
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
30
                            torch::Tensor const& b_scales,
31
                            std::optional<torch::Tensor> const& bias);
32
33
#endif
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
34
35
36
37
38
void cutlass_moe_mm_sm90(
    torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
    torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
    torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
    torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
39
40
    torch::Tensor const& b_strides, torch::Tensor const& c_strides,
    bool per_act_token, bool per_out_ch);
41

42
43
#endif

44
45
46
47
48
49
50
51
52
53
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
void cutlass_moe_mm_sm100(
    torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
    torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
    torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
    torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
    torch::Tensor const& b_strides, torch::Tensor const& c_strides,
    bool per_act_token, bool per_out_ch);
#endif

54
55
56
57
58
59
60
61
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
                             torch::Tensor const& b,
                             torch::Tensor const& a_scales,
                             torch::Tensor const& b_scales,
                             std::optional<torch::Tensor> const& bias);
#endif

62
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
63
64
65
66
67
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
                             torch::Tensor const& b,
                             torch::Tensor const& a_scales,
                             torch::Tensor const& b_scales,
                             std::optional<torch::Tensor> const& bias);
68
#endif
69

70
71
72
#if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) ||   \
    (defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \
    (defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120)
73
74
75
76
void get_cutlass_moe_mm_data_caller(
    const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
    torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
    torch::Tensor& input_permutation, torch::Tensor& output_permutation,
77
78
    const int64_t num_experts, const int64_t n, const int64_t k,
    const std::optional<torch::Tensor>& blockscale_offsets);
79

80
81
82
void get_cutlass_moe_mm_problem_sizes_caller(
    const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
    torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
83
84
    const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
    std::optional<bool> force_swap_ab = std::nullopt);
85

86
87
88
89
90
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
    const torch::Tensor& expert_first_token_offset,
    torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
    const int64_t n, const int64_t k, const bool swap_ab);

91
92
93
94
95
96
97
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
                                         torch::Tensor& problem_sizes1,
                                         torch::Tensor& problem_sizes2,
                                         const torch::Tensor& expert_num_tokens,
                                         const int64_t num_local_experts,
                                         const int64_t padded_m,
                                         const int64_t n, const int64_t k);
98
99
#endif

100
101
102
103
104
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
                                torch::Tensor const& b,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& b_scales,
                                torch::Tensor const& azp_adj,
105
106
                                std::optional<torch::Tensor> const& azp,
                                std::optional<torch::Tensor> const& bias);
107
108
109
110
111
112

void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
                                torch::Tensor const& b,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& b_scales,
                                torch::Tensor const& azp_adj,
113
114
                                std::optional<torch::Tensor> const& azp,
                                std::optional<torch::Tensor> const& bias);
115
116
117
118
119
120

void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
                                torch::Tensor const& b,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& b_scales,
                                torch::Tensor const& azp_adj,
121
122
                                std::optional<torch::Tensor> const& azp,
                                std::optional<torch::Tensor> const& bias);
123

124
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
125
126
127
128
129
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
                                torch::Tensor const& b,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& b_scales,
                                torch::Tensor const& azp_adj,
130
131
                                std::optional<torch::Tensor> const& azp,
                                std::optional<torch::Tensor> const& bias);
132
133
#endif

134
135
136
137
138
139
140
141
142
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
  // CUTLASS FP8 kernels need at least
  //   CUDA 12.0 on SM90 systems (Hopper)
  //   CUDA 12.4 on SM89 systems (Lovelace)

#if defined CUDA_VERSION
  if (cuda_device_capability >= 90) {
    return CUDA_VERSION >= 12000;
  } else if (cuda_device_capability >= 89) {
143
    return CUDA_VERSION >= 12040;
144
145
146
147
148
149
  }
#endif

  return false;
}

150
151
152
153
154
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
  // CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
  // and at least SM90 (Hopper)

#if defined CUDA_VERSION
155
  if (cuda_device_capability >= 100) {
156
    return CUDA_VERSION >= 12080;
157
158
  } else if (cuda_device_capability >= 90) {
    return CUDA_VERSION >= 12000;
159
160
161
162
163
164
  }
#endif

  return false;
}

165
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
166
167
  // CUTLASS grouped FP8 kernels need at least CUDA 12.3 and SM90 (Hopper)
  // or CUDA 12.8 and SM100 (Blackwell)
168
169

#if defined CUDA_VERSION
170
171
172
173
  if (cuda_device_capability >= 100) {
    return CUDA_VERSION >= 12080;
  }
  if (cuda_device_capability >= 90) {
174
175
176
177
178
179
180
    return CUDA_VERSION >= 12030;
  }
#endif

  return false;
}

181
182
183
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
                       torch::Tensor const& b, torch::Tensor const& a_scales,
                       torch::Tensor const& b_scales,
184
                       std::optional<torch::Tensor> const& bias) {
185
186
187
  // Checks for conformality
  TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
  TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
188
              b.size(1) == c.size(1));
189
190

  // Check for strides and alignment
191
192
193
194
  TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1);  // Row-major
  TORCH_CHECK(b.stride(0) == 1);                      // Column-major
  TORCH_CHECK(c.stride(0) % 16 == 0 &&
              b.stride(1) % 16 == 0);  // 16 Byte Alignment
195

196
197
198
199
200
  if (bias) {
    TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
                bias->dim() == 1);
  }

201
  at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
202
  int32_t version_num = get_sm_version_num();
203

204
205
206
207
208
209
210
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
  if (version_num >= 120) {
    cutlass_scaled_mm_sm120(c, a, b, a_scales, b_scales, bias);
    return;
  }
#endif

211
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
212
  if (version_num >= 100 && version_num < 120) {
213
    cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
214
215
    return;
  }
216
217
218
219
#endif

  // Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
220
  if (version_num >= 90 && version_num < 100) {
221
    // Hopper
222
223
224
    cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
    return;
  }
225
#endif
226
227
228

#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
  if (version_num == 89) {
229
    // Ada Lovelace
230
    cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
231
232
233
234
    return;
  }

  if (version_num >= 80) {
235
    // Ampere
236
    cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
237
    return;
238
  }
239

240
241
242
243
244
  if (version_num >= 75) {
    // Turing
    cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
    return;
  }
245
246
247
248
249
250
251
#endif

  TORCH_CHECK_NOT_IMPLEMENTED(
      false,
      "No compiled cutlass_scaled_mm for a compute capability less than "
      "CUDA device capability: ",
      version_num);
252
}
253

254
255
256
257
258
void cutlass_moe_mm(
    torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
    torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
    torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
    torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
259
260
    torch::Tensor const& b_strides, torch::Tensor const& c_strides,
    bool per_act_token, bool per_out_ch) {
261
  int32_t version_num = get_sm_version_num();
262
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
263
  if (version_num >= 100 && version_num < 110) {
264
265
266
267
268
269
    cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
                         expert_offsets, problem_sizes, a_strides, b_strides,
                         c_strides, per_act_token, per_out_ch);
    return;
  }
#endif
270
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
271
  if (version_num >= 90 && version_num < 100) {
272
273
274
275
276
    cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
                        expert_offsets, problem_sizes, a_strides, b_strides,
                        c_strides, per_act_token, per_out_ch);
    return;
  }
277
278
279
280
#endif
  TORCH_CHECK_NOT_IMPLEMENTED(
      false,
      "No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
281
      ". Required capability: 90 or 100");
282
283
284
285
286
287
}

void get_cutlass_moe_mm_data(
    const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
    torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
    torch::Tensor& input_permutation, torch::Tensor& output_permutation,
288
289
    const int64_t num_experts, const int64_t n, const int64_t k,
    const std::optional<torch::Tensor>& blockscale_offsets) {
290
291
292
  // This function currently gets compiled only if we have a valid cutlass moe
  // mm to run it for.
  int32_t version_num = get_sm_version_num();
293
294
295
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) ||   \
    (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
    (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
296
297
  get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
                                 problem_sizes2, input_permutation,
298
299
                                 output_permutation, num_experts, n, k,
                                 blockscale_offsets);
300
301
302
303
304
305
  return;
#endif
  TORCH_CHECK_NOT_IMPLEMENTED(
      false,
      "No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
      "CUDA device capability: ",
306
      version_num, ". Required capability: 90, 100, or 120");
307
308
}

309
310
311
void get_cutlass_moe_mm_problem_sizes(
    const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
    torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
312
313
    const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
    std::optional<bool> force_swap_ab = std::nullopt) {
314
  int32_t version_num = get_sm_version_num();
315
316
317
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) ||   \
    (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
    (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
318
319
  get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1,
                                          problem_sizes2, num_experts, n, k,
320
                                          blockscale_offsets, force_swap_ab);
321
322
323
324
325
326
  return;
#endif
  TORCH_CHECK_NOT_IMPLEMENTED(
      false,
      "No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm "
      "kernel for CUDA device capability: ",
327
      version_num, ". Required capability: 90, 100, or 120");
328
329
}

330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
    const torch::Tensor& expert_first_token_offset,
    torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
    const int64_t n, const int64_t k, const bool swap_ab) {
  int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) ||   \
    (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
    (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
  get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
      expert_first_token_offset, problem_sizes1, problem_sizes2, n, k, swap_ab);
  return;
#endif
  TORCH_CHECK_NOT_IMPLEMENTED(
      false,
      "No compiled get_cutlass_moe_mm_problem_sizes_from_expert_offsets: "
      "no cutlass_scaled_mm kernel for CUDA device capability: ",
      version_num, ". Required capability: 90, 100, or 120");
}

349
350
351
352
353
354
355
356
357
358
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
                                  torch::Tensor& problem_sizes1,
                                  torch::Tensor& problem_sizes2,
                                  const torch::Tensor& expert_num_tokens,
                                  const int64_t num_local_experts,
                                  const int64_t padded_m, const int64_t n,
                                  const int64_t k) {
  // This function currently gets compiled only if we have a valid cutlass moe
  // mm to run it for.
  int32_t version_num = get_sm_version_num();
359
360
361
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) ||   \
    (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
    (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
362
363
364
365
366
367
368
369
370
  get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
                                      problem_sizes2, expert_num_tokens,
                                      num_local_experts, padded_m, n, k);
  return;
#endif
  TORCH_CHECK_NOT_IMPLEMENTED(
      false,
      "No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
      "for CUDA device capability: ",
371
      version_num, ". Required capability: 90, 100, or 120");
372
373
}

374
375
376
377
378
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
                           torch::Tensor const& b,
                           torch::Tensor const& a_scales,
                           torch::Tensor const& b_scales,
                           torch::Tensor const& azp_adj,
379
380
                           std::optional<torch::Tensor> const& azp,
                           std::optional<torch::Tensor> const& bias) {
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
  // Checks for conformality
  TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
  TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
              b.size(1) == c.size(1));
  TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
  TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));

  // Check for strides and alignment
  TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1);  // Row-major
  TORCH_CHECK(b.stride(0) == 1);                      // Column-major
  TORCH_CHECK(c.stride(0) % 16 == 0 &&
              b.stride(1) % 16 == 0);  // 16 Byte Alignment
  TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());

  // bias, azp, azp_adj are all 1d
  // bias and azp_adj have n elements, azp has m elements
  if (bias) {
    TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
  }
  if (azp) {
    TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
  }
  TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());

  // azp & bias types
  TORCH_CHECK(azp_adj.dtype() == torch::kInt32);
  TORCH_CHECK(!azp || azp->dtype() == torch::kInt32);
  TORCH_CHECK(!bias || bias->dtype() == c.dtype(),
              "currently bias dtype must match output dtype ", c.dtype());

  at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
412

413
414
  int32_t version_num = get_sm_version_num();

415
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
416
  if (version_num >= 90) {
417
    cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
418
419
    return;
  }
420
#endif
421
422
423

#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
  if (version_num == 89) {
424
425
    // Ada Lovelace
    cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
426
427
428
429
    return;
  }

  if (version_num >= 80) {
430
431
    // Ampere
    cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
432
    return;
433
  }
434
435
436
437
438
439
440
441
442
443
444
445

  // Turing
  TORCH_CHECK(version_num >= 75);
  cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
  return;
#endif

  TORCH_CHECK_NOT_IMPLEMENTED(
      false,
      "No compiled cutlass_scaled_mm_azp for a compute capability less than "
      "CUDA device capability: ",
      version_num);
446
}