machete_mm_kernel.cuh 9.58 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#pragma once

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>

// clang-format off
// The cutlass include order matters (annoyingly)
#include "cutlass/cutlass.h"

#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on

#include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/vllm_numeric_conversion.cuh"
#include "machete_collective_builder.cuh"
#include "machete_prepacked_layout.cuh"
#include "machete_interleaving_utils.cuh"

namespace machete {

using namespace cute;

// NOTE This kernel computes D = alpha * A * B + beta * C by computing
//   D^t = alpha * B^t * A^t + beta * C^t, this is because the wgmma
//   instructions only support sourcing from registers for the left-hand
//   operand, we want to upconvert/decompress the quantized operand in
//   register. Since the primary use case we want to support is Y = XW^t where
//   W is quantized, in this situation or right-hand operand is quantized so
//   we compute the transpose to move it to the left-hand side.
template <typename ElementA_, typename ElementB_, typename ElementD_,
          typename AccumulatorT, typename ScaleT, typename ZeroT,
          class KernelSchedule, typename ScheduleConfig, bool with_C,
          bool with_scales, bool with_zeropoints>
struct MacheteKernelTemplate {
  using MmaType = ElementA_;
  using ElementA = ElementA_;
  using ElementB = ElementB_;
  using ElementD = ElementD_;
  using ElementC = cute::conditional_t<with_C, ElementD, void>;
  using ElementZ = ZeroT;
  using ElementS = ScaleT;

  using ElementAccumulator =
      AccumulatorT;  // Element type for internal accumulation
  using ElementCompute = AccumulatorT;  // For Epilogue

  using BTypeTuple = cute::conditional_t<
      with_scales,
      cute::conditional_t<with_zeropoints,
                          cute::tuple<ElementB, ElementS, ElementZ>,
                          cute::tuple<ElementB, ElementS>>,
      ElementB>;

  using LayoutA = cutlass::layout::RowMajor;
  using LayoutC = cutlass::layout::RowMajor;
  using LayoutD = LayoutC;
  using LayoutScale = cutlass::layout::RowMajor;
  // not actually used since B has the prepacked layout, but required by cutlass
  using _LayoutB = cutlass::layout::ColumnMajor;

  // Interface strides expected by create_arguments (will get transposed)
  using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
  using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
  using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
  using StrideS = cutlass::detail::TagToStrideA_t<LayoutScale>;
  using StrideZ = StrideS;

  using LayoutA_Transpose =
      typename cutlass::layout::LayoutTranspose<LayoutA>::type;
  using LayoutC_Transpose =
      typename cutlass::layout::LayoutTranspose<LayoutC>::type;
  using LayoutD_Transpose =
      typename cutlass::layout::LayoutTranspose<LayoutD>::type;

  using ArchTag = cutlass::arch::Sm90;
  using OperatorClass = cutlass::arch::OpClassTensorOp;

  using PrepackedLayoutB =
      PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementD_, AccumulatorT,
                               LayoutA_Transpose, KernelSchedule>;

  static int constexpr TileShapeK =
      128 * 8 / cutlass::sizeof_bits<MmaType>::value;
  static int constexpr AlignmentA = 128 / cutlass::sizeof_bits_v<ElementA>;
  static int constexpr AlignmentB = 128 / cutlass::sizeof_bits_v<ElementB>;
  static int constexpr AlignmentC =
      (with_C) ? 128 / cutlass::sizeof_bits_v<ElementC> : 0;
  static int constexpr AlignmentD = 128 / cutlass::sizeof_bits_v<ElementD>;

  using TileShape = decltype(append(typename ScheduleConfig::TileShapeNM{},
                                    cute::Int<TileShapeK>{}));
  using ClusterShape = typename ScheduleConfig::ClusterShape;
  using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule;
  using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
  using TileScheduler = typename ScheduleConfig::TileScheduler;

  using CollectiveEpilogue =
      typename cutlass::epilogue::collective::CollectiveBuilder<
          ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
          ElementAccumulator, ElementAccumulator, ElementC, LayoutC_Transpose,
          AlignmentC, ElementD, LayoutD_Transpose, AlignmentD,
          EpilogueSchedule>::CollectiveOp;

  using CollectiveMainloop =
      typename cutlass::gemm::collective::VLLMCollectiveBuilder<
          cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass,
          BTypeTuple, PrepackedLayoutB, AlignmentB, ElementA, LayoutA_Transpose,
          AlignmentA, ElementAccumulator, TileShape, ClusterShape,
          cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
              sizeof(typename CollectiveEpilogue::SharedStorage))>,
          KernelSchedule>::CollectiveOp;

