scaled_mm_c2x.cu 6.33 KB
Newer Older
1
#include <stddef.h>
2
#include <torch/all.h>
3
4
#include "cutlass/cutlass.h"

5
6
#include "scaled_mm_c2x.cuh"
#include "scaled_mm_c2x_sm80_dispatch.cuh"
7
8
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
9
10

/*
11
   This file defines quantized GEMM operations using the CUTLASS 2.x API, for
12
13
14
   NVIDIA GPUs with SM versions prior to sm90 (Hopper).
*/

15
16
17
18
19
template <template <typename, typename> typename Epilogue,
          typename... EpilogueArgs>
void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
                                     torch::Tensor const& b,
                                     EpilogueArgs&&... epilogue_args) {
20
21
22
23
24
25
26
27
  TORCH_CHECK(a.dtype() == torch::kInt8);
  TORCH_CHECK(b.dtype() == torch::kInt8);

  using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
  using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
  using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;

  if (out.dtype() == torch::kBFloat16) {
28
29
30
31
    return vllm::cutlass_gemm_caller<
        vllm::cutlass_2x_gemm<cutlass::arch::Sm75, vllm::enable_sm75_to_sm80,
                              int8_t, cutlass::bfloat16_t, Epilogue, TileShape,
                              WarpShape, InstructionShape, 2>>(
32
        out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
33
34
  } else {
    TORCH_CHECK(out.dtype() == torch::kFloat16);
35
36
    return vllm::cutlass_gemm_caller<vllm::cutlass_2x_gemm<
        cutlass::arch::Sm75, vllm::enable_sm75_to_sm80, int8_t, cutlass::half_t,
37
38
        Epilogue, TileShape, WarpShape, InstructionShape, 2>>(
        out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
39
40
41
  }
}

42
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
43
44
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
45
46
                            torch::Tensor const& b_scales,
                            c10::optional<torch::Tensor> const& bias) {
47
48
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
49
50
51
  if (bias) {
    TORCH_CHECK(bias->dtype() == out.dtype(),
                "currently bias dtype must match output dtype ", out.dtype());
52
    return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBias>(
53
54
        out, a, b, a_scales, b_scales, *bias);
  } else {
55
56
    return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogue>(
        out, a, b, a_scales, b_scales);
57
58
59
60
61
62
63
64
65
66
  }
}

template <template <typename, typename> typename Epilogue,
          typename... EpilogueArgs>
void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
                                     torch::Tensor const& b,
                                     EpilogueArgs&&... epilogue_args) {
  TORCH_CHECK(a.dtype() == torch::kInt8);
  TORCH_CHECK(b.dtype() == torch::kInt8);
67
68

  if (out.dtype() == torch::kBFloat16) {
69
70
    return vllm::cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t,
                                            Epilogue>(
71
        out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
72
73
  } else {
    TORCH_CHECK(out.dtype() == torch::kFloat16);
74
    return vllm::cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
75
        out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
76
77
78
  }
}

79
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
80
81
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
82
83
84
85
86
87
88
                            torch::Tensor const& b_scales,
                            c10::optional<torch::Tensor> const& bias) {
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
  if (bias) {
    TORCH_CHECK(bias->dtype() == out.dtype(),
                "currently bias dtype must match output dtype ", out.dtype());
89
    return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBias>(
90
91
        out, a, b, a_scales, b_scales, *bias);
  } else {
92
93
    return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogue>(
        out, a, b, a_scales, b_scales);
94
95
96
97
98
99
100
101
  }
}

template <template <typename, typename> typename Epilogue,
          typename... EpilogueArgs>
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
                                     torch::Tensor const& b,
                                     EpilogueArgs&&... epilogue_args) {
102
103
104
105
  if (a.dtype() == torch::kInt8) {
    TORCH_CHECK(b.dtype() == torch::kInt8);

    if (out.dtype() == torch::kBFloat16) {
106
107
      return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
                                                   Epilogue>(
108
          out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
109
110
    } else {
      assert(out.dtype() == torch::kFloat16);
111
112
      return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t,
                                                   Epilogue>(
113
          out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
114
115
116
117
118
119
    }
  } else {
    TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
    TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);

    if (out.dtype() == torch::kBFloat16) {
120
121
      return vllm::cutlass_gemm_sm89_fp8_dispatch<
          cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>(
122
          out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
123
124
    } else {
      TORCH_CHECK(out.dtype() == torch::kFloat16);
125
126
      return vllm::cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
                                                  cutlass::half_t, Epilogue>(
127
          out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
128
129
130
    }
  }
}
131
132
133
134
135
136
137
138
139
140
141

void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
                            torch::Tensor const& b_scales,
                            c10::optional<torch::Tensor> const& bias) {
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
  if (bias) {
    TORCH_CHECK(bias->dtype() == out.dtype(),
                "currently bias dtype must match output dtype ", out.dtype());
142
    return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBias>(
143
144
        out, a, b, a_scales, b_scales, *bias);
  } else {
145
146
    return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogue>(
        out, a, b, a_scales, b_scales);
147
148
  }
}