scaled_mm_c2x.cu 8.38 KB
Newer Older
1
#include <stddef.h>
2
#include <torch/all.h>
3
4
#include "cutlass/cutlass.h"

5
#include "scaled_mm_c2x.cuh"
6
#include "scaled_mm_c2x_sm75_dispatch.cuh"
7
#include "scaled_mm_c2x_sm80_dispatch.cuh"
8
9
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
10

11
12
13
14
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"

using namespace vllm;

15
/*
16
   This file defines quantized GEMM operations using the CUTLASS 2.x API, for
17
18
19
   NVIDIA GPUs with SM versions prior to sm90 (Hopper).
*/

20
21
22
23
24
template <template <typename, typename> typename Epilogue,
          typename... EpilogueArgs>
void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
                                     torch::Tensor const& b,
                                     EpilogueArgs&&... epilogue_args) {
25
26
27
28
  TORCH_CHECK(a.dtype() == torch::kInt8);
  TORCH_CHECK(b.dtype() == torch::kInt8);

  if (out.dtype() == torch::kBFloat16) {
29
    return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
30
        out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
31
32
  } else {
    TORCH_CHECK(out.dtype() == torch::kFloat16);
33
    return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
34
        out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
35
36
37
  }
}

38
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
39
40
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
41
                            torch::Tensor const& b_scales,
42
                            std::optional<torch::Tensor> const& bias) {
43
44
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
45
46
47
  if (bias) {
    TORCH_CHECK(bias->dtype() == out.dtype(),
                "currently bias dtype must match output dtype ", out.dtype());
48
    return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
49
50
        out, a, b, a_scales, b_scales, *bias);
  } else {
51
    return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
52
        out, a, b, a_scales, b_scales);
53
54
55
  }
}

56
57
58
59
60
void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
                                torch::Tensor const& b,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& b_scales,
                                torch::Tensor const& azp_adj,
61
62
                                std::optional<torch::Tensor> const& azp,
                                std::optional<torch::Tensor> const& bias) {
63
64
65
66
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

  if (azp) {
67
    return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
68
69
        out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
  } else {
70
    return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
71
72
73
74
        out, a, b, a_scales, b_scales, azp_adj, bias);
  }
}

75
76
77
78
79
80
81
template <template <typename, typename> typename Epilogue,
          typename... EpilogueArgs>
void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
                                     torch::Tensor const& b,
                                     EpilogueArgs&&... epilogue_args) {
  TORCH_CHECK(a.dtype() == torch::kInt8);
  TORCH_CHECK(b.dtype() == torch::kInt8);
82
83

  if (out.dtype() == torch::kBFloat16) {
84
    return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
85
        out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
86
87
  } else {
    TORCH_CHECK(out.dtype() == torch::kFloat16);
88
    return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
89
        out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
90
91
92
  }
}

93
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
94
95
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
96
                            torch::Tensor const& b_scales,
97
                            std::optional<torch::Tensor> const& bias) {
98
99
100
101
102
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
  if (bias) {
    TORCH_CHECK(bias->dtype() == out.dtype(),
                "currently bias dtype must match output dtype ", out.dtype());
103
    return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
104
105
        out, a, b, a_scales, b_scales, *bias);
  } else {
106
    return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
107
        out, a, b, a_scales, b_scales);
108
109
110
  }
}

111
112
113
114
115
void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
                                torch::Tensor const& b,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& b_scales,
                                torch::Tensor const& azp_adj,
116
117
                                std::optional<torch::Tensor> const& azp,
                                std::optional<torch::Tensor> const& bias) {
118
119
120
121
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

  if (azp) {
122
    return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
123
124
        out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
  } else {
125
    return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
126
127
128
129
        out, a, b, a_scales, b_scales, azp_adj, bias);
  }
}

130
131
132
133
134
template <template <typename, typename> typename Epilogue,
          typename... EpilogueArgs>
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
                                     torch::Tensor const& b,
                                     EpilogueArgs&&... epilogue_args) {
135
136
137
138
  if (a.dtype() == torch::kInt8) {
    TORCH_CHECK(b.dtype() == torch::kInt8);

    if (out.dtype() == torch::kBFloat16) {
139
140
      return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
                                             Epilogue>(
141
          out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
142
143
    } else {
      assert(out.dtype() == torch::kFloat16);
144
      return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
145
          out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
146
147
148
149
150
151
    }
  } else {
    TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
    TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);

    if (out.dtype() == torch::kBFloat16) {
152
153
      return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
                                            cutlass::bfloat16_t, Epilogue>(
154
          out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
155
156
    } else {
      TORCH_CHECK(out.dtype() == torch::kFloat16);
157
158
      return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
                                            cutlass::half_t, Epilogue>(
159
          out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
160
161
162
    }
  }
}
163
164
165
166
167

void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
                            torch::Tensor const& b_scales,
168
                            std::optional<torch::Tensor> const& bias) {
169
170
171
172
173
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
  if (bias) {
    TORCH_CHECK(bias->dtype() == out.dtype(),
                "currently bias dtype must match output dtype ", out.dtype());
174
    return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
175
176
        out, a, b, a_scales, b_scales, *bias);
  } else {
177
    return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
178
        out, a, b, a_scales, b_scales);
179
180
  }
}
181
182
183
184
185
186

void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
                                torch::Tensor const& b,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& b_scales,
                                torch::Tensor const& azp_adj,
187
188
                                std::optional<torch::Tensor> const& azp,
                                std::optional<torch::Tensor> const& bias) {
189
190
191
192
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

  if (azp) {
193
    return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
194
195
        out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
  } else {
196
    return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
197
198
199
        out, a, b, a_scales, b_scales, azp_adj, bias);
  }
}