gemm_sm80.h 17.1 KB
Newer Older
1
2
#pragma once

3
#include <cute/algorithm/clear.hpp>
4
5
6
#include <cute/arch/mma_sm80.hpp>
#include <cute/atom/mma_atom.hpp>
#include <cute/underscore.hpp>
7
8
9
10
11

#include "common.h"

namespace cute {

12
template <typename A_type, typename B_type, typename C_type, int num_warp_m,
13
          int num_warp_n, int N>
14
15
struct DispatchInstruction;

16
17
using _X = Underscore;

18
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
19
20
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n, N> {
21
  using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
22
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
23
};
24
25
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
26
  using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
27
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
28
};
29
template <int num_warp_m, int num_warp_n, int N>
30
struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m,
31
                           num_warp_n, N> {
32
  using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
33
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
34
};
35
template <int num_warp_m, int num_warp_n, int N>
36
struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m,
37
                           num_warp_n, N> {
38
  using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
39
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
40
};
41
42
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n, N> {
43
  using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
44
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
45
};
46
47
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n, N> {
48
  using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
49
  using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>;
50
51
};
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
52
53
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
54
  using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
55
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _16>;
56
57
58
};
#endif

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
template <int N, int num_warp_n, bool transpose> struct SelectCopy {
  static constexpr int remainder = (N / num_warp_n) % 16;
  using type = std::conditional_t<
      remainder == 4 || remainder == 8 || remainder == 0,
      std::conditional_t<
          transpose,
          std::conditional_t<
              remainder == 4, SM75_U32x1_LDSM_N,
              std::conditional_t<remainder == 8, SM75_U32x2_LDSM_N,
                                 SM75_U32x4_LDSM_N>>,
          std::conditional_t<
              remainder == 4, SM75_U16x2_LDSM_T,
              std::conditional_t<remainder == 8, SM75_U16x4_LDSM_T,
                                 SM75_U16x8_LDSM_T>>>,
      DefaultCopy>;
};

76
77
template <int Bits, int N, int K, bool K_inner, int num_warp_n,
          typename Enable = void>
78
79
80
struct OperandTraits {
  // Primary template, use padded layout and default copy
  static constexpr int stride = K_inner ? K : N;
81
82
83
84
85
  static constexpr int padded =
      stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride;
  using Layout = typename std::conditional<
      K_inner, Layout<Shape<Int<N>, Int<K>>, Shape<Int<padded>, _1>>,
      Layout<Shape<Int<N>, Int<K>>, Shape<_1, Int<padded>>>>::type;
86
87
88
  using Copy = DefaultCopy;
};

89
90
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
91
92
93
                     typename std::enable_if<K % 64 == 32>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
94
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
95
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
96
97
};

98
99
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
100
101
102
                     typename std::enable_if<K % 64 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
103
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
104
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
105
106
};

107
108
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, num_warp_n,
109
110
111
112
113
                     typename std::enable_if<N % 64 == 32>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
114
  using Copy = typename SelectCopy<N, num_warp_n, false>::type;
115
116
};

117
118
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, num_warp_n,
119
120
121
122
123
                     typename std::enable_if<N % 64 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
124
  using Copy = typename SelectCopy<N, num_warp_n, false>::type;
125
126
};

127
128
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, num_warp_n,
129
130
131
                     typename std::enable_if<K % 32 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
132
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
133
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
134
135
};

136
137
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, num_warp_n,
138
139
140
                     typename std::enable_if<K % 32 == 16>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
141
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
142
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
143
144
};

145
146
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, num_warp_n,
147
148
149
150
151
                     typename std::enable_if<N % 32 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
152
153
154
  using Copy = UniversalCopy<tfloat32_t>;
};

155
156
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, num_warp_n,
157
158
159
160
161
                     typename std::enable_if<N % 32 == 16>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
162
163
164
  using Copy = UniversalCopy<tfloat32_t>;
};

165
166
template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, num_warp_n,
167
168
169
                     typename std::enable_if<K % 128 == 64>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
170
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
171
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
172
173
};

174
175
template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, num_warp_n,
176
177
178
                     typename std::enable_if<K % 128 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
179
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
180
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
181
182
};

183
184
template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, true, num_warp_n,
185
186
187
                     typename std::enable_if<K % 16 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{}));
188
189
190
191
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
  using Copy = DefaultCopy;
};

192
193
template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, false, num_warp_n,
194
195
196
197
198
                     typename std::enable_if<N % 16 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
199
200
201
  using Copy = DefaultCopy;
};

202
203
204
205
206
207
208
209
template <typename T> CUTE_HOST_DEVICE static void cast_float_to_tf32(T &a) {
  uint32_t x = reinterpret_cast<uint32_t const &>(a);
  if (std::isfinite(a)) {
    x += 0x1000u;
  }
  a = tfloat32_t::bitcast(x);
};

210
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
211
212
          bool trans_B, bool clear_accum, typename A_type_raw,
          typename B_type_raw, typename C_type_raw>
213
class GemmTensorOp {
214
215
216
217
218
219
220
public:
  using A_type =
      typename std::conditional<std::is_same<A_type_raw, float>::value,
                                tfloat32_t, A_type_raw>::type;
  using B_type =
      typename std::conditional<std::is_same<B_type_raw, float>::value,
                                tfloat32_t, A_type_raw>::type;
221
  using C_type = C_type_raw;
222
223
224
225
226

