gemm_sm90.h 26.4 KB
Newer Older
1
2
#pragma once

3
#include <cute/arch/mma_sm80.hpp>
4
5
#include <cute/arch/mma_sm90.hpp>
#include <cute/atom/mma_atom.hpp>
6
7
#include <cutlass/arch/barrier.h>
#include <cutlass/cutlass.h>
8
#include <cutlass/gemm/collective/collective_builder.hpp>
9
10
11
12
13

#include "common.h"

namespace cute {

14
15
using namespace SM90;

16
17
18
19
20
21
22
23
template <typename T> CUTE_HOST_DEVICE static void cast_float_to_tf32(T &a) {
  uint32_t x = reinterpret_cast<uint32_t const &>(a);
  if (std::isfinite(a)) {
    x += 0x1000u;
  }
  a = tfloat32_t::bitcast(x);
};

24
namespace tl_wgmma {
25
26

using namespace cutlass::gemm::collective::detail; // ss_smem_selector
27

28
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
29
30
          bool trans_B, bool clear_accum, typename A_type_raw,
          typename B_type_raw, typename C_type_raw>
31
class GemmTensorOp {
32
33
34
35
36
public:
  using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
                               tfloat32_t, A_type_raw>;
  using B_type = conditional_t<std::is_same<B_type_raw, float>::value,
                               tfloat32_t, B_type_raw>;
37
38
  using C_type = C_type_raw;

39
40
41
42
43
44
  static constexpr bool need_tfloat32_cast =
      std::is_same<A_type_raw, float>::value &&
      // A_type will be tfloat32_t if A_type_raw is float
      std::is_same<B_type_raw, float>::value;
  // B_type will be tfloat32_t if B_type_raw is float

45
46
47
48
  static constexpr GMMA::Major GmmaMajorA =
      trans_A ? GMMA::Major::MN : GMMA::Major::K;
  static constexpr GMMA::Major GmmaMajorB =
      trans_B ? GMMA::Major::K : GMMA::Major::MN;
49

50
  using SmemLayoutAtomA =
51
      decltype(ss_smem_selector<GmmaMajorA, A_type, Int<M>, Int<K>>());
52
  using SmemLayoutAtomB =
53
      decltype(ss_smem_selector<GmmaMajorB, B_type, Int<N>, Int<K>>());
54

55
56
57
58
59
60
  using SmemLayoutA = decltype(tile_to_shape(
      SmemLayoutAtomA{}, Shape<Int<M>, Int<K>>{},
      conditional_t<trans_A, Step<_2, _1>, Step<_1, _2>>{}));
  using SmemLayoutB = decltype(tile_to_shape(
      SmemLayoutAtomB{}, Shape<Int<N>, Int<K>>{},
      conditional_t<trans_B, Step<_1, _2>, Step<_2, _1>>{}));
61

62
63
  static_assert(num_warp_m % 4 == 0,
                "num_warp_m must be a multiple of 4 for hopper wgmma");
64

65
66
  template <int wg_wait = 0>
  static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
67
    const int tid = threadIdx.x;
68
69
70
71
72
    Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                            SmemLayoutA{});
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
    auto tiled_mma = make_tiled_mma(
73
74
        GMMA::ss_op_selector<
            A_type, B_type, C_type,
75
            Shape<Int<4 * M / num_warp_m>, Int<N / num_warp_n>, Int<K>>,
76
            GmmaMajorA, GmmaMajorB>(),
77
        Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
78
79
80
    auto thr_mma = tiled_mma.get_thread_slice(tid);

    // Allocate registers for pipelining
81
82
    Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
    Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
83

84
85
    Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE)
    Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE)
86

87
88
89
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
90
91
92

    warpgroup_fence_operand(acc);
    warpgroup_arrive();
93
94
95
    if constexpr (clear_accum) {
      tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
    }
96
97
98
99
    if constexpr (need_tfloat32_cast) {
      cute::for_each(tCrA, cast_float_to_tf32<A_type>);
      cute::for_each(tCrB, cast_float_to_tf32<B_type>);
    }
100
101
102
103
104
105
106
107
108
    CUTLASS_PRAGMA_UNROLL
    for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
      // warpgroup_arrive();
      // (V,M) x (V,N) => (V,M,N)
      gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc);
      tiled_mma.accumulate_ = GMMA::ScaleOut::One;
    }

    warpgroup_commit_batch();
