gemm_sm90.h 34.2 KB
Newer Older
1
2
#pragma once

3
#include "common.h"
4
#include "gemm_mma.h"
5
#include "intrin.h"
6

7
8
#include <cutlass/arch/barrier.h>
#include <cutlass/cutlass.h>
9
#include <cutlass/gemm/collective/collective_builder.hpp>
10
11
12

namespace cute {

13
14
using namespace SM90;

15
namespace tl_wgmma {
16
17

using namespace cutlass::gemm::collective::detail; // ss_smem_selector
18

19
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
20
21
          bool trans_B, bool clear_accum, typename A_type_raw,
          typename B_type_raw, typename C_type_raw>
22
class GemmTensorOp {
23
24
25
26
27
public:
  using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
                               tfloat32_t, A_type_raw>;
  using B_type = conditional_t<std::is_same<B_type_raw, float>::value,
                               tfloat32_t, B_type_raw>;
28
29
  using C_type = C_type_raw;

30
31
32
33
  static constexpr GMMA::Major GmmaMajorA =
      trans_A ? GMMA::Major::MN : GMMA::Major::K;
  static constexpr GMMA::Major GmmaMajorB =
      trans_B ? GMMA::Major::K : GMMA::Major::MN;
34

35
  using SmemLayoutAtomA =
36
      decltype(ss_smem_selector<GmmaMajorA, A_type, Int<M>, Int<K>>());
37
  using SmemLayoutAtomB =
38
      decltype(ss_smem_selector<GmmaMajorB, B_type, Int<N>, Int<K>>());
39

40
41
42
43
44
45
  using SmemLayoutA = decltype(tile_to_shape(
      SmemLayoutAtomA{}, Shape<Int<M>, Int<K>>{},
      conditional_t<trans_A, Step<_2, _1>, Step<_1, _2>>{}));
  using SmemLayoutB = decltype(tile_to_shape(
      SmemLayoutAtomB{}, Shape<Int<N>, Int<K>>{},
      conditional_t<trans_B, Step<_1, _2>, Step<_2, _1>>{}));
46

47
48
  static_assert(num_warp_m % 4 == 0,
                "num_warp_m must be a multiple of 4 for hopper wgmma");
49

50
51
  template <int wg_wait = 0>
  static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
52
    const int tid = threadIdx.x;
53
54
55
56
57
    Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                            SmemLayoutA{});
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
    auto tiled_mma = make_tiled_mma(
58
59
        GMMA::ss_op_selector<
            A_type, B_type, C_type,
60
            Shape<Int<4 * M / num_warp_m>, Int<N / num_warp_n>, Int<K>>,
61
            GmmaMajorA, GmmaMajorB>(),
62
        Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
63
64
65
    auto thr_mma = tiled_mma.get_thread_slice(tid);

    // Allocate registers for pipelining
66
67
    Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
    Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
68

69
70
    Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE)
    Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE)
71

72
73
74
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
75
76
77

    warpgroup_fence_operand(acc);
    warpgroup_arrive();
78
79
80
    if constexpr (clear_accum) {
      tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
    }
81
82
83
84
85
86
87
88
89
    CUTLASS_PRAGMA_UNROLL
    for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
      // warpgroup_arrive();
      // (V,M) x (V,N) => (V,M,N)
      gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc);
      tiled_mma.accumulate_ = GMMA::ScaleOut::One;
    }

    warpgroup_commit_batch();
90
91
92
    if constexpr (wg_wait >= 0) {
      warpgroup_wait<wg_wait>();
    }
93
94
95
    warpgroup_fence_operand(acc);
  }

96
97
98
  template <int wg_wait = 0>
  static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
99
    // TODO: Move bar.sync out of body_rs
100
101
    // asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(num_warp_m * num_warp_n *
    // 32));
102
    const int tid = threadIdx.x;
103
104
105
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
    auto tiled_mma = make_tiled_mma(
106
107
108
109
        GMMA::rs_op_selector<
            A_type, B_type, C_type,
            Shape<Int<M / (num_warp_m / 4)>, Int<N / num_warp_n>, Int<K>>,
            GmmaMajorA, GmmaMajorB>(),
110
        Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
111
112
113
    auto thr_mma = tiled_mma.get_thread_slice(tid);

    // Allocate registers for pipelining
114
115
116
117
118
119
120
121
    Tensor tCsB = thr_mma.partition_B(sB);       // (MMA,MMA_N,MMA_K,PIPE)
    Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE)
    Tensor tCrA =
        make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
                    partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