  static constexpr bool need_tfloat32_cast =
      std::is_same<A_type_raw, float>::value &&
      std::is_same<B_type_raw, float>::value;

227
  using Instruction =
228
      DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
229

230
  using OperandATraits =
231
      OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A, num_warp_m>;
232
  using OperandBTraits =
233
234
      OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B, num_warp_n>;

235
236
237
238
239
  using SmemLayoutA = typename OperandATraits::Layout;
  using SmemLayoutB = typename OperandBTraits::Layout;
  using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>;
  using SmemCopyB = Copy_Atom<typename OperandBTraits::Copy, B_type>;

240
241
242
  using TileMma = TiledMMA<typename Instruction::MMA,
                           Layout<Shape<Int<num_warp_m>, Int<num_warp_n>, _1>>,
                           typename Instruction::MMA_Group>;
243
244

  template <class... Args>
245
  static CUTE_DEVICE auto remove_swizzle(Layout<Args...> const &layout) {
246
247
248
249
250
    return layout;
  }
  // In fp16, when layout is KxN and n_warp is 1 and N % 64 == 0
  // the original layout fail to compile, currently using this as a workaround
  template <class... Args>
251
252
  static CUTE_DEVICE auto
  remove_swizzle(ComposedLayout<Args...> const &layout) {
253
254
255
256
257
258
    if constexpr (sizeof(A_type) == 2)
      return layout.layout_b();
    else
      return layout;
  }

259
  static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
260
    const int tid = threadIdx.x;
261
262
263
264
    Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                            SmemLayoutA{});
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
    auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
    auto thr_copy_A = tiled_copy_A.get_thread_slice(tid);
    auto thr_copy_B = tiled_copy_B.get_thread_slice(tid);

    Tensor tCrA = thr_mma.partition_fragment_A(sA);
    Tensor tCrB = thr_mma.partition_fragment_B(sB);
    Tensor tCsA = thr_copy_A.partition_S(sA);
    Tensor tCsB = thr_copy_B.partition_S(sB);

    Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
    Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);

280
281
282
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
283

284
285
286
    if constexpr (clear_accum) {
      clear(acc);
    }
287
288
    // when layout is KxN and n_warp is 1, there seem to be a bug, use this as a
    // workaround
289
290
291
292
293
294
    auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
    auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
      copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k));
295
296
297
298
      if constexpr (need_tfloat32_cast) {
        cute::for_each(tCrA_view(_, _, k), cast_float_to_tf32<A_type>);
        cute::for_each(tCrB_view(_, _, k), cast_float_to_tf32<B_type>);
      }
299
300
301
302
      gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc);
    }
  }

303
304
  static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
305
    const int tid = threadIdx.x;
306
307
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
308
309
310
311
312
313
314
315
316
317
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
    auto thr_copy_B = tiled_copy_B.get_thread_slice(tid);

    Tensor tCrB = thr_mma.partition_fragment_B(sB);
    Tensor tCsB = thr_copy_B.partition_S(sB);

    Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);

318
319
320
321
322
323
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
    Tensor tCrA =
        make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
                    partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
324
325
326
    if constexpr (need_tfloat32_cast) {
      cute::for_each(tCrA, cast_float_to_tf32<A_type>);
    }
327
328
329
    if constexpr (clear_accum) {
      clear(acc);
    }
330
331
332
333
334
335
336
    auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
    copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0));
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      if (k < size<2>(tCrA) - 1) {
        copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1));
      }
337
338
339
      if constexpr (need_tfloat32_cast) {
        cute::for_each(tCrB_view(_, _, k), cast_float_to_tf32<B_type>);
      }
340
341
342
343
      gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc);
    }
  }

344
345
  static CUTE_DEVICE void body_sr(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
346
    const int tid = threadIdx.x;
347
348
    Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                            SmemLayoutA{});
349
350
351
352
353
354
355
356
357
358
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
    auto thr_copy_A = tiled_copy_A.get_thread_slice(tid);

    Tensor tCrA = thr_mma.partition_fragment_A(sA);
    Tensor tCsA = thr_copy_A.partition_S(sA);

    Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);

359
360
361
362
363
364
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
    Tensor tCrB =
        make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
                    partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
365
366
367
    if constexpr (need_tfloat32_cast) {
      cute::for_each(tCrB, cast_float_to_tf32<B_type>);
    }
368
369
370
    if constexpr (clear_accum) {
      clear(acc);
    }
371
372
373
374
375
376
377
    auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
    copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0));
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      if (k < size<2>(tCrA) - 1) {
        copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1));
      }
378
379
380
      if constexpr (need_tfloat32_cast) {
        cute::for_each(tCrA_view(_, _, k), cast_float_to_tf32<A_type>);
      }
381
382
383
384
385
      gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc);
    }
  }
};

386
} // namespace cute
387
388
389

namespace tl {

390
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
391
392
          bool trans_B, bool clear_accum, typename A_type, typename B_type,
          typename C_type>
393
394
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
  using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
395
                                 trans_B, clear_accum, A_type, B_type, C_type>;
396
397
398
  MMA::body(pA, pB, accum);
}

399
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
400
401
          bool trans_B, bool clear_accum, typename A_type, typename B_type,
          typename C_type>
402
403
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
  using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
404
                                 trans_B, clear_accum, A_type, B_type, C_type>;
405
406
407
  MMA::body_rs(pA, pB, accum);
}

408
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
409
410
          bool trans_B, bool clear_accum, typename A_type, typename B_type,
          typename C_type>
411
412
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
  using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
413
                                 trans_B, clear_accum, A_type, B_type, C_type>;
414
415
416
  MMA::body_sr(pA, pB, accum);
}

417
} // namespace tl