109
110
111
    if constexpr (wg_wait >= 0) {
      warpgroup_wait<wg_wait>();
    }
112
113
114
    warpgroup_fence_operand(acc);
  }

115
116
117
  template <int wg_wait = 0>
  static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
118
    // TODO: Move bar.sync out of body_rs
119
120
    // asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(num_warp_m * num_warp_n *
    // 32));
121
    const int tid = threadIdx.x;
122
123
124
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
    auto tiled_mma = make_tiled_mma(
125
126
127
128
        GMMA::rs_op_selector<
            A_type, B_type, C_type,
            Shape<Int<M / (num_warp_m / 4)>, Int<N / num_warp_n>, Int<K>>,
            GmmaMajorA, GmmaMajorB>(),
129
        Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
130
131
132
    auto thr_mma = tiled_mma.get_thread_slice(tid);

    // Allocate registers for pipelining
133
134
135
136
137
138
139
140
    Tensor tCsB = thr_mma.partition_B(sB);       // (MMA,MMA_N,MMA_K,PIPE)
    Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE)
    Tensor tCrA =
        make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
                    partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
141
142
143
144
    if constexpr (need_tfloat32_cast) {
      cute::for_each(tCrA, cast_float_to_tf32<A_type>);
      cute::for_each(tCrB, cast_float_to_tf32<B_type>);
    }
145
146
147
    warpgroup_fence_operand(tCrA);
    warpgroup_fence_operand(acc);
    warpgroup_arrive();
148
149
150
    if constexpr (clear_accum) {
      tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
    }
151
152
153
154
155
156
157
158
    CUTLASS_PRAGMA_UNROLL
    for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
      // warpgroup_arrive();
      // (V,M) x (V,N) => (V,M,N)
      gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc);
      tiled_mma.accumulate_ = GMMA::ScaleOut::One;
    }
    warpgroup_commit_batch();
159
160
161
    if constexpr (wg_wait >= 0) {
      warpgroup_wait<wg_wait>();
    }
162
163
164
165
166
    warpgroup_fence_operand(acc);
    warpgroup_fence_operand(tCrA);
  }
};

167
168
169
170
171
} // namespace tl_wgmma

namespace tl_mma {

template <typename A_type, typename B_type, typename C_type, int num_warp_m,
172
          int num_warp_n, int N>
173
174
175
176
177
struct DispatchInstruction;

using _X = Underscore;

#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
178
179
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n, N> {
180
  using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
181
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
182
};
183
184
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
185
  using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
186
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
187
};
188
template <int num_warp_m, int num_warp_n, int N>
189
struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m,
190
                           num_warp_n, N> {
191
  using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
192
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
193
};
194
template <int num_warp_m, int num_warp_n, int N>
195
struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m,
196
                           num_warp_n, N> {
197
  using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
198
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
199
};
200
201
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n, N> {
202
  using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
203
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
204
};
205
206
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n, N> {
207
208
209
210
  using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
  using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>;
};
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
211
212
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
213
  using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
214
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _16>;
215
216
217
};
#endif

218
219
template <int Bits, int N, int K, bool K_inner, int num_warp_n,
          typename Enable = void>
220
221
222
223
224
225
226
227
228
229
230
struct OperandTraits {
  // Primary template, use padded layout and default copy
  static constexpr int stride = K_inner ? K : N;
  static constexpr int padded =
      stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride;
  using Layout = typename std::conditional<
      K_inner, Layout<Shape<Int<N>, Int<K>>, Shape<Int<padded>, _1>>,
      Layout<Shape<Int<N>, Int<K>>, Shape<_1, Int<padded>>>>::type;
  using Copy = DefaultCopy;
};

