gemm_sm90.h 25.1 KB
Newer Older
1
2
#pragma once

3
#include <cute/arch/mma_sm80.hpp>
4
5
#include <cute/arch/mma_sm90.hpp>
#include <cute/atom/mma_atom.hpp>
6
7
#include <cutlass/arch/barrier.h>
#include <cutlass/cutlass.h>
8
#include <cutlass/gemm/collective/collective_builder.hpp>
9
10
11
12
13

#include "common.h"

namespace cute {

14
15
using namespace SM90;

16
namespace tl_wgmma {
17
18

using namespace cutlass::gemm::collective::detail; // ss_smem_selector
19

20
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
21
22
          bool trans_B, bool clear_accum, typename A_type_raw,
          typename B_type_raw, typename C_type_raw>
23
class GemmTensorOp {
24
25
26
27
28
public:
  using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
                               tfloat32_t, A_type_raw>;
  using B_type = conditional_t<std::is_same<B_type_raw, float>::value,
                               tfloat32_t, B_type_raw>;
29
30
  using C_type = C_type_raw;

31
32
33
34
  static constexpr GMMA::Major GmmaMajorA =
      trans_A ? GMMA::Major::MN : GMMA::Major::K;
  static constexpr GMMA::Major GmmaMajorB =
      trans_B ? GMMA::Major::K : GMMA::Major::MN;
35

36
  using SmemLayoutAtomA =
37
38
      decltype(ss_smem_selector<GmmaMajorA, A_type, Int<M / (num_warp_m / 4)>,
                                Int<K>>());
39
  using SmemLayoutAtomB =
40
41
      decltype(ss_smem_selector<GmmaMajorB, B_type, Int<N / num_warp_n>,
                                Int<K>>());
42

43
44
45
46
47
48
  using SmemLayoutA = decltype(tile_to_shape(
      SmemLayoutAtomA{}, Shape<Int<M>, Int<K>>{},
      conditional_t<trans_A, Step<_2, _1>, Step<_1, _2>>{}));
  using SmemLayoutB = decltype(tile_to_shape(
      SmemLayoutAtomB{}, Shape<Int<N>, Int<K>>{},
      conditional_t<trans_B, Step<_1, _2>, Step<_2, _1>>{}));
49

50
  static_assert(num_warp_m % 4 == 0, "num_warp_m must be a multiple of 4");
51

52
53
  template <int wg_wait = 0>
  static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
54
    const int tid = threadIdx.x;
55
56
57
58
59
    Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                            SmemLayoutA{});
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
    auto tiled_mma = make_tiled_mma(
60
61
62
63
        GMMA::ss_op_selector<
            A_type, B_type, C_type,
            Shape<Int<M / (num_warp_m / 4)>, Int<N / num_warp_n>, Int<K>>,
            GmmaMajorA, GmmaMajorB>(),
64
        Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
65
66
67
    auto thr_mma = tiled_mma.get_thread_slice(tid);

    // Allocate registers for pipelining
68
69
    Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
    Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
70

71
72
    Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE)
    Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE)
73

74
75
76
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
77
78
79

    warpgroup_fence_operand(acc);
    warpgroup_arrive();
80
81
82
    if constexpr (clear_accum) {
      tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
    }
83
84
85
86
87
88
89
90
91
    CUTLASS_PRAGMA_UNROLL
    for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
      // warpgroup_arrive();
      // (V,M) x (V,N) => (V,M,N)
      gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc);
      tiled_mma.accumulate_ = GMMA::ScaleOut::One;
    }

    warpgroup_commit_batch();
92
93
94
    if constexpr (wg_wait >= 0) {
      warpgroup_wait<wg_wait>();
    }
95
96
97
98
99
100
101
102
103
104
105
    warpgroup_fence_operand(acc);
    // warpgroup_fence_operand(acc);
    // warpgroup_arrive();

    // gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc);

    // warpgroup_commit_batch();
    // if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
    // warpgroup_fence_operand(acc);
  }

106
107
108
  template <int wg_wait = 0>
  static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
109
    // TODO: Move bar.sync out of body_rs
110
111
    // asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(num_warp_m * num_warp_n *
    // 32));
112
    const int tid = threadIdx.x;
