machete_prepacked_layout.cuh 10.4 KB
Newer Older
raojy's avatar
raojy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
#pragma once

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>

// clang-format off
// The cutlass include order matters (annoyingly)

#include "cutlass/cutlass.h"

#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on

#include "cutlass_extensions/cute_utils.cuh"
#include "machete_collective_builder.cuh"
#include "machete_interleaving_utils.cuh"

namespace machete {

using namespace cute;

struct IlvBlkLayoutAuto {};

// This defines a prepacked layout for the B matrix, where the matrix is broken
// up into PPBlockShape_NK blocks. The data within each block is then compactly
// stored in memory such that when performing a TiledMMA operation with the same
// shape as prepacked block, all the data for a given thread is contiguous in
// memory. This allows us to use wider shared memory loads when loading B from
// shared memory. The values within a thread are also potentially interlaeved
// inorder to allow for more efficient upconverting.
//
// The contract here is that the `TiledMma` determined below matches the one
// ultimately used in the kernel. (this is also why the other element types are
// required along with the kernel schedule)
template <typename ElementA_, typename ElementB_, typename ElementConvert_,
          typename AccumulatorT, class LayoutB, class KernelSchedule,
          typename IlvBlkLayout_ = IlvBlkLayoutAuto>
// clang-format on
struct PrepackedLayoutBTemplate {
  using MmaType = ElementA_;
  using ElementA = ElementA_;
  using ElementB = ElementB_;
  using ElementAccumulator = AccumulatorT;
  using ElementMma = MmaType;

  // Interleave for 4bit bit types when we are not upconverting to fp8 or int8,
  // in those cases case we use a LUT using prmt instructions to upconvert and
  // is more efficient if the data is not interleaved For 8bit+ prmt
  // instructions makes non-interleaved layouts efficient enough we don't need
  // iterleaved layouts (and can reuse more of the existing cutlass converts)
  static constexpr bool should_interleave =
      sizeof_bits_v<ElementB> <= 4 &&
      !std::is_same_v<ElementConvert_, cutlass::float_e4m3_t> &&
      !std::is_same_v<ElementConvert_, int8_t>;

  // Only use interleaved layouts for subbyte weights,
  using IlvdBlkLayout = std::conditional_t<
      std::is_same_v<IlvBlkLayout_, IlvBlkLayoutAuto>,
      std::conditional_t<
          should_interleave,
          decltype(get_interleaved_blk_layout<
                   ElementB, sizeof_bits_v<ElementConvert_>, 32>()),
          void>,
      IlvBlkLayout_>;

  // TODO (LucasWilkinson): compare the performance for other sizes
  // Prepacked block shape, smallest layout atom for loading into registers
  //   (can contain multiple wgmma instructions worth of data in one block)
  // We ideally want this to be configured such that a thread can perform 128bit
  // loads, i.e. we amount of data associated with each thread within a
  // prepacked block is a multiple of 128bits, when using a cooperative sechdule
  // we have 256 threads working a single block at a time, this means each
  // thread works on `sizeof_bits_v<ElementB> * (128*64) / 256` bits of data,
  // for a 4bit type this would be 128bits
  using PPBlockShape_NK = Shape<_128, _64>;

  // Create the shape of the tile anticipated to be used by the GEMM kernel,
  //  when the kernel executes we will compute `Ct = Bt * At` since the
  //  quantized weights (B), must be the lhs operand so the flow through
  //  registers.
  // The _128 here doesn't actually impact the shape of the stored tile directly
  //  but may impact the op selected by rs_op_selector
  using GemmTileShape = decltype(make_shape(size<0>(PPBlockShape_NK{}), _128{},
                                            size<1>(PPBlockShape_NK{})));

  static constexpr cute::GMMA::Major GmmaMajorB =
      gmma_rs_tag_to_major_B<LayoutB>();

  // For coop schedules we have two warp groups cooperatively issuing wgmma
  // instructions so we use 2 atoms along the M dim (one for each warpgroup)
  using AtomLayoutMNK = cute::conditional_t<
      cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperative>,
      Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;

  using TiledMma = decltype(cute::make_tiled_mma(
      cute::GMMA::rs_op_selector<ElementMma, ElementMma, ElementAccumulator,
                                 GemmTileShape, GMMA::Major::K, GmmaMajorB>(),
      AtomLayoutMNK{}));

  // Prepacked block, (athrid, val) -> (N,K)
  // i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (N,K)
  CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_NK() {
    return TiledMma{}.thrfrg_A(make_layout(PPBlockShape_NK{}));
  }

  // Prepacked block, (N,K) -> (athrid, val)
  // i.e. (N,K) -> ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...)))
  CUTE_HOST_DEVICE static constexpr auto ppblock_NK_to_TV() {
    return right_inverse(ppblock_TV_to_NK()).with_shape(PPBlockShape_NK{});
  }

  // Prepacked block, (athrid, val) -> (storage_offset)
  // i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (storage_idx)
  CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_offset() {
    // Return iterleaved layout
    return make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
  }

