gemm_sm80.h 14.6 KB
Newer Older
1
2
#pragma once

3
#include <cute/algorithm/clear.hpp>
4
5
6
#include <cute/arch/mma_sm80.hpp>
#include <cute/atom/mma_atom.hpp>
#include <cute/underscore.hpp>
7
8
9
10
11

#include "common.h"

namespace cute {

12
13
template <typename A_type, typename B_type, typename C_type, int num_warp_m,
          int num_warp_n>
14
15
struct DispatchInstruction;

16
17
using _X = Underscore;

18
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
19
20
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n> {
21
  using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
22
  using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
23
};
24
25
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n> {
26
  using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
27
  using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
28
};
29
30
31
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m,
                           num_warp_n> {
32
  using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
33
  using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
34
};
35
36
37
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m,
                           num_warp_n> {
38
  using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
39
  using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
40
};
41
42
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n> {
43
  using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
44
  using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
45
};
46
47
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n> {
48
  using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
49
  using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>;
50
51
};
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
52
53
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n> {
54
  using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
55
  using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _16>;
56
57
58
59
60
61
62
};
#endif

template <int Bits, int N, int K, bool K_inner, typename Enable = void>
struct OperandTraits {
  // Primary template, use padded layout and default copy
  static constexpr int stride = K_inner ? K : N;
63
64
65
66
67
  static constexpr int padded =
      stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride;
  using Layout = typename std::conditional<
      K_inner, Layout<Shape<Int<N>, Int<K>>, Shape<Int<padded>, _1>>,
      Layout<Shape<Int<N>, Int<K>>, Shape<_1, Int<padded>>>>::type;
68
69
70
71
  using Copy = DefaultCopy;
};

template <int N, int K>
72
73
74
75
struct OperandTraits<16, N, K, true,
                     typename std::enable_if<K % 64 == 32>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
76
77
78
79
80
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
  using Copy = SM75_U32x4_LDSM_N;
};

template <int N, int K>
81
82
83
84
struct OperandTraits<16, N, K, true,
                     typename std::enable_if<K % 64 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
85
86
87
88
89
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
  using Copy = SM75_U32x4_LDSM_N;
};

template <int N, int K>
90
91
92
93
94
95
struct OperandTraits<16, N, K, false,
                     typename std::enable_if<N % 64 == 32>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
96
97
98
99
  using Copy = SM75_U16x8_LDSM_T;
};

template <int N, int K>
100
101
102
103
104
105
struct OperandTraits<16, N, K, false,
                     typename std::enable_if<N % 64 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
106
107
108
109
  using Copy = SM75_U16x8_LDSM_T;
};

template <int N, int K>
110
111
112
113
struct OperandTraits<32, N, K, true,
                     typename std::enable_if<K % 32 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
114
115
116
117
118
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
  using Copy = SM75_U32x4_LDSM_N;
};

template <int N, int K>
119
120
121
122
struct OperandTraits<32, N, K, true,
                     typename std::enable_if<K % 32 == 16>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
123
124
125
126
127
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
  using Copy = SM75_U32x4_LDSM_N;
};

template <int N, int K>
128
129
130
131
132
133
struct OperandTraits<32, N, K, false,
                     typename std::enable_if<N % 32 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
134
135
136
137
  using Copy = UniversalCopy<tfloat32_t>;
};

template <int N, int K>
138
139
140
141
142
143
struct OperandTraits<32, N, K, false,
                     typename std::enable_if<N % 32 == 16>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
144
145
146
147
  using Copy = UniversalCopy<tfloat32_t>;
};

template <int N, int K>
148
149
150
151
struct OperandTraits<8, N, K, true,
                     typename std::enable_if<K % 128 == 64>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
152
153
154
155
156
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
  using Copy = SM75_U32x4_LDSM_N;
};

template <int N, int K>
157
158
159
160
struct OperandTraits<8, N, K, true,
                     typename std::enable_if<K % 128 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
161
162
163
164
165
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
  using Copy = SM75_U32x4_LDSM_N;
};

template <int N, int K>
166
167
168
169
struct OperandTraits<64, N, K, true,
                     typename std::enable_if<K % 16 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{}));
170
171
172
173
174
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
  using Copy = DefaultCopy;
};

template <int N, int K>
175
176
177
178
179
180
struct OperandTraits<64, N, K, false,
                     typename std::enable_if<N % 16 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
181
182
183
  using Copy = DefaultCopy;
};

184
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
185
186
          bool trans_B, bool clear_accum, typename A_type_raw,
          typename B_type_raw, typename C_type_raw>