122

123
124
125
    warpgroup_fence_operand(tCrA);
    warpgroup_fence_operand(acc);
    warpgroup_arrive();
126
127
128
    if constexpr (clear_accum) {
      tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
    }
129
130
131
132
133
134
135
136
    CUTLASS_PRAGMA_UNROLL
    for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
      // warpgroup_arrive();
      // (V,M) x (V,N) => (V,M,N)
      gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc);
      tiled_mma.accumulate_ = GMMA::ScaleOut::One;
    }
    warpgroup_commit_batch();
137
138
139
    if constexpr (wg_wait >= 0) {
      warpgroup_wait<wg_wait>();
    }
140
141
142
143
144
    warpgroup_fence_operand(acc);
    warpgroup_fence_operand(tCrA);
  }
};

145
146
147
148
149
} // namespace tl_wgmma

namespace tl_mma {

template <typename A_type, typename B_type, typename C_type, int num_warp_m,
150
          int num_warp_n, int N>
151
152
153
154
struct DispatchInstruction;

using _X = Underscore;

155
156
157
158
159
160
161
162
163
164
165
166
167
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, num_warp_m, num_warp_n,
                           N> {
  using MMA = MMA_Atom<SM89_16x8x32_F32E4M3E4M3F32_TN>;
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, num_warp_m, num_warp_n,
                           N> {
  using MMA = MMA_Atom<SM89_16x8x32_F32E5M2E5M2F32_TN>;
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};

168
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
169
170
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n, N> {
171
  using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
172
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
173
};
174
175
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
176
  using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
177
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
178
};
179
template <int num_warp_m, int num_warp_n, int N>
180
struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m,
181
                           num_warp_n, N> {
182
  using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
183
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
184
};
185
template <int num_warp_m, int num_warp_n, int N>
186
struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m,
187
                           num_warp_n, N> {
188
  using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
189
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
190
};
191
192
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n, N> {
193
  using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
194
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
195
};
196
197
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n, N> {
198
199
200
201
  using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
  using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>;
};
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
202
203
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n, N> {
204
  using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
205
  using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _16>;
206
207
208
};
#endif

209
template <int Bits, int N, int K, bool K_inner, int num_warp_n, int leading_dim,
210
          typename Enable = void>
211
212
struct OperandTraits {
  // Primary template, use padded layout and default copy
213
  static constexpr int stride = leading_dim;
214
215
216
  static constexpr int padded =
      stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride;
  using Layout = typename std::conditional<
217
218
      K_inner, Layout<Shape<Int<N>, Int<leading_dim>>, Shape<Int<padded>, _1>>,
      Layout<Shape<Int<leading_dim>, Int<K>>, Shape<_1, Int<padded>>>>::type;
219
220
221
  using Copy = DefaultCopy;
};

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
template <int N, int num_warp_n, bool transpose> struct SelectCopy {
  static constexpr int remainder = (N / num_warp_n) % 16;
  using type = std::conditional_t<
      remainder == 4 || remainder == 8 || remainder == 0,
      std::conditional_t<
          transpose,
          std::conditional_t<
              remainder == 4, SM75_U32x1_LDSM_N,
              std::conditional_t<remainder == 8, SM75_U32x2_LDSM_N,
                                 SM75_U32x4_LDSM_N>>,
          std::conditional_t<
              remainder == 4, SM75_U16x2_LDSM_T,
              std::conditional_t<remainder == 8, SM75_U16x4_LDSM_T,
                                 SM75_U16x8_LDSM_T>>>,
      DefaultCopy>;
};

239
240
241
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<16, N, K, true, num_warp_n, leading_dim,
                     typename std::enable_if<leading_dim % 64 == 32>::type> {
242
243
  using LayoutAtom = decltype(composition(
      Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
244
245
  using Layout =
      decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
246
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
247
248
};

249
250
251
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<16, N, K, true, num_warp_n, leading_dim,
                     typename std::enable_if<leading_dim % 64 == 0>::type> {
252
253
  using LayoutAtom = decltype(composition(
      Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
254
255
  using Layout =
      decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
256
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
257
258
};

259
260
261
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<16, N, K, false, num_warp_n, leading_dim,
                     typename std::enable_if<leading_dim % 64 == 32>::type> {
262
263
  using LayoutAtom = decltype(composition(
      Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
264
265
  using Layout = decltype(tile_to_shape(
      LayoutAtom{}, Shape<Int<leading_dim>, Int<K>>{}, Step<_2, _1>{}));
266
  using Copy = typename SelectCopy<N, num_warp_n, false>::type;
267
268
};

269
270
271
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<16, N, K, false, num_warp_n, leading_dim,
                     typename std::enable_if<leading_dim % 64 == 0>::type> {
272
273
  using LayoutAtom = decltype(composition(
      Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
274
275
  using Layout = decltype(tile_to_shape(
      LayoutAtom{}, Shape<Int<leading_dim>, Int<K>>{}, Step<_2, _1>{}));
276
  using Copy = typename SelectCopy<N, num_warp_n, false>::type;
277
278
};

279
280
281
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<32, N, K, true, num_warp_n, leading_dim,
                     typename std::enable_if<leading_dim % 32 == 0>::type> {
282
283
  using LayoutAtom = decltype(composition(
      Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
284
285
  using Layout =
      decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
286
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
287
288
};

289
290
291
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<32, N, K, true, num_warp_n, leading_dim,
                     typename std::enable_if<leading_dim % 32 == 16>::type> {
292
293
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
294
295
  using Layout =
      decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
296
  using Copy = typename SelectCopy<N, num_warp_n, true>::type;
297
298
};

299
300
301
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<32, N, K, false, num_warp_n, leading_dim,
                     typename std::enable_if<leading_dim % 32 == 0>::type> {
302
303
  using LayoutAtom = decltype(composition(
      Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
304
305
  using Layout = decltype(tile_to_shape(
      LayoutAtom{}, Shape<Int<leading_dim>, Int<K>>{}, Step<_2, _1>{}));
306
307
308
  using Copy = UniversalCopy<tfloat32_t>;
};

309
310
311
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<32, N, K, false, num_warp_n, leading_dim,
                     typename std::enable_if<leading_dim % 32 == 16>::type> {
312
313
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}));
314
315
  using Layout = decltype(tile_to_shape(
      LayoutAtom{}, Shape<Int<leading_dim>, Int<K>>{}, Step<_2, _1>{}));
316
317
318
  using Copy = UniversalCopy<tfloat32_t>;
};

319
320
321
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<8, N, K, true, num_warp_n, leading_dim,
                     typename std::enable_if<leading_dim % 128 == 64>::type> {
322
323
  using LayoutAtom = decltype(composition(
      Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
324
325
  using Layout =
      decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
326
327
  using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
                                         SM75_U32x4_LDSM_N>::type;
328
329
};

330
331
332
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<8, N, K, true, num_warp_n, leading_dim,
                     typename std::enable_if<leading_dim % 128 == 0>::type> {
333
334
  using LayoutAtom = decltype(composition(
      Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
335
336
  using Layout =
      decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
337
338
  using Copy = typename std::conditional<N == 8 * num_warp_n, SM75_U32x2_LDSM_N,
                                         SM75_U32x4_LDSM_N>::type;
339
340
};

341
342
343
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<64, N, K, true, num_warp_n, leading_dim,
                     typename std::enable_if<leading_dim % 16 == 0>::type> {
344
345
  using LayoutAtom = decltype(composition(
      Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{}));
346
347
  using Layout =
      decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<leading_dim>>{}));
348
349
350
  using Copy = DefaultCopy;
};

351
352
353
template <int N, int K, int num_warp_n, int leading_dim>
struct OperandTraits<64, N, K, false, num_warp_n, leading_dim,
                     typename std::enable_if<leading_dim % 16 == 0>::type> {
354
355
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{}));
356
357
  using Layout = decltype(tile_to_shape(
      LayoutAtom{}, Shape<Int<leading_dim>, Int<K>>{}, Step<_2, _1>{}));
358
359
360
361
  using Copy = DefaultCopy;
};

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
362
363
364
          bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
          int offset_b, typename A_type_raw, typename B_type_raw,
          typename C_type_raw>
365
366
367
368
369
370
371
372
373
class GemmTensorOp {
public:
  using A_type =
      typename std::conditional<std::is_same<A_type_raw, float>::value,
                                tfloat32_t, A_type_raw>::type;
  using B_type =
      typename std::conditional<std::is_same<B_type_raw, float>::value,
                                tfloat32_t, A_type_raw>::type;
  using C_type = C_type_raw;
374

375
  using Instruction =
376
      DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
377

378
379
  using OperandATraits = OperandTraits<sizeof_bits<A_type>::value, M, K,
                                       !trans_A, num_warp_m, lda>;
380
  using OperandBTraits =
381
382
      OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B, num_warp_n, ldb>;

383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
  using SmemLayoutA = typename OperandATraits::Layout;
  using SmemLayoutB = typename OperandBTraits::Layout;
  using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>;
  using SmemCopyB = Copy_Atom<typename OperandBTraits::Copy, B_type>;

  using TileMma = TiledMMA<typename Instruction::MMA,
                           Layout<Shape<Int<num_warp_m>, Int<num_warp_n>, _1>>,
                           typename Instruction::MMA_Group>;

  template <class... Args>
  static CUTE_DEVICE auto remove_swizzle(Layout<Args...> const &layout) {
    return layout;
  }
  // In fp16, when layout is KxN and n_warp is 1 and N % 64 == 0
  // the original layout fail to compile, currently using this as a workaround
  template <class... Args>
  static CUTE_DEVICE auto
  remove_swizzle(ComposedLayout<Args...> const &layout) {
    if constexpr (sizeof(A_type) == 2)
      return layout.layout_b();
    else
      return layout;
  }

407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
  template <int offset, int NN, int KK, bool trans, int lddim, typename Engine0,
            typename Layout0>
  static CUTE_DEVICE auto get_region_tensor(Tensor<Engine0, Layout0> &sa) {
    if constexpr (offset == 0) {
      return composition(
          sa,
          Layout<Shape<Int<NN>, Int<KK>>,
                 Stride<_1, typename std::conditional<trans, Int<NN>,
                                                      Int<lddim>>::type>>{});
    } else {
      if constexpr (trans) {
        static_assert(offset % KK == 0, "Offset must be a multiple of K");
        constexpr int offset_n = offset / KK;
        return flat_divide(sa, Shape<Int<NN>, Int<KK>>{})(_, _, _0{},
                                                          Int<offset_n>{});
      } else {
        static_assert(offset % NN == 0, "Offset must be a multiple of N");
        constexpr int offset_n = offset / NN;
        return flat_divide(sa, Shape<Int<NN>, Int<KK>>{})(_, _, Int<offset_n>{},
                                                          _0{});
      }
    }
  }

431
432
  static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
    const int tid = threadIdx.x;
433
434
435
436
437
438
    Tensor sA_all = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                                SmemLayoutA{});
    Tensor sB_all = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                                SmemLayoutB{});
    Tensor sA = get_region_tensor<offset_a, M, K, !trans_A, lda>(sA_all);
    Tensor sB = get_region_tensor<offset_b, N, K, trans_B, ldb>(sB_all);
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
    auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
    auto thr_copy_A = tiled_copy_A.get_thread_slice(tid);
    auto thr_copy_B = tiled_copy_B.get_thread_slice(tid);

    Tensor tCrA = thr_mma.partition_fragment_A(sA);
    Tensor tCrB = thr_mma.partition_fragment_B(sB);
    Tensor tCsA = thr_copy_A.partition_S(sA);
    Tensor tCsB = thr_copy_B.partition_S(sB);

    Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
    Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);

    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));

    // when layout is KxN and n_warp is 1, there seem to be a bug, use this as a
    // workaround
    auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
    auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
462
    if constexpr (clear_accum) {
463
      clear(acc);
464
    }
465
466
467
468
469
470
471
472
473
474
475
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
      copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k));
      gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc);
    }
  }