231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
template <int N, int num_warp_n, bool transpose> struct SelectCopy {
  static constexpr int remainder = (N / num_warp_n) % 16;
  using type = std::conditional_t<
      remainder == 4 || remainder == 8 || remainder == 0,
      std::conditional_t<
          transpose,
          std::conditional_t<
              remainder == 4, SM75_U32x1_LDSM_N,
              std::conditional_t<remainder == 8, SM75_U32x2_LDSM_N,
                                 SM75_U32x4_LDSM_N>>,
          std::conditional_t<
              remainder == 4, SM75_U16x2_LDSM_T,
              std::conditional_t<remainder == 8, SM75_U16x4_LDSM_T,
                                 SM75_U16x8_LDSM_T>>>,
      DefaultCopy>;
};

248
249
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
250
251
252
253
                     typename std::enable_if<K % 64 == 32>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
254
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
255
256
};

257
258
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
259
260
261
262
                     typename std::enable_if<K % 64 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
263
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
264
265
};

266
267
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, num_warp_n,
268
269
270
271
272
                     typename std::enable_if<N % 64 == 32>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
273
  using Copy = typename SelectCopy<N, num_warp_n, false>::type;
274
275
};

276
277
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, num_warp_n,
278
279
280
281
282
                     typename std::enable_if<N % 64 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
283
  using Copy = typename SelectCopy<N, num_warp_n, false>::type;
284
285
};

286
287
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, num_warp_n,
288
289
290
291
                     typename std::enable_if<K % 32 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
292
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
293
294
};

295
296
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, num_warp_n,
297
298
299
300
                     typename std::enable_if<K % 32 == 16>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
301
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
302
303
};

304
305
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, num_warp_n,
306
307
308
309
310
311
312
313
                     typename std::enable_if<N % 32 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
  using Copy = UniversalCopy<tfloat32_t>;
};

314
315
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, num_warp_n,
316
317
318
319
320
321
322
323
                     typename std::enable_if<N % 32 == 16>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
  using Copy = UniversalCopy<tfloat32_t>;
};

324
325
template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, num_warp_n,
326
327
328
329
                     typename std::enable_if<K % 128 == 64>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
330
331
  using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
                                         SM75_U32x4_LDSM_N>::type;
332
333
};

334
335
template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, num_warp_n,
336
337
338
339
                     typename std::enable_if<K % 128 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
340
341
  using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
                                         SM75_U32x4_LDSM_N>::type;
342
343
};

344
345
template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, true, num_warp_n,
346
347
348
349
350
351
352
                     typename std::enable_if<K % 16 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
  using Copy = DefaultCopy;
};

353
354
template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, false, num_warp_n,
355
356
357
358
359
360
361
362
363
                     typename std::enable_if<N % 16 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
  using Copy = DefaultCopy;
};

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
364
365
          bool trans_B, bool clear_accum, typename A_type_raw,
          typename B_type_raw, typename C_type_raw>
366
367
368
369
370
371
372
373
374
class GemmTensorOp {
public:
  using A_type =
      typename std::conditional<std::is_same<A_type_raw, float>::value,
                                tfloat32_t, A_type_raw>::type;
  using B_type =
      typename std::conditional<std::is_same<B_type_raw, float>::value,
                                tfloat32_t, A_type_raw>::type;
  using C_type = C_type_raw;
375
376
377
378
379
380
381

  static constexpr bool need_tfloat32_cast =
      std::is_same<A_type_raw, float>::value &&
      std::is_same<A_type, tfloat32_t>::value &&
      std::is_same<B_type_raw, float>::value &&
      std::is_same<B_type, tfloat32_t>::value;

382
  using Instruction =
383
      DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
384
385

  using OperandATraits =
386
      OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A, num_warp_m>;
387
  using OperandBTraits =
388
      OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B, num_warp_n>;
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
  using SmemLayoutA = typename OperandATraits::Layout;
  using SmemLayoutB = typename OperandBTraits::Layout;
  using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>;
  using SmemCopyB = Copy_Atom<typename OperandBTraits::Copy, B_type>;