113
114
115
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
    auto tiled_mma = make_tiled_mma(
116
117
118
119
        GMMA::rs_op_selector<
            A_type, B_type, C_type,
            Shape<Int<M / (num_warp_m / 4)>, Int<N / num_warp_n>, Int<K>>,
            GmmaMajorA, GmmaMajorB>(),
120
        Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
121
122
123
    auto thr_mma = tiled_mma.get_thread_slice(tid);

    // Allocate registers for pipelining
124
125
126
127
128
129
130
131
    Tensor tCsB = thr_mma.partition_B(sB);       // (MMA,MMA_N,MMA_K,PIPE)
    Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE)
    Tensor tCrA =
        make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
                    partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
132
133
134
135

    warpgroup_fence_operand(tCrA);
    warpgroup_fence_operand(acc);
    warpgroup_arrive();
136
137
138
    if constexpr (clear_accum) {
      tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
    }
139
140
141
142
143
144
145
146
    CUTLASS_PRAGMA_UNROLL
    for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
      // warpgroup_arrive();
      // (V,M) x (V,N) => (V,M,N)
      gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc);
      tiled_mma.accumulate_ = GMMA::ScaleOut::One;
    }
    warpgroup_commit_batch();
147
148
149
    if constexpr (wg_wait >= 0) {
      warpgroup_wait<wg_wait>();
    }
150
151
152
153
154
155
156
157
158
    warpgroup_fence_operand(acc);
    warpgroup_fence_operand(tCrA);

    // warpgroup_fence_operand(acc);
    // warpgroup_arrive();

    // gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc);

    // warpgroup_commit_batch();
159

160
161
162
163
164
    // if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
    // warpgroup_fence_operand(acc);
  }
};

165
166
167
168
169
} // namespace tl_wgmma

namespace tl_mma {

template <typename A_type, typename B_type, typename C_type, int num_warp_m,
170
          int num_warp_n, int N>
171
172
173
174
175
struct DispatchInstruction;

using _X = Underscore;

#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
176
177
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n, N> {
178
  using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
179
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
180
};
181
182
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
183
  using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
184
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
185
};
186
template <int num_warp_m, int num_warp_n, int N>
187
struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m,
188
                           num_warp_n, N> {
189
  using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
190
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
191
};
192
template <int num_warp_m, int num_warp_n, int N>
193
struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m,
194
                           num_warp_n, N> {
195
  using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
196
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
197
};
198
199
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n, N> {
200
  using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
201
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
202
};
203
204
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n, N> {
205
206
207
208
  using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
  using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>;
};
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
209
210
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
211
  using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
212
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _16>;
213
214
215
};
#endif

216
217
template <int Bits, int N, int K, bool K_inner, int num_warp_n,
          typename Enable = void>
218
219
220
221
222
223
224
225
226
227
228
struct OperandTraits {
  // Primary template, use padded layout and default copy
  static constexpr int stride = K_inner ? K : N;
  static constexpr int padded =
      stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride;
  using Layout = typename std::conditional<
      K_inner, Layout<Shape<Int<N>, Int<K>>, Shape<Int<padded>, _1>>,
      Layout<Shape<Int<N>, Int<K>>, Shape<_1, Int<padded>>>>::type;
  using Copy = DefaultCopy;
};

229
230
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
231
232
233
234
                     typename std::enable_if<K % 64 == 32>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
235
236
  using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
                                         SM75_U32x4_LDSM_N>::type;
237
238
};

239
240
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
241
242
243
244
                     typename std::enable_if<K % 64 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
245
246
  using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
                                         SM75_U32x4_LDSM_N>::type;
247
248
};

249
250
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, num_warp_n,
251
252
253
254
255
                     typename std::enable_if<N % 64 == 32>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
256
257
  using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U16x4_LDSM_T,
                                         SM75_U16x8_LDSM_T>::type;
258
259
};

260
261
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, num_warp_n,
262
263
264
265
266
                     typename std::enable_if<N % 64 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
267
268
  using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U16x4_LDSM_T,
                                         SM75_U16x8_LDSM_T>::type;
269
270
};

271
272
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, num_warp_n,
273
274
275
276
                     typename std::enable_if<K % 32 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
277
278
  using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
                                         SM75_U32x4_LDSM_N>::type;
279
280
};

281
282
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, num_warp_n,
283
284
285
286
                     typename std::enable_if<K % 32 == 16>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
287
288
  using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
                                         SM75_U32x4_LDSM_N>::type;
289
290
};

291
292
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, num_warp_n,
293
294
295
296
297
298
299
300
                     typename std::enable_if<N % 32 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
  using Copy = UniversalCopy<tfloat32_t>;
};