  static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
    const int tid = threadIdx.x;
476
477
478
    Tensor sB_all = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                                SmemLayoutB{});
    Tensor sB = get_region_tensor<offset_b, N, K, trans_B, ldb>(sB_all);
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
    auto thr_copy_B = tiled_copy_B.get_thread_slice(tid);

    Tensor tCrB = thr_mma.partition_fragment_B(sB);
    Tensor tCsB = thr_copy_B.partition_S(sB);

    Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);

    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
    Tensor tCrA =
        make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
                    partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
    auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
496
    if constexpr (clear_accum) {
497
      clear(acc);
498
    }
499
500
501
502
503
504
505
506
507
508
509
510
511
    copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0));
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      if (k < size<2>(tCrA) - 1) {
        copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1));
      }
      gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc);
    }
  }

  static CUTE_DEVICE void body_sr(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
    const int tid = threadIdx.x;
512
513
514
    Tensor sA_all = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                                SmemLayoutA{});
    Tensor sA = get_region_tensor<offset_a, M, K, !trans_A, lda>(sA_all);
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
    auto thr_copy_A = tiled_copy_A.get_thread_slice(tid);

    Tensor tCrA = thr_mma.partition_fragment_A(sA);
    Tensor tCsA = thr_copy_A.partition_S(sA);

    Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);

    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
    Tensor tCrB =
        make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
                    partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
    auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
