gemm_sm90.h 25 KB
Newer Older
1
2
#pragma once

3
#include <cute/arch/mma_sm80.hpp>
4
5
#include <cute/arch/mma_sm90.hpp>
#include <cute/atom/mma_atom.hpp>
6
7
#include <cutlass/arch/barrier.h>
#include <cutlass/cutlass.h>
8
#include <cutlass/gemm/collective/collective_builder.hpp>
9
10
11
12
13

#include "common.h"

namespace cute {

14
15
using namespace SM90;

16
namespace tl_wgmma {
17
18

using namespace cutlass::gemm::collective::detail; // ss_smem_selector
19

20
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
21
22
          bool trans_B, bool clear_accum, typename A_type_raw,
          typename B_type_raw, typename C_type_raw>
23
class GemmTensorOp {
24
25
26
27
28
public:
  using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
                               tfloat32_t, A_type_raw>;
  using B_type = conditional_t<std::is_same<B_type_raw, float>::value,
                               tfloat32_t, B_type_raw>;
29
30
  using C_type = C_type_raw;

31
32
33
34
  static constexpr GMMA::Major GmmaMajorA =
      trans_A ? GMMA::Major::MN : GMMA::Major::K;
  static constexpr GMMA::Major GmmaMajorB =
      trans_B ? GMMA::Major::K : GMMA::Major::MN;
35

36
  using SmemLayoutAtomA =
37
      decltype(ss_smem_selector<GmmaMajorA, A_type, Int<M>, Int<K>>());
38
  using SmemLayoutAtomB =
39
      decltype(ss_smem_selector<GmmaMajorB, B_type, Int<N>, Int<K>>());
40

41
42
43
44
45
46
  using SmemLayoutA = decltype(tile_to_shape(
      SmemLayoutAtomA{}, Shape<Int<M>, Int<K>>{},
      conditional_t<trans_A, Step<_2, _1>, Step<_1, _2>>{}));
  using SmemLayoutB = decltype(tile_to_shape(
      SmemLayoutAtomB{}, Shape<Int<N>, Int<K>>{},
      conditional_t<trans_B, Step<_1, _2>, Step<_2, _1>>{}));
47

48
49
  static_assert(num_warp_m % 4 == 0,
                "num_warp_m must be a multiple of 4 for hopper wgmma");
50

51
52
  template <int wg_wait = 0>
  static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
53
    const int tid = threadIdx.x;
54
55
56
57
58
    Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                            SmemLayoutA{});
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
    auto tiled_mma = make_tiled_mma(
59
60
        GMMA::ss_op_selector<
            A_type, B_type, C_type,
61
            Shape<Int<4 * M / num_warp_m>, Int<N / num_warp_n>, Int<K>>,
62
            GmmaMajorA, GmmaMajorB>(),
63
        Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
64
65
66
    auto thr_mma = tiled_mma.get_thread_slice(tid);

    // Allocate registers for pipelining
67
68
    Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
    Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
69

70
71
    Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE)
    Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE)
72

73
74
75
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
76
77
78

    warpgroup_fence_operand(acc);
    warpgroup_arrive();
79
80
81
    if constexpr (clear_accum) {
      tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
    }
82
83
84
85
86
87
88
89
90
    CUTLASS_PRAGMA_UNROLL
    for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
      // warpgroup_arrive();
      // (V,M) x (V,N) => (V,M,N)
      gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc);
      tiled_mma.accumulate_ = GMMA::ScaleOut::One;
    }

    warpgroup_commit_batch();
91
92
93
    if constexpr (wg_wait >= 0) {
      warpgroup_wait<wg_wait>();
    }
94
95
96
    warpgroup_fence_operand(acc);
  }

97
98
99
  template <int wg_wait = 0>
  static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
100
    // TODO: Move bar.sync out of body_rs
101
102
    // asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(num_warp_m * num_warp_n *
    // 32));
103
    const int tid = threadIdx.x;
104
105
106
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
    auto tiled_mma = make_tiled_mma(
107
108
109
110
        GMMA::rs_op_selector<
            A_type, B_type, C_type,
            Shape<Int<M / (num_warp_m / 4)>, Int<N / num_warp_n>, Int<K>>,
            GmmaMajorA, GmmaMajorB>(),
111
        Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
112
113
114
    auto thr_mma = tiled_mma.get_thread_slice(tid);

    // Allocate registers for pipelining
115
116
117
118
119
120
121
122
    Tensor tCsB = thr_mma.partition_B(sB);       // (MMA,MMA_N,MMA_K,PIPE)
    Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE)
    Tensor tCrA =
        make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
                    partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