  using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
      Shape<int, int, int, int>,  // Indicates ProblemShape
      CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

  // stride_B is unused (since B is prepacked), but still required by cutlass
  using _StrideB = cutlass::detail::TagToStrideB_t<_LayoutB>;

  using Arguments = typename Gemm::Arguments;
  using MainloopArguments = typename GemmKernel::MainloopArguments;
  using EpilogueArguments = typename GemmKernel::EpilogueArguments;

  template <typename ShapeA, typename ShapeC, typename ShapeD, typename ShapeS,
            typename ShapeZ>
  static Arguments create_arguments(
      cudaStream_t stream,
      ElementA const* A_ptr,  // A is an MxK matrix
      Layout<ShapeA, StrideA> const& layout_A,
      ElementB const* B_ptr,  // B is an KxN prepacked matrix
      ElementD* D_ptr,        // D is an MxN matrix
      Layout<ShapeD, StrideD> const& layout_D,
      ElementC const* C_ptr,  // C is an MxN matrix
      std::optional<Layout<ShapeC, StrideC>> const& layout_C,
      ElementS const* S_ptr,  // S is an scale_KxN matrix
      std::optional<Layout<ShapeS, StrideS>> const& layout_S,
      ElementZ const* Z_ptr,  // Z is an scale_KxN matrix
      std::optional<Layout<ShapeZ, StrideZ>> const& layout_Z,
      ElementCompute alpha, ElementCompute beta,
      std::optional<int> maybe_group_size) {
    static_assert(!with_zeropoints || with_scales);

    int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A);

155
156
    int const group_size =
        maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    int const scale_k = (K + group_size - 1) / group_size;

    TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
    TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N);

    if constexpr (with_C) {
      TORCH_CHECK(C_ptr && layout_C);
    } else {
      TORCH_CHECK(!C_ptr, "C not supported");
    }

    if constexpr (with_scales) {
      TORCH_CHECK(S_ptr && layout_S);
      TORCH_CHECK((size<0>(*layout_S) == scale_k && size<1>(*layout_S) == N));
    } else {
      TORCH_CHECK(!S_ptr, "Scales not supported");
    }

    if constexpr (with_zeropoints) {
      TORCH_CHECK(Z_ptr && layout_Z);
      TORCH_CHECK((size<0>(*layout_Z) == scale_k && size<1>(*layout_Z) == N));
      TORCH_CHECK(layout_S && *layout_Z == *layout_S,
                  "Scales and zeros must have the same layout");
    } else {
      TORCH_CHECK(!Z_ptr, "Zeropoints not supported");
    }

    // Transpose A and D
    // A doesn't need to be transposed since cutlass expects a NxK matrix
    //  for B (which is At)
    auto stride_At = layout_A.stride();
    auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
    auto stride_Ct = stride_Dt;
    if (layout_C) {
      stride_Ct = permute_layout<1, 0, 2>(*layout_C).stride();
    }

    MainloopArguments mainloop_arguments{};
    EpilogueArguments epilogue_arguments{
        {alpha, beta}, C_ptr, stride_Ct, D_ptr, stride_Dt};

    if constexpr (with_scales && with_zeropoints) {
      auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
      mainloop_arguments =
          MainloopArguments{B_ptr, _StrideB{}, A_ptr,      stride_At,
                            S_ptr, stride_S,   group_size, Z_ptr};
    } else if constexpr (with_scales) {
      auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
      mainloop_arguments = MainloopArguments{
          B_ptr, _StrideB{}, A_ptr, stride_At, S_ptr, stride_S, group_size};
    } else {
      mainloop_arguments =
          MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At};
    }

    return Arguments{cutlass::gemm::GemmUniversalMode::kGemm,
                     {N, M, K, 1},
                     mainloop_arguments,
                     epilogue_arguments};
  };

  static size_t get_workspace_size(Arguments const& args) {
    return Gemm::get_workspace_size(args);
  }

  static bool can_implement(Arguments const& args) {
    return Gemm::can_implement(args) == cutlass::Status::kSuccess;
  }

  static void run(Arguments const& args, void* workspace, cudaStream_t stream) {
    Gemm gemm_op;

    cutlass::Status status = gemm_op.initialize(args, workspace, stream);
    TORCH_CHECK(status == cutlass::Status::kSuccess,
                "Machete kernel failed to initialize workspace");

    status = gemm_op.run(stream);
    TORCH_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed");
  }
};

};  // namespace machete