532
    if constexpr (clear_accum) {
533
      clear(acc);
534
    }
535
536
537
538
539
540
541
542
543
544
545
546
547
    copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0));
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      if (k < size<2>(tCrA) - 1) {
        copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1));
      }
      gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc);
    }
  }
};

} // namespace tl_mma

548
} // namespace cute
549
/**
550
551
 * Execute a tiled GEMM where A is read from global memory and B is staged in
 * shared memory.
552
 *
553
554
 * Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_rs to perform the
 * computation.
555
556
557
558
559
560
 *
 * @param pA Pointer to the A tile region (device memory).
 * @param pB Pointer to the B tile region (device memory).
 * @param accum Pointer to the accumulator/output tile region (device memory).
 */
/**
561
562
 * Execute a tiled GEMM where A is staged in shared memory and B is read from
 * global memory.
563
 *
564
565
 * Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_sr to perform the
 * computation.
566
567
568
569
570
571
 *
 * @param pA Pointer to the A tile region (device memory).
 * @param pB Pointer to the B tile region (device memory).
 * @param accum Pointer to the accumulator/output tile region (device memory).
 */
/**
572
573
 * Perform a tiled GEMM (both operands in shared memory or selected backend) and
 * write to accum.
574
 *
575
576
577
 * If use_wgmma is true, validates wgmma constraints (strides and offsets) and
 * dispatches to the Hopper wgmma implementation; otherwise dispatches to the
 * tl_mma implementation.
578
579
580
581
582
583
 *
 * @param pA Pointer to the A tile region (device memory).
 * @param pB Pointer to the B tile region (device memory).
 * @param accum Pointer to the accumulator/output tile region (device memory).
 */
