scaled_mm_c2x.cu 8.47 KB
Newer Older
1
#include <stddef.h>
2
#include <torch/all.h>
3
4
#include "cutlass/cutlass.h"

5
#include "scaled_mm_c2x.cuh"
6
#include "scaled_mm_c2x_sm75_dispatch.cuh"
7
#include "scaled_mm_c2x_sm80_dispatch.cuh"
8
9
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
10
11

/*
12
   This file defines quantized GEMM operations using the CUTLASS 2.x API, for
13
14
15
   NVIDIA GPUs with SM versions prior to sm90 (Hopper).
*/

16
17
18
19
20
template <template <typename, typename> typename Epilogue,
          typename... EpilogueArgs>
void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
                                     torch::Tensor const& b,
                                     EpilogueArgs&&... epilogue_args) {
21
22
23
24
  TORCH_CHECK(a.dtype() == torch::kInt8);
  TORCH_CHECK(b.dtype() == torch::kInt8);

  if (out.dtype() == torch::kBFloat16) {
25
26
    return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t,
                                            Epilogue>(
27
        out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
28
29
  } else {
    TORCH_CHECK(out.dtype() == torch::kFloat16);
30
    return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
31
        out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
32
33
34
  }
}

35
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
36
37
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
38
39
                            torch::Tensor const& b_scales,
                            c10::optional<torch::Tensor> const& bias) {
40
41
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
42
43
44
  if (bias) {
    TORCH_CHECK(bias->dtype() == out.dtype(),
                "currently bias dtype must match output dtype ", out.dtype());
45
    return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBias>(
46
47
        out, a, b, a_scales, b_scales, *bias);
  } else {
48
49
    return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogue>(
        out, a, b, a_scales, b_scales);
50
51
52
  }
}

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
                                torch::Tensor const& b,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& b_scales,
                                torch::Tensor const& azp_adj,
                                c10::optional<torch::Tensor> const& azp,
                                c10::optional<torch::Tensor> const& bias) {
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

  if (azp) {
    return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
        out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
  } else {
    return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBiasAzp>(
        out, a, b, a_scales, b_scales, azp_adj, bias);
  }
}

72
73
74
75
76
77
78
template <template <typename, typename> typename Epilogue,
          typename... EpilogueArgs>
void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
                                     torch::Tensor const& b,
                                     EpilogueArgs&&... epilogue_args) {
  TORCH_CHECK(a.dtype() == torch::kInt8);
  TORCH_CHECK(b.dtype() == torch::kInt8);
79
80

  if (out.dtype() == torch::kBFloat16) {
81
82
    return vllm::cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t,
                                            Epilogue>(
83
        out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
84
85
  } else {
    TORCH_CHECK(out.dtype() == torch::kFloat16);
86
    return vllm::cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
87
        out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
88
89
90
  }
}

91
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
92
93
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
94
95
96
97
98
99
100
                            torch::Tensor const& b_scales,
                            c10::optional<torch::Tensor> const& bias) {
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
  if (bias) {
    TORCH_CHECK(bias->dtype() == out.dtype(),
                "currently bias dtype must match output dtype ", out.dtype());
101
    return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBias>(
102
103
        out, a, b, a_scales, b_scales, *bias);
  } else {
104
105
    return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogue>(
        out, a, b, a_scales, b_scales);
106
107
108
  }
}

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
                                torch::Tensor const& b,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& b_scales,
                                torch::Tensor const& azp_adj,
                                c10::optional<torch::Tensor> const& azp,
                                c10::optional<torch::Tensor> const& bias) {
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

  if (azp) {
    return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
        out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
  } else {
    return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBiasAzp>(
        out, a, b, a_scales, b_scales, azp_adj, bias);
  }
}

128
129
130
131
132
template <template <typename, typename> typename Epilogue,
          typename... EpilogueArgs>
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
                                     torch::Tensor const& b,
                                     EpilogueArgs&&... epilogue_args) {
133
134
135
136
  if (a.dtype() == torch::kInt8) {
    TORCH_CHECK(b.dtype() == torch::kInt8);

    if (out.dtype() == torch::kBFloat16) {
137
138
      return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
                                                   Epilogue>(
139
          out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
140
141
    } else {
      assert(out.dtype() == torch::kFloat16);
142
143
      return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t,
                                                   Epilogue>(
144
          out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
145
146
147
148
149
150
    }
  } else {
    TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
    TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);

    if (out.dtype() == torch::kBFloat16) {
151
152
      return vllm::cutlass_gemm_sm89_fp8_dispatch<
          cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>(
153
          out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
154
155
    } else {
      TORCH_CHECK(out.dtype() == torch::kFloat16);
156
157
      return vllm::cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
                                                  cutlass::half_t, Epilogue>(
158
          out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
159
160
161
    }
  }
}
162
163
164
165
166
167
168
169
170
171
172

void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
                            torch::Tensor const& b_scales,
                            c10::optional<torch::Tensor> const& bias) {
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
  if (bias) {
    TORCH_CHECK(bias->dtype() == out.dtype(),
                "currently bias dtype must match output dtype ", out.dtype());
173
    return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBias>(
174
175
        out, a, b, a_scales, b_scales, *bias);
  } else {
176
177
    return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogue>(
        out, a, b, a_scales, b_scales);
178
179
  }
}
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198

void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
                                torch::Tensor const& b,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& b_scales,
                                torch::Tensor const& azp_adj,
                                c10::optional<torch::Tensor> const& azp,
                                c10::optional<torch::Tensor> const& bias) {
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

  if (azp) {
    return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBiasAzpToken>(
        out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
  } else {
    return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBiasAzp>(
        out, a, b, a_scales, b_scales, azp_adj, bias);
  }
}