123
124
125
126

    warpgroup_fence_operand(tCrA);
    warpgroup_fence_operand(acc);
    warpgroup_arrive();
127
128
129
    if constexpr (clear_accum) {
      tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
    }
130
131
132
133
134
135
136
137
    CUTLASS_PRAGMA_UNROLL
    for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
      // warpgroup_arrive();
      // (V,M) x (V,N) => (V,M,N)
      gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc);
      tiled_mma.accumulate_ = GMMA::ScaleOut::One;
    }
    warpgroup_commit_batch();
138
139
140
    if constexpr (wg_wait >= 0) {
      warpgroup_wait<wg_wait>();
    }
141
142
143
144
145
146
147
148
149
    warpgroup_fence_operand(acc);
    warpgroup_fence_operand(tCrA);

    // warpgroup_fence_operand(acc);
    // warpgroup_arrive();

    // gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc);

    // warpgroup_commit_batch();
150

151
152
153
154
155
    // if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
    // warpgroup_fence_operand(acc);
  }
};

156
157
158
159
160
} // namespace tl_wgmma

namespace tl_mma {

template <typename A_type, typename B_type, typename C_type, int num_warp_m,
161
          int num_warp_n, int N>
162
163
164
165
166
struct DispatchInstruction;

using _X = Underscore;

#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
167
168
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n, N> {
169
  using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
170
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
171
};
172
173
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
174
  using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
175
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
176
};
177
template <int num_warp_m, int num_warp_n, int N>
178
struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m,
179
                           num_warp_n, N> {
180
  using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
181
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
182
};
183
template <int num_warp_m, int num_warp_n, int N>
184
struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m,
185
                           num_warp_n, N> {
186
  using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
187
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
188
};
189
190
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n, N> {
191
  using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
192
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
193
};
194
195
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n, N> {
196
197
198
199
  using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
  using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>;
};
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
200
201
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
202
  using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
203
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _16>;
204
205
206
};
#endif

207
208
template <int Bits, int N, int K, bool K_inner, int num_warp_n,
          typename Enable = void>
209
210
211
212
213
214
215
216
217
218
219
struct OperandTraits {
  // Primary template, use padded layout and default copy
  static constexpr int stride = K_inner ? K : N;
  static constexpr int padded =
      stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride;
  using Layout = typename std::conditional<
      K_inner, Layout<Shape<Int<N>, Int<K>>, Shape<Int<padded>, _1>>,
      Layout<Shape<Int<N>, Int<K>>, Shape<_1, Int<padded>>>>::type;
  using Copy = DefaultCopy;
};

220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
template <int N, int num_warp_n, bool transpose> struct SelectCopy {
  static constexpr int remainder = (N / num_warp_n) % 16;
  using type = std::conditional_t<
      remainder == 4 || remainder == 8 || remainder == 0,
      std::conditional_t<
          transpose,
          std::conditional_t<
              remainder == 4, SM75_U32x1_LDSM_N,
              std::conditional_t<remainder == 8, SM75_U32x2_LDSM_N,
                                 SM75_U32x4_LDSM_N>>,
          std::conditional_t<
              remainder == 4, SM75_U16x2_LDSM_T,
              std::conditional_t<remainder == 8, SM75_U16x4_LDSM_T,
                                 SM75_U16x8_LDSM_T>>>,
      DefaultCopy>;
};

237
238
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
239
240
241
242
                     typename std::enable_if<K % 64 == 32>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
243
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
244
245
};

246
247
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
248
249
250
251
                     typename std::enable_if<K % 64 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
252
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
253
254
};

255
256
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, num_warp_n,
257
258
259
260
261
                     typename std::enable_if<N % 64 == 32>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
262
  using Copy = typename SelectCopy<N, num_warp_n, false>::type;
263
264
};

265
266
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, num_warp_n,
267
268
269
270
271
                     typename std::enable_if<N % 64 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
272
  using Copy = typename SelectCopy<N, num_warp_n, false>::type;
273
274
};

275
276
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, num_warp_n,
277
278
279
280
                     typename std::enable_if<K % 32 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
281
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
282
283
};

284
285
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, num_warp_n,
286
287
288
289
                     typename std::enable_if<K % 32 == 16>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
290
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
291
292
};

293
294
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, num_warp_n,
295
296
297
298
299
300
301
302
                     typename std::enable_if<N % 32 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
  using Copy = UniversalCopy<tfloat32_t>;
};