  using TileMma = TiledMMA<typename Instruction::MMA,
                           Layout<Shape<Int<num_warp_m>, Int<num_warp_n>, _1>>,
                           typename Instruction::MMA_Group>;

  template <class... Args>
  static CUTE_DEVICE auto remove_swizzle(Layout<Args...> const &layout) {
    return layout;
  }
  // In fp16, when layout is KxN and n_warp is 1 and N % 64 == 0
  // the original layout fail to compile, currently using this as a workaround
  template <class... Args>
  static CUTE_DEVICE auto
  remove_swizzle(ComposedLayout<Args...> const &layout) {
    if constexpr (sizeof(A_type) == 2)
      return layout.layout_b();
    else
      return layout;
  }

  static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
    const int tid = threadIdx.x;
    Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                            SmemLayoutA{});
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
    auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
    auto thr_copy_A = tiled_copy_A.get_thread_slice(tid);
    auto thr_copy_B = tiled_copy_B.get_thread_slice(tid);

    Tensor tCrA = thr_mma.partition_fragment_A(sA);
    Tensor tCrB = thr_mma.partition_fragment_B(sB);
    Tensor tCsA = thr_copy_A.partition_S(sA);
    Tensor tCsB = thr_copy_B.partition_S(sB);

    Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
    Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);

    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));

    // when layout is KxN and n_warp is 1, there seem to be a bug, use this as a
    // workaround
    auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
    auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
442
443
444
    if constexpr (clear_accum) {
      tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
    }
445
446
447
448
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
      copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k));
449
450
451
452
      if constexpr (need_tfloat32_cast) {
        cute::for_each(tCrA_view(_, _, k), cast_float_to_tf32<A_type>);
        cute::for_each(tCrB_view(_, _, k), cast_float_to_tf32<B_type>);
      }
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
      gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc);
    }
  }

  static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
    const int tid = threadIdx.x;
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
    auto thr_copy_B = tiled_copy_B.get_thread_slice(tid);

    Tensor tCrB = thr_mma.partition_fragment_B(sB);
    Tensor tCsB = thr_copy_B.partition_S(sB);

    Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);

    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
    Tensor tCrA =
        make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
                    partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
478
479
480
    if constexpr (need_tfloat32_cast) {
      cute::for_each(tCrA, cast_float_to_tf32<A_type>);
    }
481
    auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
482
483
484
    if constexpr (clear_accum) {
      tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
    }
485
486
487
488
489
490
    copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0));
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      if (k < size<2>(tCrA) - 1) {
        copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1));
      }
491
492
493
      if constexpr (need_tfloat32_cast) {
        cute::for_each(tCrB_view(_, _, k), cast_float_to_tf32<B_type>);
      }
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
      gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc);
    }
  }

  static CUTE_DEVICE void body_sr(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
    const int tid = threadIdx.x;
    Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                            SmemLayoutA{});
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
    auto thr_copy_A = tiled_copy_A.get_thread_slice(tid);

    Tensor tCrA = thr_mma.partition_fragment_A(sA);
    Tensor tCsA = thr_copy_A.partition_S(sA);

    Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);

    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
    Tensor tCrB =
        make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
                    partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
519
520
521
    if constexpr (need_tfloat32_cast) {
      cute::for_each(tCrB, cast_float_to_tf32<B_type>);
    }
522
    auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
523
524
525
    if constexpr (clear_accum) {
      tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
    }
526
527
528
529
530
531
    copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0));
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      if (k < size<2>(tCrA) - 1) {
        copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1));
      }
532
533
534
      if constexpr (need_tfloat32_cast) {
        cute::for_each(tCrA_view(_, _, k), cast_float_to_tf32<A_type>);
      }
535
536
537
538
539
540
541
      gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc);
    }
  }
};

} // namespace tl_mma

542
} // namespace cute
543
544
545