187
class GemmTensorOp {
188
189
190
191
192
193
194
public:
  using A_type =
      typename std::conditional<std::is_same<A_type_raw, float>::value,
                                tfloat32_t, A_type_raw>::type;
  using B_type =
      typename std::conditional<std::is_same<B_type_raw, float>::value,
                                tfloat32_t, A_type_raw>::type;
195
  using C_type = C_type_raw;
196
197
  using Instruction =
      DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n>;
198

199
200
201
202
  using OperandATraits =
      OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A>;
  using OperandBTraits =
      OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B>;
203
204
205
206
207
  using SmemLayoutA = typename OperandATraits::Layout;
  using SmemLayoutB = typename OperandBTraits::Layout;
  using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>;
  using SmemCopyB = Copy_Atom<typename OperandBTraits::Copy, B_type>;

208
209
210
  using TileMma = TiledMMA<typename Instruction::MMA,
                           Layout<Shape<Int<num_warp_m>, Int<num_warp_n>, _1>>,
                           typename Instruction::MMA_Group>;
211
212

  template <class... Args>
213
  static CUTE_DEVICE auto remove_swizzle(Layout<Args...> const &layout) {
214
215
216
217
218
    return layout;
  }
  // In fp16, when layout is KxN and n_warp is 1 and N % 64 == 0
  // the original layout fail to compile, currently using this as a workaround
  template <class... Args>
219
220
  static CUTE_DEVICE auto
  remove_swizzle(ComposedLayout<Args...> const &layout) {
221
222
223
224
225
226
    if constexpr (sizeof(A_type) == 2)
      return layout.layout_b();
    else
      return layout;
  }

227
  static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
228
    const int tid = threadIdx.x;
229
230
231
232
    Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                            SmemLayoutA{});
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
    auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
    auto thr_copy_A = tiled_copy_A.get_thread_slice(tid);
    auto thr_copy_B = tiled_copy_B.get_thread_slice(tid);

    Tensor tCrA = thr_mma.partition_fragment_A(sA);
    Tensor tCrB = thr_mma.partition_fragment_B(sB);
    Tensor tCsA = thr_copy_A.partition_S(sA);
    Tensor tCsB = thr_copy_B.partition_S(sB);

    Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
    Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);

248
249
250
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
251

252
253
254
    if constexpr (clear_accum) {
      clear(acc);
    }
255
256
    // when layout is KxN and n_warp is 1, there seem to be a bug, use this as a
    // workaround
257
258
259
260
261
262
263
264
265
266
    auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
    auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
      copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k));
      gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc);
    }
  }

267
268
  static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
269
    const int tid = threadIdx.x;
270
271
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
272
273
274
275
276
277
278
279
280
281
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
    auto thr_copy_B = tiled_copy_B.get_thread_slice(tid);

    Tensor tCrB = thr_mma.partition_fragment_B(sB);
    Tensor tCsB = thr_copy_B.partition_S(sB);

    Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);

282
283
284
285
286
287
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
    Tensor tCrA =
        make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
                    partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
288

289
290
291
    if constexpr (clear_accum) {
      clear(acc);
    }
292
293
294
295
296
297
298
299
300
301
302
    auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
    copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0));
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      if (k < size<2>(tCrA) - 1) {
        copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1));
      }
      gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc);
    }
  }

303
304
  static CUTE_DEVICE void body_sr(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
305
    const int tid = threadIdx.x;
306
307
    Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                            SmemLayoutA{});
308
309
310
311
312
313
314
315
316
317
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
    auto thr_copy_A = tiled_copy_A.get_thread_slice(tid);

    Tensor tCrA = thr_mma.partition_fragment_A(sA);
    Tensor tCsA = thr_copy_A.partition_S(sA);

    Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);

318
319
320
321
322
323
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
    Tensor tCrB =
        make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
                    partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
324

325
326
327
    if constexpr (clear_accum) {
      clear(acc);
    }
328
329
330
331
332
333
334
335
336
337
338
339
    auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
    copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0));
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      if (k < size<2>(tCrA) - 1) {
        copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1));
      }
      gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc);
    }
  }
};

340
} // namespace cute
341
342
343

namespace tl {

344
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
345
346
          bool trans_B, bool clear_accum, typename A_type, typename B_type,
          typename C_type>
347
348
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
  using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
349
                                 trans_B, clear_accum, A_type, B_type, C_type>;
350
351
352
  MMA::body(pA, pB, accum);
}

353
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
354
355
          bool trans_B, bool clear_accum, typename A_type, typename B_type,
          typename C_type>
356
357
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
  using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
358
                                 trans_B, clear_accum, A_type, B_type, C_type>;
359
360
361
  MMA::body_rs(pA, pB, accum);
}

362
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
363
364
          bool trans_B, bool clear_accum, typename A_type, typename B_type,
          typename C_type>
365
366
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
  using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
367
                                 trans_B, clear_accum, A_type, B_type, C_type>;
368
369
370
  MMA::body_sr(pA, pB, accum);
}

371
} // namespace tl