301
302
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, num_warp_n,
303
304
305
306
307
308
309
310
                     typename std::enable_if<N % 32 == 16>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
  using Copy = UniversalCopy<tfloat32_t>;
};

311
312
template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, num_warp_n,
313
314
315
316
                     typename std::enable_if<K % 128 == 64>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
317
318
  using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
                                         SM75_U32x4_LDSM_N>::type;
319
320
};

321
322
template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, num_warp_n,
323
324
325
326
                     typename std::enable_if<K % 128 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
327
328
  using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
                                         SM75_U32x4_LDSM_N>::type;
329
330
};

331
332
template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, true, num_warp_n,
333
334
335
336
337
338
339
                     typename std::enable_if<K % 16 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
  using Copy = DefaultCopy;
};

340
341
template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, false, num_warp_n,
342
343
344
345
346
347
348
349
350
                     typename std::enable_if<N % 16 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
  using Copy = DefaultCopy;
};

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
351
352
          bool trans_B, bool clear_accum, typename A_type_raw,
          typename B_type_raw, typename C_type_raw>
353
354
355
356
357
358
359
360
361
362
class GemmTensorOp {
public:
  using A_type =
      typename std::conditional<std::is_same<A_type_raw, float>::value,
                                tfloat32_t, A_type_raw>::type;
  using B_type =
      typename std::conditional<std::is_same<B_type_raw, float>::value,
                                tfloat32_t, A_type_raw>::type;
  using C_type = C_type_raw;
  using Instruction =
363
      DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
364
365

  using OperandATraits =
366
      OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A, num_warp_m>;
367
  using OperandBTraits =
368
      OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B, num_warp_n>;
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
  using SmemLayoutA = typename OperandATraits::Layout;
  using SmemLayoutB = typename OperandBTraits::Layout;
  using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>;
  using SmemCopyB = Copy_Atom<typename OperandBTraits::Copy, B_type>;

  using TileMma = TiledMMA<typename Instruction::MMA,
                           Layout<Shape<Int<num_warp_m>, Int<num_warp_n>, _1>>,
                           typename Instruction::MMA_Group>;

  template <class... Args>
  static CUTE_DEVICE auto remove_swizzle(Layout<Args...> const &layout) {
    return layout;
  }
  // In fp16, when layout is KxN and n_warp is 1 and N % 64 == 0
  // the original layout fail to compile, currently using this as a workaround
  template <class... Args>
  static CUTE_DEVICE auto
  remove_swizzle(ComposedLayout<Args...> const &layout) {
    if constexpr (sizeof(A_type) == 2)
      return layout.layout_b();
    else
      return layout;
  }

  static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
    const int tid = threadIdx.x;
    Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                            SmemLayoutA{});
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
    auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
    auto thr_copy_A = tiled_copy_A.get_thread_slice(tid);
    auto thr_copy_B = tiled_copy_B.get_thread_slice(tid);

    Tensor tCrA = thr_mma.partition_fragment_A(sA);
    Tensor tCrB = thr_mma.partition_fragment_B(sB);
    Tensor tCsA = thr_copy_A.partition_S(sA);
    Tensor tCsB = thr_copy_B.partition_S(sB);

    Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
    Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);

    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));

    // when layout is KxN and n_warp is 1, there seem to be a bug, use this as a
    // workaround
    auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
    auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
422
423
424
    if constexpr (clear_accum) {
      tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
    }
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
      copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k));
      gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc);
    }
  }

  static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
    const int tid = threadIdx.x;
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
    auto thr_copy_B = tiled_copy_B.get_thread_slice(tid);

    Tensor tCrB = thr_mma.partition_fragment_B(sB);
    Tensor tCsB = thr_copy_B.partition_S(sB);

    Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);

    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
    Tensor tCrA =
        make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
                    partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));

    auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
456
457
458
    if constexpr (clear_accum) {
      tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
    }
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
    copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0));
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      if (k < size<2>(tCrA) - 1) {
        copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1));
      }
      gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc);
    }
  }

  static CUTE_DEVICE void body_sr(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
    const int tid = threadIdx.x;
    Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                            SmemLayoutA{});
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
    auto thr_copy_A = tiled_copy_A.get_thread_slice(tid);

    Tensor tCrA = thr_mma.partition_fragment_A(sA);
    Tensor tCsA = thr_copy_A.partition_S(sA);

    Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);

    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
    Tensor tCrB =
        make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
                    partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));

    auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
492
493
494
    if constexpr (clear_accum) {
      tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
    }
