gemm_sm89.h 20.3 KB
Newer Older
1
2
3
4
5
6
7
8
#pragma once

#include <cute/algorithm/clear.hpp>
#include <cute/arch/mma_sm80.hpp>
#include <cute/atom/mma_atom.hpp>
#include <cute/atom/mma_traits.hpp>
#include <cute/underscore.hpp>

9
10
11
#include "common.h"
#include "cuda_fp8.h"

12
13
14
namespace cute {

template <typename A_type, typename B_type, typename C_type, int num_warp_m,
15
          int num_warp_n, int N>
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
struct DispatchInstruction;

using _X = Underscore;

#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 890))

struct SM89_16x8x32_F32F8F8F32_E4M3_TN {
  using DRegisters = float[4];
  using ARegisters = uint32_t[4];
  using BRegisters = uint32_t[2];
  using CRegisters = float[4];

  CUTE_HOST_DEVICE static void fma(float &d0, float &d1, float &d2, float &d3,
                                   uint32_t const &a0, uint32_t const &a1,
                                   uint32_t const &a2, uint32_t const &a3,
                                   uint32_t const &b0, uint32_t const &b1,
                                   float const &c0, float const &c1,
                                   float const &c2, float const &c3) {
    asm volatile("mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
                 "{%0,  %1,  %2,  %3},"
                 "{%4,  %5,  %6,  %7},"
                 "{%8,  %9},"
                 "{%10, %11, %12, %13};\n"
                 : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3)
                 : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1),
                   "f"(c0), "f"(c1), "f"(c2), "f"(c3));
  }
};

struct SM89_16x8x32_F32F8F8F32_E5M2_TN {
  using DRegisters = float[4];
  using ARegisters = uint32_t[4];
  using BRegisters = uint32_t[2];
  using CRegisters = float[4];

  CUTE_HOST_DEVICE static void fma(float &d0, float &d1, float &d2, float &d3,
                                   uint32_t const &a0, uint32_t const &a1,
                                   uint32_t const &a2, uint32_t const &a3,
                                   uint32_t const &b0, uint32_t const &b1,
                                   float const &c0, float const &c1,
                                   float const &c2, float const &c3) {
    asm volatile("mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 "
                 "{%0,  %1,  %2,  %3},"
                 "{%4,  %5,  %6,  %7},"
                 "{%8,  %9},"
                 "{%10, %11, %12, %13};\n"
                 : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3)
                 : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1),
                   "f"(c0), "f"(c1), "f"(c2), "f"(c3));
  }
};

// (T32,V1) -> (M8,N8)
using SM80_8x4 = Layout<Shape<Shape<_4, _8>, _1>, Stride<Stride<_8, _1>, _0>>;
// (T32,V2) -> (M8,N8)
using SM80_8x8_Row =
    Layout<Shape<Shape<_4, _8>, _2>, Stride<Stride<_16, _1>, _8>>;
// (T32,V4) -> (M8,N16)
using SM80_8x16_Row =
    Layout<Shape<Shape<_4, _8>, _4>, Stride<Stride<_32, _1>, _8>>;
// (T32,V4) -> (M16,N8)
using SM80_16x8_Row = Layout<Shape<Shape<_4, _8>, Shape<_2, _2>>,
                             Stride<Stride<_32, _1>, Stride<_16, _8>>>;

template <> struct MMA_Traits<SM89_16x8x32_F32F8F8F32_E4M3_TN> {
  using ValTypeD = float;
  using ValTypeA = fp8_e4_t;
  using ValTypeB = fp8_e4_t;
  using ValTypeC = float;

  using Shape_MNK = Shape<_16, _8, _32>;
  using ThrID = Layout<_32>;
  using ALayout = Layout<Shape<Shape<_4, _8>, Shape<_4, _2, _2>>,
                         Stride<Stride<_64, _1>, Stride<_16, _8, _256>>>;
  using BLayout = Layout<Shape<Shape<_4, _8>, Shape<_4, _2>>,
                         Stride<Stride<_32, _1>, Stride<_8, _128>>>;
  using CLayout = SM80_16x8_Row;
};

template <> struct MMA_Traits<SM89_16x8x32_F32F8F8F32_E5M2_TN> {
  using ValTypeD = float;
  using ValTypeA = fp8_e5_t;
  using ValTypeB = fp8_e5_t;
  using ValTypeC = float;

  using Shape_MNK = Shape<_16, _8, _32>;
  using ThrID = Layout<_32>;
  using ALayout = Layout<Shape<Shape<_4, _8>, Shape<_4, _2, _2>>,
                         Stride<Stride<_64, _1>, Stride<_16, _8, _256>>>;
  using BLayout = Layout<Shape<Shape<_4, _8>, Shape<_4, _2>>,
                         Stride<Stride<_32, _1>, Stride<_8, _128>>>;
  using CLayout = SM80_16x8_Row;
};

