scaled_mm_entry.cu 9.01 KB
Newer Older
1
2
#include <cudaTypedefs.h>

3
#include <c10/cuda/CUDAGuard.h>
4
#include <torch/all.h>
5

6
7
#include "cutlass_extensions/common.hpp"

8
9
10
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
11
                            torch::Tensor const& b_scales,
12
                            std::optional<torch::Tensor> const& bias);
13

14
15
16
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
17
                            torch::Tensor const& b_scales,
18
                            std::optional<torch::Tensor> const& bias);
19

20
21
22
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
23
                            torch::Tensor const& b_scales,
24
                            std::optional<torch::Tensor> const& bias);
25

26
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
27
28
29
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
30
                            torch::Tensor const& b_scales,
31
                            std::optional<torch::Tensor> const& bias);
32
33
34
35
36
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
                             torch::Tensor const& b,
                             torch::Tensor const& a_scales,
                             torch::Tensor const& b_scales,
                             std::optional<torch::Tensor> const& bias);
37
#endif
38

39
40
41
42
43
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
                                torch::Tensor const& b,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& b_scales,
                                torch::Tensor const& azp_adj,
44
45
                                std::optional<torch::Tensor> const& azp,
                                std::optional<torch::Tensor> const& bias);
46
47
48
49
50
51

void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
                                torch::Tensor const& b,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& b_scales,
                                torch::Tensor const& azp_adj,
52
53
                                std::optional<torch::Tensor> const& azp,
                                std::optional<torch::Tensor> const& bias);
54
55
56
57
58
59

void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
                                torch::Tensor const& b,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& b_scales,
                                torch::Tensor const& azp_adj,
60
61
                                std::optional<torch::Tensor> const& azp,
                                std::optional<torch::Tensor> const& bias);
62
63
64
65
66
67
68

#if defined CUDA_VERSION && CUDA_VERSION >= 12000
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
                                torch::Tensor const& b,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& b_scales,
                                torch::Tensor const& azp_adj,
69
70
                                std::optional<torch::Tensor> const& azp,
                                std::optional<torch::Tensor> const& bias);
71
72
#endif

73
74
75
76
77
78
79
80
81
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
  // CUTLASS FP8 kernels need at least
  //   CUDA 12.0 on SM90 systems (Hopper)
  //   CUDA 12.4 on SM89 systems (Lovelace)

#if defined CUDA_VERSION
  if (cuda_device_capability >= 90) {
    return CUDA_VERSION >= 12000;
  } else if (cuda_device_capability >= 89) {
82
    return CUDA_VERSION >= 12040;
83
84
85
86
87
88
  }
#endif

  return false;
}

89
90
91
92
93
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
  // CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
  // and at least SM90 (Hopper)

#if defined CUDA_VERSION
94
  if (cuda_device_capability >= 90 && cuda_device_capability < 100) {
95
96
97
98
99
100
101
    return CUDA_VERSION >= 12000;
  }
#endif

  return false;
}

102
103
104
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
                       torch::Tensor const& b, torch::Tensor const& a_scales,
                       torch::Tensor const& b_scales,
105
                       std::optional<torch::Tensor> const& bias) {
106
107
108
  // Checks for conformality
  TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
  TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
109
              b.size(1) == c.size(1));
110
111

  // Check for strides and alignment
112
113
114
115
  TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1);  // Row-major
  TORCH_CHECK(b.stride(0) == 1);                      // Column-major
  TORCH_CHECK(c.stride(0) % 16 == 0 &&
              b.stride(1) % 16 == 0);  // 16 Byte Alignment
116

117
118
119
120
121
  if (bias) {
    TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
                bias->dim() == 1);
  }

122
  at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
123
  int32_t version_num = get_sm_version_num();
124
  // Hopper
125

126
127
  // Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
128
129
130

  #if defined CUDA_VERSION && CUDA_VERSION < 12080
  if (version_num >= 90 && version_num < 100) {
131
    cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
132
133
    return;
  }
134
135
136
137
138
139
140
141
142
143
  #else
  if (version_num >= 90 && version_num < 100) {
    cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
    return;
  } else if (version_num >= 100) {
    cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
    return;
  }
  #endif

144
#endif
145
146
147

#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
  if (version_num == 89) {
148
    // Ada Lovelace
149
    cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
150
151
152
153
    return;
  }

  if (version_num >= 80) {
154
    // Ampere
155
    cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
156
    return;
157
  }
158

159
160
161
162
163
  if (version_num >= 75) {
    // Turing
    cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
    return;
  }
164
165
166
167
168
169
170
#endif

  TORCH_CHECK_NOT_IMPLEMENTED(
      false,
      "No compiled cutlass_scaled_mm for a compute capability less than "
      "CUDA device capability: ",
      version_num);
171
}
172
173
174
175
176
177

void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
                           torch::Tensor const& b,
                           torch::Tensor const& a_scales,
                           torch::Tensor const& b_scales,
                           torch::Tensor const& azp_adj,
178
179
                           std::optional<torch::Tensor> const& azp,
                           std::optional<torch::Tensor> const& bias) {
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
  // Checks for conformality
  TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
  TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
              b.size(1) == c.size(1));
  TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
  TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));

  // Check for strides and alignment
  TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1);  // Row-major
  TORCH_CHECK(b.stride(0) == 1);                      // Column-major
  TORCH_CHECK(c.stride(0) % 16 == 0 &&
              b.stride(1) % 16 == 0);  // 16 Byte Alignment
  TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());

  // bias, azp, azp_adj are all 1d
  // bias and azp_adj have n elements, azp has m elements
  if (bias) {
    TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
  }
  if (azp) {
    TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
  }
  TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());

  // azp & bias types
  TORCH_CHECK(azp_adj.dtype() == torch::kInt32);
  TORCH_CHECK(!azp || azp->dtype() == torch::kInt32);
  TORCH_CHECK(!bias || bias->dtype() == c.dtype(),
              "currently bias dtype must match output dtype ", c.dtype());

  at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
211

212
213
  int32_t version_num = get_sm_version_num();

214
215
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
  if (version_num >= 90) {
216
    cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
217
218
    return;
  }
219
#endif
220
221
222

#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
  if (version_num == 89) {
223
224
    // Ada Lovelace
    cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
225
226
227
228
    return;
  }

  if (version_num >= 80) {
229
230
    // Ampere
    cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
231
    return;
232
  }
233
234
235
236
237
238
239
240
241
242
243
244

  // Turing
  TORCH_CHECK(version_num >= 75);
  cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
  return;
#endif

  TORCH_CHECK_NOT_IMPLEMENTED(
      false,
      "No compiled cutlass_scaled_mm_azp for a compute capability less than "
      "CUDA device capability: ",
      version_num);
245
}