495
496
497
498
499
500
501
502
503
504
505
506
507
    copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0));
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      if (k < size<2>(tCrA) - 1) {
        copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1));
      }
      gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc);
    }
  }
};

} // namespace tl_mma

508
} // namespace cute
509
510
511

namespace tl {

512
513
namespace tl_mma {

514
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
515
516
          bool trans_B, bool clear_accum, typename A_type, typename B_type,
          typename C_type>
517
518
519
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
  using MMA =
      cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
520
                                 trans_B, clear_accum, A_type, B_type, C_type>;
521
  MMA::body(pA, pB, accum);
522
523
}

524
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
525
526
          bool trans_B, bool clear_accum, typename A_type, typename B_type,
          typename C_type>
527
528
529
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
  using MMA =
      cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
530
                                 trans_B, clear_accum, A_type, B_type, C_type>;
531
532
533
534
  MMA::body_rs(pA, pB, accum);
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
535
536
          bool trans_B, bool clear_accum, typename A_type, typename B_type,
          typename C_type>
537
538
539
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
  using MMA =
      cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
540
                                 trans_B, clear_accum, A_type, B_type, C_type>;
541
542
543
544
545
546
  MMA::body_sr(pA, pB, accum);
}

} // namespace tl_mma

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
547
548
          bool trans_B, bool clear_accum = false, bool use_wgmma = true,
          int wg_wait = 0, typename A_type, typename B_type, typename C_type>
549
550
TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
  if constexpr (use_wgmma) {
551
552
553
    using MMA = cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
                                             trans_A, trans_B, clear_accum,
                                             A_type, B_type, C_type>;
554
555
    MMA::body<wg_wait>(pA, pB, accum);
  } else {
556
557
558
    using MMA = cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
                                           trans_A, trans_B, clear_accum,
                                           A_type, B_type, C_type>;
559
560
561
562
563
    MMA::body(pA, pB, accum);
  }
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
564
565
          bool trans_B, bool clear_accum = false, bool use_wgmma = true,
          int wg_wait = 0, typename A_type, typename B_type, typename C_type>
566
567
TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
  if constexpr (use_wgmma) {
568
569
570
    using MMA = cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
                                             trans_A, trans_B, clear_accum,
                                             A_type, B_type, C_type>;
571
572
    MMA::body_rs<wg_wait>(pA, pB, accum);
  } else {
573
574
575
    using MMA = cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
                                           trans_A, trans_B, clear_accum,
                                           A_type, B_type, C_type>;
576
577
    MMA::body_rs(pA, pB, accum);
  }
578
579
}

580
template <int num_mma> TL_DEVICE void wait_wgmma() {
581
  cute::warpgroup_wait<num_mma>();
582
583
}

584
585
586
template <int NumMmaThreads> TL_DEVICE void warp_scheduler_barrier_sync() {
  cutlass::arch::NamedBarrier::sync(NumMmaThreads,
                                    cutlass::canonical_warp_group_idx() /*id*/);
587
588
}

589
template <int NumMmaThreads> TL_DEVICE void warp_scheduler_barrier_arrive() {
590
591
  static_assert(NumMmaThreads == 256 || NumMmaThreads == 384);
  if constexpr (NumMmaThreads == 256) {
592
593
    cutlass::arch::NamedBarrier::arrive(
        NumMmaThreads, (1 - cutlass::canonical_warp_group_idx()) /*id*/);
594
  } else {
595
596
597
598
599
600
601
602
603
604
    cutlass::arch::NamedBarrier::arrive(
        NumMmaThreads,
        (cutlass::canonical_warp_group_idx() <= 1
             ? cutlass::canonical_warp_group_idx() + 1
             : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
    cutlass::arch::NamedBarrier::arrive(
        NumMmaThreads,
        (cutlass::canonical_warp_group_idx() <= 0
             ? cutlass::canonical_warp_group_idx() + 2
             : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
605
606
607
  }
}

608
template <int NumMmaThreads> TL_DEVICE void mma_init() {
609
610
611
612
613
614
615
616
617
618
  static_assert(NumMmaThreads == 256 || NumMmaThreads == 384);
  if (cutlass::canonical_warp_group_idx() > 0) {
    cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 0);
  }
  if constexpr (NumMmaThreads == 384) {
    if (cutlass::canonical_warp_group_idx() > 1) {
      cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 1 /*id*/);
    }
  }
}
619
} // namespace tl