110
111
112
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, num_warp_m, num_warp_n,
                           N> {
113
  using MMA = MMA_Atom<SM89_16x8x32_F32F8F8F32_E4M3_TN>;
114
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
115
};
116
117
118
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, num_warp_m, num_warp_n,
                           N> {
119
  using MMA = MMA_Atom<SM89_16x8x32_F32F8F8F32_E5M2_TN>;
120
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
121
122
};

123
124
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n, N> {
125
  using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
126
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
127
};
128
129
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
130
  using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
131
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
132
};
133
template <int num_warp_m, int num_warp_n, int N>
134
struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m,
135
                           num_warp_n, N> {
136
  using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
137
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
138
};
139
template <int num_warp_m, int num_warp_n, int N>
140
struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m,
141
                           num_warp_n, N> {
142
  using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
143
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
144
};
145
146
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n, N> {
147
  using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
148
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
149
};
150
151
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n, N> {
152
153
154
155
  using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
  using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>;
};
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
156
157
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
158
  using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
159
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _16>;
160
161
162
};
#endif

163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
template <int N, int num_warp_n, bool transpose> struct SelectCopy {
  static constexpr int remainder = (N / num_warp_n) % 16;
  using type = std::conditional_t<
      remainder == 4 || remainder == 8 || remainder == 0,
      std::conditional_t<
          transpose,
          std::conditional_t<
              remainder == 4, SM75_U32x1_LDSM_N,
              std::conditional_t<remainder == 8, SM75_U32x2_LDSM_N,
                                 SM75_U32x4_LDSM_N>>,
          std::conditional_t<
              remainder == 4, SM75_U16x2_LDSM_T,
              std::conditional_t<remainder == 8, SM75_U16x4_LDSM_T,
                                 SM75_U16x8_LDSM_T>>>,
      DefaultCopy>;
};

180
181
template <int Bits, int N, int K, bool K_inner, int num_warp_n,
          typename Enable = void>
182
183
184
185
186
187
188
189
190
191
192
struct OperandTraits {
  // Primary template, use padded layout and default copy
  static constexpr int stride = K_inner ? K : N;
  static constexpr int padded =
      stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride;
  using Layout = typename std::conditional<
      K_inner, Layout<Shape<Int<N>, Int<K>>, Shape<Int<padded>, _1>>,
      Layout<Shape<Int<N>, Int<K>>, Shape<_1, Int<padded>>>>::type;
  using Copy = DefaultCopy;
};

193
194
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
195
196
197
198
                     typename std::enable_if<K % 64 == 32>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
199
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
200
201
};

202
203
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, true, num_warp_n,
204
205
206
207
                     typename std::enable_if<K % 64 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
208
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
209
210
};

211
212
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, num_warp_n,
213
214
215
216
217
                     typename std::enable_if<N % 64 == 32>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
218
  using Copy = typename SelectCopy<N, num_warp_n, false>::type;
219
220
};

221
222
template <int N, int K, int num_warp_n>
struct OperandTraits<16, N, K, false, num_warp_n,
223
224
225
226
227
                     typename std::enable_if<N % 64 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
228
  using Copy = typename SelectCopy<N, num_warp_n, false>::type;
229
230
};

231
232
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, num_warp_n,
233
234
235
236
                     typename std::enable_if<K % 32 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
237
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
238
239
};

240
241
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, true, num_warp_n,
242
243
244
245
                     typename std::enable_if<K % 32 == 16>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
246
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
247
248
};

249
250
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, num_warp_n,
251
252
253
254
255
256
257
258
                     typename std::enable_if<N % 32 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
  using Copy = UniversalCopy<tfloat32_t>;
};

259
260
template <int N, int K, int num_warp_n>
struct OperandTraits<32, N, K, false, num_warp_n,
261
262
263
264
265
266
267
268
                     typename std::enable_if<N % 32 == 16>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
  using Copy = UniversalCopy<tfloat32_t>;
};

269
270
template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, num_warp_n,
271
272
273
274
                     typename std::enable_if<K % 128 == 64>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
275
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
276
277
};

278
279
template <int N, int K, int num_warp_n>
struct OperandTraits<8, N, K, true, num_warp_n,
280
281
282
283
                     typename std::enable_if<K % 128 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
284
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
285
286
};

287
288
template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, true, num_warp_n,
289
290
291
292
293
294
295
                     typename std::enable_if<K % 16 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
  using Copy = DefaultCopy;
};