303
304
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, num_warp_n,
305
306
307
308
309
310
311
312
                     typename std::enable_if<N % 32 == 16>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
  using Copy = UniversalCopy<tfloat32_t>;
};

313
314
template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, num_warp_n,
315
316
317
318
                     typename std::enable_if<K % 128 == 64>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
319
320
  using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
                                         SM75_U32x4_LDSM_N>::type;
321
322
};

323
324
template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, num_warp_n,
325
326
327
328
                     typename std::enable_if<K % 128 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
329
330
  using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
                                         SM75_U32x4_LDSM_N>::type;
331
332
};

333
334
template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, true, num_warp_n,
335
336
337
338
339
340
341
                     typename std::enable_if<K % 16 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
  using Copy = DefaultCopy;
};

342
343
template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, false, num_warp_n,
344
345
346
347
348
349
350
351
352
                     typename std::enable_if<N % 16 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
  using Copy = DefaultCopy;
};

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
353
354
          bool trans_B, bool clear_accum, typename A_type_raw,
          typename B_type_raw, typename C_type_raw>
355
356
357
358
359
360
361
362
363
364
class GemmTensorOp {
public:
  using A_type =
      typename std::conditional<std::is_same<A_type_raw, float>::value,
                                tfloat32_t, A_type_raw>::type;
  using B_type =
      typename std::conditional<std::is_same<B_type_raw, float>::value,
                                tfloat32_t, A_type_raw>::type;
  using C_type = C_type_raw;
  using Instruction =
365
      DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
366
367

  using OperandATraits =
368
      OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A, num_warp_m>;
369
  using OperandBTraits =
370
      OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B, num_warp_n>;
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
  using SmemLayoutA = typename OperandATraits::Layout;
  using SmemLayoutB = typename OperandBTraits::Layout;
  using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>;
  using SmemCopyB = Copy_Atom<typename OperandBTraits::Copy, B_type>;

  using TileMma = TiledMMA<typename Instruction::MMA,
                           Layout<Shape<Int<num_warp_m>, Int<num_warp_n>, _1>>,
                           typename Instruction::MMA_Group>;

  template <class... Args>
  static CUTE_DEVICE auto remove_swizzle(Layout<Args...> const &layout) {
    return layout;
  }
  // In fp16, when layout is KxN and n_warp is 1 and N % 64 == 0
  // the original layout fail to compile, currently using this as a workaround
  template <class... Args>
  static CUTE_DEVICE auto
  remove_swizzle(ComposedLayout<Args...> const &layout) {
    if constexpr (sizeof(A_type) == 2)
      return layout.layout_b();
    else
      return layout;
  }

  static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
    const int tid = threadIdx.x;
    Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                            SmemLayoutA{});
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
    auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
    auto thr_copy_A = tiled_copy_A.get_thread_slice(tid);
    auto thr_copy_B = tiled_copy_B.get_thread_slice(tid);

    Tensor tCrA = thr_mma.partition_fragment_A(sA);
    Tensor tCrB = thr_mma.partition_fragment_B(sB);
    Tensor tCsA = thr_copy_A.partition_S(sA);
    Tensor tCsB = thr_copy_B.partition_S(sB);

    Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
    Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);

    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));

    // when layout is KxN and n_warp is 1, there seem to be a bug, use this as a
    // workaround
    auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
    auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
424
425
426
    if constexpr (clear_accum) {
      tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
    }
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
      copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k));
      gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc);
    }
  }

  static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
    const int tid = threadIdx.x;
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
    auto thr_copy_B = tiled_copy_B.get_thread_slice(tid);

    Tensor tCrB = thr_mma.partition_fragment_B(sB);
    Tensor tCsB = thr_copy_B.partition_S(sB);

    Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);

    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
    Tensor tCrA =
        make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
                    partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));

    auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
458
459
460
    if constexpr (clear_accum) {
      tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
    }
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
    copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0));
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      if (k < size<2>(tCrA) - 1) {
        copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1));
      }
      gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc);
    }
  }

  static CUTE_DEVICE void body_sr(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
    const int tid = threadIdx.x;
    Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                            SmemLayoutA{});
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
    auto thr_copy_A = tiled_copy_A.get_thread_slice(tid);

    Tensor tCrA = thr_mma.partition_fragment_A(sA);
    Tensor tCsA = thr_copy_A.partition_S(sA);

    Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);

    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
    Tensor tCrB =
        make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
                    partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));

    auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
