scaled_mm_entry.cu 4.09 KB
Newer Older
1
2
#include <cudaTypedefs.h>

3
#include <c10/cuda/CUDAGuard.h>
4
#include <torch/all.h>
5

6
7
8
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
9
10
                            torch::Tensor const& b_scales,
                            c10::optional<torch::Tensor> const& bias);
11

12
13
14
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
15
16
                            torch::Tensor const& b_scales,
                            c10::optional<torch::Tensor> const& bias);
17

18
19
20
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
21
22
                            torch::Tensor const& b_scales,
                            c10::optional<torch::Tensor> const& bias);
23

24
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
25
26
27
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
28
29
                            torch::Tensor const& b_scales,
                            c10::optional<torch::Tensor> const& bias);
30
#endif
31

32
33
34
35
36
37
38
39
40
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
  // CUTLASS FP8 kernels need at least
  //   CUDA 12.0 on SM90 systems (Hopper)
  //   CUDA 12.4 on SM89 systems (Lovelace)

#if defined CUDA_VERSION
  if (cuda_device_capability >= 90) {
    return CUDA_VERSION >= 12000;
  } else if (cuda_device_capability >= 89) {
41
42
43
44
45
46
47
    // CUTLASS Kernels have not been tuned for Ada Lovelace systems
    // and are slower than torch.mm. Return false unconditionally in this case.
    return false;

    // Once the CUTLASS kernels have been optimized for Lovelace systems,
    // use the following check:
    // return CUDA_VERSION >= 12040;
48
49
50
51
52
53
  }
#endif

  return false;
}

54
55
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
                       torch::Tensor const& b, torch::Tensor const& a_scales,
56
57
                       torch::Tensor const& b_scales,
                       c10::optional<torch::Tensor> const& bias) {
58
59
60
61
62
63
64
65
66
67
68
  int32_t major_capability;
  int32_t minor_capability;
  cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
                         0);
  cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
                         0);
  int32_t version_num = major_capability * 10 + minor_capability;

  // Checks for conformality
  TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
  TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
69
              b.size(1) == c.size(1));
70
71
72
73
  TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
  TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));

  // Check for strides and alignment
74
75
76
77
  TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1);  // Row-major
  TORCH_CHECK(b.stride(0) == 1);                      // Column-major
  TORCH_CHECK(c.stride(0) % 16 == 0 &&
              b.stride(1) % 16 == 0);  // 16 Byte Alignment
78
79
  TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());

80
81
82
83
84
  if (bias) {
    TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
                bias->dim() == 1);
  }

85
86
87
88
  at::cuda::OptionalCUDAGuard const device_guard(device_of(a));

  if (version_num >= 90) {
    // Hopper
89
90
91

    // Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
92
    cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
93
#else
94
    cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
95
#endif
96
97
  } else if (version_num == 89) {
    // Ada Lovelace
98
    cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
99
100
  } else if (version_num >= 80) {
    // Ampere
101
    cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
102
103
104
  } else {
    // Turing
    TORCH_CHECK(version_num >= 75);
105
    cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
106
  }
107
}