/**
584
585
 * Perform a tiled GEMM with A in global memory and B in shared memory (or
 * selected backend).
586
 *
587
588
589
 * If use_wgmma is true, validates wgmma constraints (strides and offsets) and
 * dispatches to the Hopper wgmma read-share implementation; otherwise
 * dispatches to the tl_mma read-share.
590
591
592
593
594
595
 *
 * @param pA Pointer to the A tile region (device memory).
 * @param pB Pointer to the B tile region (device memory).
 * @param accum Pointer to the accumulator/output tile region (device memory).
 */
/**
596
597
 * Perform a tiled GEMM with A staged in shared memory and B in global memory
 * (tl_mma only).
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
 *
 * wgmma does not support this variant; caller must set use_wgmma == false.
 * Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_sr.
 *
 * @param pA Pointer to the A tile region (device memory).
 * @param pB Pointer to the B tile region (device memory).
 * @param accum Pointer to the accumulator/output tile region (device memory).
 */
/**
 * Wait for a warp-group of WMMA/MMA warps to complete.
 *
 * Wrapper around cute::warpgroup_wait for the specified number of MMA warps.
 */
/**
 * Synchronize a named barrier across NumMmaThreads MMA threads.
 *
 * Calls cutlass::arch::NamedBarrier::sync with the canonical warp-group id.
 */