494
495
496
    if constexpr (clear_accum) {
      tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
    }
497
498
499
500
501
502
503
504
505
506
507
508
509
    copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0));
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      if (k < size<2>(tCrA) - 1) {
        copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1));
      }
      gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc);
    }
  }
};

} // namespace tl_mma

510
} // namespace cute
511
512
513

namespace tl {

514
515
namespace tl_mma {

516
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
517
518
          bool trans_B, bool clear_accum, typename A_type, typename B_type,
          typename C_type>
519
520
521
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
  using MMA =
      cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
522
                                 trans_B, clear_accum, A_type, B_type, C_type>;
523
  MMA::body(pA, pB, accum);
524
525
}

526
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
527
528
          bool trans_B, bool clear_accum, typename A_type, typename B_type,
          typename C_type>
529
530
531
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
  using MMA =
      cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
532
                                 trans_B, clear_accum, A_type, B_type, C_type>;
533
534
535
536
  MMA::body_rs(pA, pB, accum);
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
537
538
          bool trans_B, bool clear_accum, typename A_type, typename B_type,
          typename C_type>
539
540
541
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
  using MMA =
      cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
542
                                 trans_B, clear_accum, A_type, B_type, C_type>;
543
544
545
546
547
548
  MMA::body_sr(pA, pB, accum);
}

} // namespace tl_mma

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
549
550
          bool trans_B, bool clear_accum = false, bool use_wgmma = true,
          int wg_wait = 0, typename A_type, typename B_type, typename C_type>
551
552
TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
  if constexpr (use_wgmma) {
553
554
555
    using MMA = cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
                                             trans_A, trans_B, clear_accum,
                                             A_type, B_type, C_type>;
556
557
    MMA::body<wg_wait>(pA, pB, accum);
  } else {
558
559
560
    using MMA = cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
                                           trans_A, trans_B, clear_accum,
                                           A_type, B_type, C_type>;
561
562
563
564
565
    MMA::body(pA, pB, accum);
  }
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
566
567
          bool trans_B, bool clear_accum = false, bool use_wgmma = true,
          int wg_wait = 0, typename A_type, typename B_type, typename C_type>
568
569
TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
  if constexpr (use_wgmma) {
570
571
572
    using MMA = cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
                                             trans_A, trans_B, clear_accum,
                                             A_type, B_type, C_type>;
573
574
    MMA::body_rs<wg_wait>(pA, pB, accum);
  } else {
575
576
577
    using MMA = cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
                                           trans_A, trans_B, clear_accum,
                                           A_type, B_type, C_type>;
578
579
    MMA::body_rs(pA, pB, accum);
  }
580
581
}

582
template <int num_mma> TL_DEVICE void wait_wgmma() {
583
  cute::warpgroup_wait<num_mma>();
584
585
}

586
587
588
template <int NumMmaThreads> TL_DEVICE void warp_scheduler_barrier_sync() {
  cutlass::arch::NamedBarrier::sync(NumMmaThreads,
                                    cutlass::canonical_warp_group_idx() /*id*/);
589
590
}

591
template <int NumMmaThreads> TL_DEVICE void warp_scheduler_barrier_arrive() {
592
593
  static_assert(NumMmaThreads == 256 || NumMmaThreads == 384);
  if constexpr (NumMmaThreads == 256) {
594
595
    cutlass::arch::NamedBarrier::arrive(
        NumMmaThreads, (1 - cutlass::canonical_warp_group_idx()) /*id*/);
596
  } else {
597
598
599
600
601
602
603
604
605
606
    cutlass::arch::NamedBarrier::arrive(
        NumMmaThreads,
        (cutlass::canonical_warp_group_idx() <= 1
             ? cutlass::canonical_warp_group_idx() + 1
             : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
    cutlass::arch::NamedBarrier::arrive(
        NumMmaThreads,
        (cutlass::canonical_warp_group_idx() <= 0
             ? cutlass::canonical_warp_group_idx() + 2
             : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
607
608
609
  }
}

610
template <int NumMmaThreads> TL_DEVICE void mma_init() {
611
612
613
614
615
616
617
618
619
620
  static_assert(NumMmaThreads == 256 || NumMmaThreads == 384);
  if (cutlass::canonical_warp_group_idx() > 0) {
    cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 0);
  }
  if constexpr (NumMmaThreads == 384) {
    if (cutlass::canonical_warp_group_idx() > 1) {
      cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 1 /*id*/);
    }
  }
}
621
} // namespace tl