namespace tl {

546
547
namespace tl_mma {

548
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
549
550
          bool trans_B, bool clear_accum, typename A_type, typename B_type,
          typename C_type>
551
552
553
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
  using MMA =
      cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
554
                                 trans_B, clear_accum, A_type, B_type, C_type>;
555
  MMA::body(pA, pB, accum);
556
557
}

558
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
559
560
          bool trans_B, bool clear_accum, typename A_type, typename B_type,
          typename C_type>
561
562
563
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
  using MMA =
      cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
564
                                 trans_B, clear_accum, A_type, B_type, C_type>;
565
566
567
568
  MMA::body_rs(pA, pB, accum);
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
569
570
          bool trans_B, bool clear_accum, typename A_type, typename B_type,
          typename C_type>
571
572
573
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
  using MMA =
      cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
574
                                 trans_B, clear_accum, A_type, B_type, C_type>;
575
576
577
578
579
580
  MMA::body_sr(pA, pB, accum);
}

} // namespace tl_mma

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
581
582
          bool trans_B, bool clear_accum = false, bool use_wgmma = true,
          int wg_wait = 0, typename A_type, typename B_type, typename C_type>
583
584
TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
  if constexpr (use_wgmma) {
585
586
587
    using MMA = cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
                                             trans_A, trans_B, clear_accum,
                                             A_type, B_type, C_type>;
588
589
    MMA::body<wg_wait>(pA, pB, accum);
  } else {
590
591
592
    using MMA = cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
                                           trans_A, trans_B, clear_accum,
                                           A_type, B_type, C_type>;
593
594
595
596
597
    MMA::body(pA, pB, accum);
  }
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
598
599
          bool trans_B, bool clear_accum = false, bool use_wgmma = true,
          int wg_wait = 0, typename A_type, typename B_type, typename C_type>
600
601
TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
  if constexpr (use_wgmma) {
602
603
604
    using MMA = cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
                                             trans_A, trans_B, clear_accum,
                                             A_type, B_type, C_type>;
605
606
    MMA::body_rs<wg_wait>(pA, pB, accum);
  } else {
607
608
609
    using MMA = cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
                                           trans_A, trans_B, clear_accum,
                                           A_type, B_type, C_type>;
610
611
    MMA::body_rs(pA, pB, accum);
  }
612
613
}

614
template <int num_mma> TL_DEVICE void wait_wgmma() {
615
  cute::warpgroup_wait<num_mma>();
616
617
}

618
619
620
template <int NumMmaThreads> TL_DEVICE void warp_scheduler_barrier_sync() {
  cutlass::arch::NamedBarrier::sync(NumMmaThreads,
                                    cutlass::canonical_warp_group_idx() /*id*/);
621
622
}

623
template <int NumMmaThreads> TL_DEVICE void warp_scheduler_barrier_arrive() {
624
625
  static_assert(NumMmaThreads == 256 || NumMmaThreads == 384);
  if constexpr (NumMmaThreads == 256) {
626
627
    cutlass::arch::NamedBarrier::arrive(
        NumMmaThreads, (1 - cutlass::canonical_warp_group_idx()) /*id*/);
628
  } else {
629
630
631
632
633
634
635
636
637
638
    cutlass::arch::NamedBarrier::arrive(
        NumMmaThreads,
        (cutlass::canonical_warp_group_idx() <= 1
             ? cutlass::canonical_warp_group_idx() + 1
             : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
    cutlass::arch::NamedBarrier::arrive(
        NumMmaThreads,
        (cutlass::canonical_warp_group_idx() <= 0
             ? cutlass::canonical_warp_group_idx() + 2
             : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
639
640
641
  }
}

642
template <int NumMmaThreads> TL_DEVICE void mma_init() {
643
644
645
646
647
648
649
650
651
652
  static_assert(NumMmaThreads == 256 || NumMmaThreads == 384);
  if (cutlass::canonical_warp_group_idx() > 0) {
    cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 0);
  }
  if constexpr (NumMmaThreads == 384) {
    if (cutlass::canonical_warp_group_idx() > 1) {
      cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 1 /*id*/);
    }
  }
}
653
} // namespace tl