/**
617
618
 * Arrive at a named barrier for NumMmaThreads MMA threads using
 * architecture-aware mapping.
619
 *
620
621
622
 * Supported NumMmaThreads values: 256 or 384. The function issues one or two
 * barrier arrives depending on the thread-group topology to ensure proper
 * rendezvous ordering.
623
624
625
626
 */
/**
 * Initialize named-barrier state for multi-warp MMA execution.
 *
627
628
629
 * For NumMmaThreads == 256 or 384, performs the required initial barrier
 * arrivals for non-zero canonical warp-group indices to set up subsequent
 * barrier synchronization.
630
 */
631
632
633

namespace tl {

634
635
namespace tl_mma {

636
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
637
638
          bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
          int offset_b, typename A_type, typename B_type, typename C_type>
639
640
641
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
  using MMA =
      cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
642
643
                                 trans_B, clear_accum, lda, ldb, offset_a,
                                 offset_b, A_type, B_type, C_type>;
644
  MMA::body(pA, pB, accum);
645
646
}

647
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
648
649
          bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
          int offset_b, typename A_type, typename B_type, typename C_type>
650
651
652
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
  using MMA =
      cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
653
654
                                 trans_B, clear_accum, lda, ldb, offset_a,
                                 offset_b, A_type, B_type, C_type>;
655
656
657
658
  MMA::body_rs(pA, pB, accum);
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
659
660
          bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
          int offset_b, typename A_type, typename B_type, typename C_type>
661
662
663
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
  using MMA =
      cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
664
665
                                 trans_B, clear_accum, lda, ldb, offset_a,
                                 offset_b, A_type, B_type, C_type>;
666
667
668
669
670
671
  MMA::body_sr(pA, pB, accum);
}

} // namespace tl_mma

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
672
673
          bool trans_B, bool clear_accum = false, int lda = 0, int ldb = 0,
          int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
674
          int wg_wait = 0, typename A_type, typename B_type, typename C_type>
675
676
TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
  if constexpr (use_wgmma) {
677
678
679
680
681
682
    static_assert((trans_A && lda == M) || (!trans_A && lda == K),
                  "Hopper wgmma doesn't support custom stride for A");
    static_assert((trans_B && ldb == K) || (!trans_B && ldb == N),
                  "Hopper wgmma doesn't support custom stride for B");
    static_assert(offset_a == 0 && offset_b == 0,
                  "offset_a and offset_b must be zero for wgmma");
683
684
685
    using MMA = cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
                                             trans_A, trans_B, clear_accum,
                                             A_type, B_type, C_type>;
686
687
    MMA::body<wg_wait>(pA, pB, accum);
  } else {
688
689
690
691
    using MMA =
        cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
                                   trans_B, clear_accum, lda, ldb, offset_a,
                                   offset_b, A_type, B_type, C_type>;
692
693
694
695
696
    MMA::body(pA, pB, accum);
  }
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
697
698
          bool trans_B, bool clear_accum = false, int lda = 0, int ldb = 0,
          int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
699
          int wg_wait = 0, typename A_type, typename B_type, typename C_type>
700
TL_DEVICE /**
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
           * Perform a read-share (B in shared memory, A in global) tiled GEMM
           * and accumulate into `accum`.
           *
           * Dispatches at compile time to either the Hopper wgmma
           * implementation or the fallback MMA implementation depending on
           * `use_wgmma`. The selected GemmTensorOp::body_rs performs the
           * region-tiled GEMM loop and updates the accumulator in-place.
           *
           * When `use_wgmma == true`, this function enforces wgmma constraints
           * at compile time:
           * - A's leading dimension must equal (trans_A ? M : K)
           * - B's leading dimension must equal (trans_B ? K : N)
           * - offset_a and offset_b must be zero
           *
           * @param pA Pointer to operand A (global memory). Layout/stride
           * expectations depend on template parameters.
           * @param pB Pointer to operand B (base for shared-memory staging).
           * Layout/stride expectations depend on template parameters.
           * @param accum Pointer to the accumulator/output C buffer updated
           * in-place.
           */
    void
    gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
