scaled_mm_entry.cu 8.97 KB
Newer Older
1
2
#include <cudaTypedefs.h>

3
#include <c10/cuda/CUDAGuard.h>
4
#include <torch/all.h>
5

6
7
#include "cutlass_extensions/common.hpp"

8
9
10
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
11
                            torch::Tensor const& b_scales,
12
                            std::optional<torch::Tensor> const& bias);
13

14
15
16
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
17
                            torch::Tensor const& b_scales,
18
                            std::optional<torch::Tensor> const& bias);
19

20
21
22
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
23
                            torch::Tensor const& b_scales,
24
                            std::optional<torch::Tensor> const& bias);
25

26
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
27
28
29
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
30
                            torch::Tensor const& b_scales,
31
                            std::optional<torch::Tensor> const& bias);
32
33
34
#endif

#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
35
36
37
38
39
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
                             torch::Tensor const& b,
                             torch::Tensor const& a_scales,
                             torch::Tensor const& b_scales,
                             std::optional<torch::Tensor> const& bias);
40
#endif
41

42
43
44
45
46
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
                                torch::Tensor const& b,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& b_scales,
                                torch::Tensor const& azp_adj,
47
48
                                std::optional<torch::Tensor> const& azp,
                                std::optional<torch::Tensor> const& bias);
49
50
51
52
53
54

void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
                                torch::Tensor const& b,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& b_scales,
                                torch::Tensor const& azp_adj,
55
56
                                std::optional<torch::Tensor> const& azp,
                                std::optional<torch::Tensor> const& bias);
57
58
59
60
61
62

void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
                                torch::Tensor const& b,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& b_scales,
                                torch::Tensor const& azp_adj,
63
64
                                std::optional<torch::Tensor> const& azp,
                                std::optional<torch::Tensor> const& bias);
65

66
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
67
68
69
70
71
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
                                torch::Tensor const& b,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& b_scales,
                                torch::Tensor const& azp_adj,
72
73
                                std::optional<torch::Tensor> const& azp,
                                std::optional<torch::Tensor> const& bias);
74
75
#endif

76
77
78
79
80
81
82
83
84
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
  // CUTLASS FP8 kernels need at least
  //   CUDA 12.0 on SM90 systems (Hopper)
  //   CUDA 12.4 on SM89 systems (Lovelace)

#if defined CUDA_VERSION
  if (cuda_device_capability >= 90) {
    return CUDA_VERSION >= 12000;
  } else if (cuda_device_capability >= 89) {
85
    return CUDA_VERSION >= 12040;
86
87
88
89
90
91
  }
#endif

  return false;
}

92
93
94
95
96
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
  // CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
  // and at least SM90 (Hopper)

#if defined CUDA_VERSION
97
  if (cuda_device_capability >= 90 && cuda_device_capability < 100) {
98
99
100
101
102
103
104
    return CUDA_VERSION >= 12000;
  }
#endif

  return false;
}

105
106
107
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
                       torch::Tensor const& b, torch::Tensor const& a_scales,
                       torch::Tensor const& b_scales,
108
                       std::optional<torch::Tensor> const& bias) {
109
110
111
  // Checks for conformality
  TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
  TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
112
              b.size(1) == c.size(1));
113
114

  // Check for strides and alignment
115
116
117
118
  TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1);  // Row-major
  TORCH_CHECK(b.stride(0) == 1);                      // Column-major
  TORCH_CHECK(c.stride(0) % 16 == 0 &&
              b.stride(1) % 16 == 0);  // 16 Byte Alignment
119

120
121
122
123
124
  if (bias) {
    TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
                bias->dim() == 1);
  }

125
  at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
126
  int32_t version_num = get_sm_version_num();
127

128
129
130
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
  if (version_num >= 100) {
    cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
131
132
    return;
  }
133
134
135
136
#endif

  // Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
137
  if (version_num >= 90 && version_num < 100) {
138
    // Hopper
139
140
141
    cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
    return;
  }
142
#endif
143
144
145

#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
  if (version_num == 89) {
146
    // Ada Lovelace
147
    cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
148
149
150
151
    return;
  }

  if (version_num >= 80) {
152
    // Ampere
153
    cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
154
    return;
155
  }
156

157
158
159
160
161
  if (version_num >= 75) {
    // Turing
    cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
    return;
  }
162
163
164
165
166
167
168
#endif

  TORCH_CHECK_NOT_IMPLEMENTED(
      false,
      "No compiled cutlass_scaled_mm for a compute capability less than "
      "CUDA device capability: ",
      version_num);
169
}
170
171
172
173
174
175

void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
                           torch::Tensor const& b,
                           torch::Tensor const& a_scales,
                           torch::Tensor const& b_scales,
                           torch::Tensor const& azp_adj,
176
177
                           std::optional<torch::Tensor> const& azp,
                           std::optional<torch::Tensor> const& bias) {
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
  // Checks for conformality
  TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
  TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
              b.size(1) == c.size(1));
  TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
  TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));

  // Check for strides and alignment
  TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1);  // Row-major
  TORCH_CHECK(b.stride(0) == 1);                      // Column-major
  TORCH_CHECK(c.stride(0) % 16 == 0 &&
              b.stride(1) % 16 == 0);  // 16 Byte Alignment
  TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());

  // bias, azp, azp_adj are all 1d
  // bias and azp_adj have n elements, azp has m elements
  if (bias) {
    TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
  }
  if (azp) {
    TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
  }
  TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());

  // azp & bias types
  TORCH_CHECK(azp_adj.dtype() == torch::kInt32);
  TORCH_CHECK(!azp || azp->dtype() == torch::kInt32);
  TORCH_CHECK(!bias || bias->dtype() == c.dtype(),
              "currently bias dtype must match output dtype ", c.dtype());

  at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
209

210
211
  int32_t version_num = get_sm_version_num();

212
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
213
  if (version_num >= 90) {
214
    cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
215
216
    return;
  }
217
#endif
218
219
220

#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
  if (version_num == 89) {
221
222
    // Ada Lovelace
    cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
223
224
225
226
    return;
  }

  if (version_num >= 80) {
227
228
    // Ampere
    cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
229
    return;
230
  }
231
232
233
234
235
236
237
238
239
240
241
242

  // Turing
  TORCH_CHECK(version_num >= 75);
  cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
  return;
#endif

  TORCH_CHECK_NOT_IMPLEMENTED(
      false,
      "No compiled cutlass_scaled_mm_azp for a compute capability less than "
      "CUDA device capability: ",
      version_num);
243
}