296
297
template <int N, int K, int num_warp_n>
struct OperandTraits<64, N, K, false, num_warp_n,
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
                     typename std::enable_if<N % 16 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
  using Copy = DefaultCopy;
};

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
          bool trans_B, bool clear_accum, typename A_type_raw,
          typename B_type_raw, typename C_type_raw>
class GemmTensorOp {
public:
  using A_type =
      typename std::conditional<std::is_same<A_type_raw, float>::value,
                                tfloat32_t, A_type_raw>::type;
  using B_type =
      typename std::conditional<std::is_same<B_type_raw, float>::value,
                                tfloat32_t, A_type_raw>::type;
  using C_type = C_type_raw;
  using Instruction =
319
      DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
320
321

  using OperandATraits =
322
      OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A, num_warp_m>;
323
  using OperandBTraits =
324
325
      OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B, num_warp_n>;

326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
  using SmemLayoutA = typename OperandATraits::Layout;
  using SmemLayoutB = typename OperandBTraits::Layout;
  using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>;
  using SmemCopyB = Copy_Atom<typename OperandBTraits::Copy, B_type>;

  using TileMma = TiledMMA<typename Instruction::MMA,
                           Layout<Shape<Int<num_warp_m>, Int<num_warp_n>, _1>>,
                           typename Instruction::MMA_Group>;

  template <class... Args>
  static CUTE_DEVICE auto remove_swizzle(Layout<Args...> const &layout) {
    return layout;
  }
  // In fp16, when layout is KxN and n_warp is 1 and N % 64 == 0
  // the original layout fail to compile, currently using this as a workaround
  template <class... Args>
  static CUTE_DEVICE auto
  remove_swizzle(ComposedLayout<Args...> const &layout) {
    if constexpr (sizeof(A_type) == 2)
      return layout.layout_b();
    else
      return layout;
  }

  static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
    const int tid = threadIdx.x;
    Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                            SmemLayoutA{});
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
    auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
    auto thr_copy_A = tiled_copy_A.get_thread_slice(tid);
    auto thr_copy_B = tiled_copy_B.get_thread_slice(tid);

    Tensor tCrA = thr_mma.partition_fragment_A(sA);
    Tensor tCrB = thr_mma.partition_fragment_B(sB);
    Tensor tCsA = thr_copy_A.partition_S(sA);
    Tensor tCsB = thr_copy_B.partition_S(sB);

    Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
    Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);

    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));

    if constexpr (clear_accum) {
      clear(acc);
    }
    // when layout is KxN and n_warp is 1, there seem to be a bug, use this as a
    // workaround
    auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
    auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
      copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k));
      gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc);
    }
  }

  static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
    const int tid = threadIdx.x;
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
    auto thr_copy_B = tiled_copy_B.get_thread_slice(tid);

    Tensor tCrB = thr_mma.partition_fragment_B(sB);
    Tensor tCsB = thr_copy_B.partition_S(sB);

    Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);

    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
    Tensor tCrA =
        make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
                    partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));

    if constexpr (clear_accum) {
      clear(acc);
    }
    auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
    copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0));
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      if (k < size<2>(tCrA) - 1) {
        copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1));
      }
      gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc);
    }
  }

  static CUTE_DEVICE void body_sr(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
    const int tid = threadIdx.x;
    Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                            SmemLayoutA{});
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
    auto thr_copy_A = tiled_copy_A.get_thread_slice(tid);

    Tensor tCrA = thr_mma.partition_fragment_A(sA);
    Tensor tCsA = thr_copy_A.partition_S(sA);

    Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);

    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
    Tensor tCrB =
        make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
                    partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));

    if constexpr (clear_accum) {
      clear(acc);
    }
    auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
    copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0));
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      if (k < size<2>(tCrA) - 1) {
        copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1));
      }
      gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc);
    }
  }
};

} // namespace cute

namespace tl {

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
          bool trans_B, bool clear_accum, typename A_type, typename B_type,
          typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
  using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
                                 trans_B, clear_accum, A_type, B_type, C_type>;
  MMA::body(pA, pB, accum);
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
          bool trans_B, bool clear_accum, typename A_type, typename B_type,
          typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
  using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
                                 trans_B, clear_accum, A_type, B_type, C_type>;
  MMA::body_rs(pA, pB, accum);
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
          bool trans_B, bool clear_accum, typename A_type, typename B_type,
          typename C_type>
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
  using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
                                 trans_B, clear_accum, A_type, B_type, C_type>;
  MMA::body_sr(pA, pB, accum);
}

} // namespace tl