724
  if constexpr (use_wgmma) {
725
726
727
728
729
730
    static_assert((trans_A && lda == M) || (!trans_A && lda == K),
                  "Hopper wgmma doesn't support custom stride for A");
    static_assert((trans_B && ldb == K) || (!trans_B && ldb == N),
                  "Hopper wgmma doesn't support custom stride for B");
    static_assert(offset_a == 0 && offset_b == 0,
                  "offset_a and offset_b must be zero for wgmma");
731
732
733
    using MMA = cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
                                             trans_A, trans_B, clear_accum,
                                             A_type, B_type, C_type>;
734
735
    MMA::body_rs<wg_wait>(pA, pB, accum);
  } else {
736
737
738
739
    using MMA =
        cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
                                   trans_B, clear_accum, lda, ldb, offset_a,
                                   offset_b, A_type, B_type, C_type>;
740
741
    MMA::body_rs(pA, pB, accum);
  }
742
743
}

744
745
746
747
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
          bool trans_B, bool clear_accum = false, int lda = 0, int ldb = 0,
          int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
          int wg_wait = 0, typename A_type, typename B_type, typename C_type>
748
TL_DEVICE /**
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
           * Perform a non-wgmma tiled GEMM where A regions are staged into
           * shared memory and B is read directly from global memory,
           * accumulating into `accum`.
           *
           * This overload dispatches to the tl_mma::GemmTensorOp::body_sr
           * implementation. Must be instantiated with `use_wgmma = false`
           * (enforced via static_assert).
           *
           * @param pA Pointer to the A operand in global memory (source that
           * will be staged to shared memory).
           * @param pB Pointer to the B operand in global memory (read
           * directly).
           * @param accum Pointer to the output accumulator matrix in global
           * memory.
           */
    void
    gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
766
767
768
769
770
771
772
773
  static_assert(!use_wgmma, "wgmma doesn't support gemm_sr");
  using MMA =
      cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
                                 trans_B, clear_accum, lda, ldb, offset_a,
                                 offset_b, A_type, B_type, C_type>;
  MMA::body_sr(pA, pB, accum);
}

774
775
776
777
778
779
780
781
782
783
784
template <int num_mma>
TL_DEVICE /**
           * Wait for all WMMA/MMA warps in the current warp-group to
           * synchronize.
           *
           * Blocks until the warp-group-wide rendezvous for `num_mma` MMA lanes
           * completes, ensuring all participating warps have arrived before
           * proceeding.
           */
    void
    wait_wgmma() {
785
  cute::warpgroup_wait<num_mma>();
786
787
}

788
789
790
template <int NumMmaThreads> TL_DEVICE void warp_scheduler_barrier_sync() {
  cutlass::arch::NamedBarrier::sync(NumMmaThreads,
                                    cutlass::canonical_warp_group_idx() /*id*/);
791
792
}

793
template <int NumMmaThreads> TL_DEVICE void warp_scheduler_barrier_arrive() {
794
795
  static_assert(NumMmaThreads == 256 || NumMmaThreads == 384);
  if constexpr (NumMmaThreads == 256) {
796
797
    cutlass::arch::NamedBarrier::arrive(
        NumMmaThreads, (1 - cutlass::canonical_warp_group_idx()) /*id*/);
798
  } else {
799
800
801
802
803
804
805
806
807
808
    cutlass::arch::NamedBarrier::arrive(
        NumMmaThreads,
        (cutlass::canonical_warp_group_idx() <= 1
             ? cutlass::canonical_warp_group_idx() + 1
             : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
    cutlass::arch::NamedBarrier::arrive(
        NumMmaThreads,
        (cutlass::canonical_warp_group_idx() <= 0
             ? cutlass::canonical_warp_group_idx() + 2
             : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
809
810
811
  }
}

812
template <int NumMmaThreads> TL_DEVICE void mma_init() {
813
814
815
816
817
818
819
820
821
822
  static_assert(NumMmaThreads == 256 || NumMmaThreads == 384);
  if (cutlass::canonical_warp_group_idx() > 0) {
    cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 0);
  }
  if constexpr (NumMmaThreads == 384) {
    if (cutlass::canonical_warp_group_idx() > 1) {
      cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 1 /*id*/);
    }
  }
}
823
} // namespace tl