  // Prepacked block, (athrid, val) -> (storage_offset)
  // i.e. ((ThrV,(ThrM,ThrK)),(IlvdFrgV,(RestM,RestK,...))) -> (storage_idx)
  CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_TV_to_offset() {
    auto layout_no_interleave =
        make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});

    if constexpr (std::is_same_v<IlvdBlkLayout, void>) {
      return layout_no_interleave;
    } else {
      // interleave by transforming FrgV into interleaved blocks where each
      // block has the layout IlvdBlkLayout, for example if IlvdBlkLayout is
      // (2, 2) : (2, 1) then we get: ((2, 2), size(FrgV) / 4) : ((2, 1), 4)
      //   if FrgV is {A, B, C, D, E, F, G, H}
      //   then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H}
      auto frgV = get<1, 0>(layout_no_interleave);
      auto ilvdBlk = IlvdBlkLayout{};
      static_assert(size(frgV) % size(ilvdBlk) == 0,
                    "FrgV must be divisible by size(ilvdBlk)");
      auto ilvd_FrgV = make_layout(
          make_shape(shape(ilvdBlk), Int<size(frgV) / size(ilvdBlk)>{}),
          make_stride(stride(ilvdBlk), size(ilvdBlk)));

      // Return iterleaved layout
      return make_layout(
          get<0>(layout_no_interleave),
          make_layout(ilvd_FrgV, get<1, 1>(layout_no_interleave)));
    }
  }

  // Prepacked block, (M,K) -> (storage_offset)
  CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_NK_to_offset() {
    // do (M,K) -> (athrid, val) -> (storage_idx)
    return ppblock_ilvd_TV_to_offset().compose(ppblock_NK_to_TV());
  }

  // ((athrid, val), (BlocksN, BlocksK), L) -> (storage_idx)
  template <typename Shape_NKL>
  CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset(
      Shape_NKL shape_mkl) {
    constexpr auto block_layout = ppblock_TV_to_offset();

    // (BlocksN, BlocksK, L)
    auto blocks_shape =
        cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
                        [](auto x, auto y) { return x / y; });

    // ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
    auto result = make_layout(
        block_layout,
        make_layout(blocks_shape,
                    compact_col_major(blocks_shape, size(block_layout))));

    // ((athrid, val), (BlocksN, BlocksK, L))
    //   => ((athrid, val), (BlocksN, BlocksK), L)
    return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
  }

  // ((athrid_val), (BlocksN, BlocksK, L)) -> (N, K, L)
  template <typename Shape_NKL>
  CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset_copy(
      Shape_NKL shape_mkl) {
    auto layout = TVbNbKL_to_offset(shape_mkl);
    // for 4-bit elements, having >= 64 values per column
    // allows TMA to load full 32-byte sectors
    auto inner_layout =
        make_layout(make_shape(_256{}, size<0>(layout) / _256{}));

    return make_layout(inner_layout, get<1>(layout), get<2>(layout));
  }

  // ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx)
  template <typename Shape_NKL>
  CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset(
      Shape_NKL shape_mkl) {
    constexpr auto block_layout = ppblock_ilvd_NK_to_offset();

    // (BlocksN, BlocksK, L)
    auto blocks_shape =
        cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
                        [](auto x, auto y) { return x / y; });

    // ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
    auto result = make_layout(
        block_layout,
        make_layout(blocks_shape,
                    compact_col_major(blocks_shape, size(block_layout))));

    // ((athrid, val), (BlocksN, BlocksK, L)) => ((athrid, val), (BlocksN,
    // BlocksK), L)
    return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
  }

  // (BlocksN, BlocksK, L) -> (storage_idx)
  template <typename Shape_NKL>
  CUTE_HOST_DEVICE static constexpr auto bNbKL_to_offset(Shape_NKL shape_mkl) {
    // (BlocksN, BlocksK, L)
    auto blocks_shape =
        cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
                        [](auto x, auto y) { return x / y; });
    auto stride = size(PPBlockShape_NK{});

    // (BlocksN, BlocksK, L) -> (storage_idx)
    return make_layout(blocks_shape, compact_col_major(blocks_shape, stride));
  }

  // ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L)
  template <class Shape_NKL>
  CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) {
    auto tile = make_tile(make_layout(size<0>(PPBlockShape_NK{})),
                          make_layout(size<1>(PPBlockShape_NK{})));

    // ((BlockN, BlockK), (BlocksN, BlocksK, L)) -> (N, K, L)
    auto tiled_A = zipped_divide(make_layout(shape_mkl), tile);
    return tiled_A.compose(ppblock_TV_to_NK(), _);
  }

  // (N, K, L) -> ((athrid, val), (BlocksN, BlocksK), L)
  template <class Shape_NKL>
  CUTE_HOST_DEVICE static auto NKL_to_TVbNbK(Shape_NKL shape_mkl) {
    auto TVbNbK_to_NKL_layout = TVbNbK_to_NKL(shape_mkl);
    return blocked_product(ppblock_NK_to_TV(),
                           make_layout(shape<1>(TVbNbK_to_NKL_layout)));
  }
};

};  // namespace machete