gemm_sm90.h 15.4 KB
Newer Older
1
2
#pragma once

3
#include "common.h"
4
#include "gemm_mma.h"
5
#include "intrin.h"
6

7
8
#include <cutlass/arch/barrier.h>
#include <cutlass/cutlass.h>
9
#include <cutlass/gemm/collective/collective_builder.hpp>
10
11
12

namespace cute {

13
14
using namespace SM90;

15
namespace tl_wgmma {
16
17

using namespace cutlass::gemm::collective::detail; // ss_smem_selector
18

19
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
20
21
          bool trans_B, bool clear_accum, typename A_type_raw,
          typename B_type_raw, typename C_type_raw>
22
class GemmTensorOp {
23
public:
24
25
26
27
28
29
  using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
  using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
  using A_type = conditional_t<std::is_same<A_type_cute, float>::value,
                               tfloat32_t, A_type_cute>;
  using B_type = conditional_t<std::is_same<B_type_cute, float>::value,
                               tfloat32_t, A_type_cute>;
30
31
  using C_type = C_type_raw;

32
33
34
35
  static constexpr GMMA::Major GmmaMajorA =
      trans_A ? GMMA::Major::MN : GMMA::Major::K;
  static constexpr GMMA::Major GmmaMajorB =
      trans_B ? GMMA::Major::K : GMMA::Major::MN;
36

37
  using SmemLayoutAtomA =
38
      decltype(ss_smem_selector<GmmaMajorA, A_type, Int<M>, Int<K>>());
39
  using SmemLayoutAtomB =
40
      decltype(ss_smem_selector<GmmaMajorB, B_type, Int<N>, Int<K>>());
41

42
43
44
45
46
47
  using SmemLayoutA = decltype(tile_to_shape(
      SmemLayoutAtomA{}, Shape<Int<M>, Int<K>>{},
      conditional_t<trans_A, Step<_2, _1>, Step<_1, _2>>{}));
  using SmemLayoutB = decltype(tile_to_shape(
      SmemLayoutAtomB{}, Shape<Int<N>, Int<K>>{},
      conditional_t<trans_B, Step<_1, _2>, Step<_2, _1>>{}));
48

49
50
  static_assert(num_warp_m % 4 == 0,
                "num_warp_m must be a multiple of 4 for hopper wgmma");
51

52
53
  template <int wg_wait = 0>
  static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
54
    const int tid = threadIdx.x;
55
56
57
58
59
    Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                            SmemLayoutA{});
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
    auto tiled_mma = make_tiled_mma(
60
61
        GMMA::ss_op_selector<
            A_type, B_type, C_type,
62
            Shape<Int<4 * M / num_warp_m>, Int<N / num_warp_n>, Int<K>>,
63
            GmmaMajorA, GmmaMajorB>(),
64
        Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
65
66
67
    auto thr_mma = tiled_mma.get_thread_slice(tid);

    // Allocate registers for pipelining
68
69
    Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
    Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
70

71
72
    Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE)
    Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE)
73

74
75
76
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
77
78
79

    warpgroup_fence_operand(acc);
    warpgroup_arrive();
80
81
82
    if constexpr (clear_accum) {
      tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
    }
83
84
85
86
87
88
89
90
91
    CUTLASS_PRAGMA_UNROLL
    for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
      // warpgroup_arrive();
      // (V,M) x (V,N) => (V,M,N)
      gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc);
      tiled_mma.accumulate_ = GMMA::ScaleOut::One;
    }

    warpgroup_commit_batch();
92
93
94
    if constexpr (wg_wait >= 0) {
      warpgroup_wait<wg_wait>();
    }
95
96
97
    warpgroup_fence_operand(acc);
  }

98
99
100
  template <int wg_wait = 0>
  static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
101
    // TODO: Move bar.sync out of body_rs
102
103
    // asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(num_warp_m * num_warp_n *
    // 32));
104
    const int tid = threadIdx.x;
105
106
107
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
    auto tiled_mma = make_tiled_mma(
108
109
110
111
        GMMA::rs_op_selector<
            A_type, B_type, C_type,
            Shape<Int<M / (num_warp_m / 4)>, Int<N / num_warp_n>, Int<K>>,
            GmmaMajorA, GmmaMajorB>(),
112
        Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
113
114
115
    auto thr_mma = tiled_mma.get_thread_slice(tid);

    // Allocate registers for pipelining
116
117
118
119
120
121
122
123
    Tensor tCsB = thr_mma.partition_B(sB);       // (MMA,MMA_N,MMA_K,PIPE)
    Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE)
    Tensor tCrA =
        make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
                    partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
124

125
126
127
    warpgroup_fence_operand(tCrA);
    warpgroup_fence_operand(acc);
    warpgroup_arrive();
128
129
130
    if constexpr (clear_accum) {
      tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
    }
131
132
133
134
135
136
137
138
    CUTLASS_PRAGMA_UNROLL
    for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
      // warpgroup_arrive();
      // (V,M) x (V,N) => (V,M,N)
      gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc);
      tiled_mma.accumulate_ = GMMA::ScaleOut::One;
    }
    warpgroup_commit_batch();
139
140
141
    if constexpr (wg_wait >= 0) {
      warpgroup_wait<wg_wait>();
    }
142
143
144
145
146
    warpgroup_fence_operand(acc);
    warpgroup_fence_operand(tCrA);
  }
};

147
148
} // namespace tl_wgmma

149
} // namespace cute
150
/**
151
152
 * Execute a tiled GEMM where A is read from global memory and B is staged in
 * shared memory.
153
 *
154
155
 * Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_rs to perform the
 * computation.
156
157
158
159
160
161
 *
 * @param pA Pointer to the A tile region (device memory).
 * @param pB Pointer to the B tile region (device memory).
 * @param accum Pointer to the accumulator/output tile region (device memory).
 */
/**
162
163
 * Execute a tiled GEMM where A is staged in shared memory and B is read from
 * global memory.
164
 *
165
166
 * Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_sr to perform the
 * computation.
167
168
169
170
171
172
 *
 * @param pA Pointer to the A tile region (device memory).
 * @param pB Pointer to the B tile region (device memory).
 * @param accum Pointer to the accumulator/output tile region (device memory).
 */
/**
173
174
 * Perform a tiled GEMM (both operands in shared memory or selected backend) and
 * write to accum.
175
 *
176
177
178
 * If use_wgmma is true, validates wgmma constraints (strides and offsets) and
 * dispatches to the Hopper wgmma implementation; otherwise dispatches to the
 * tl_mma implementation.
179
180
181
182
183
184
 *
 * @param pA Pointer to the A tile region (device memory).
 * @param pB Pointer to the B tile region (device memory).
 * @param accum Pointer to the accumulator/output tile region (device memory).
 */
/**
185
186
 * Perform a tiled GEMM with A in global memory and B in shared memory (or
 * selected backend).
187
 *
188
189
190
 * If use_wgmma is true, validates wgmma constraints (strides and offsets) and
 * dispatches to the Hopper wgmma read-share implementation; otherwise
 * dispatches to the tl_mma read-share.
191
192
193
194
195
196
 *
 * @param pA Pointer to the A tile region (device memory).
 * @param pB Pointer to the B tile region (device memory).
 * @param accum Pointer to the accumulator/output tile region (device memory).
 */
/**
197
198
 * Perform a tiled GEMM with A staged in shared memory and B in global memory
 * (tl_mma only).
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
 *
 * wgmma does not support this variant; caller must set use_wgmma == false.
 * Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_sr.
 *
 * @param pA Pointer to the A tile region (device memory).
 * @param pB Pointer to the B tile region (device memory).
 * @param accum Pointer to the accumulator/output tile region (device memory).
 */
/**
 * Wait for a warp-group of WMMA/MMA warps to complete.
 *
 * Wrapper around cute::warpgroup_wait for the specified number of MMA warps.
 */
/**
 * Synchronize a named barrier across NumMmaThreads MMA threads.
 *
 * Calls cutlass::arch::NamedBarrier::sync with the canonical warp-group id.
 */
/**
218
219
 * Arrive at a named barrier for NumMmaThreads MMA threads using
 * architecture-aware mapping.
220
 *
221
222
223
 * Supported NumMmaThreads values: 256 or 384. The function issues one or two
 * barrier arrives depending on the thread-group topology to ensure proper
 * rendezvous ordering.
224
225
226
227
 */
/**
 * Initialize named-barrier state for multi-warp MMA execution.
 *
228
229
230
 * For NumMmaThreads == 256 or 384, performs the required initial barrier
 * arrivals for non-zero canonical warp-group indices to set up subsequent
 * barrier synchronization.
231
 */
232
233
234

namespace tl {

235
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
236
237
          bool trans_B, bool clear_accum = false, int lda = 0, int ldb = 0,
          int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
238
          int wg_wait = 0, typename A_type, typename B_type, typename C_type>
239
240
TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
  if constexpr (use_wgmma) {
241
242
243
244
245
246
    static_assert((trans_A && lda == M) || (!trans_A && lda == K),
                  "Hopper wgmma doesn't support custom stride for A");
    static_assert((trans_B && ldb == K) || (!trans_B && ldb == N),
                  "Hopper wgmma doesn't support custom stride for B");
    static_assert(offset_a == 0 && offset_b == 0,
                  "offset_a and offset_b must be zero for wgmma");
247
248
249
    using MMA = cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
                                             trans_A, trans_B, clear_accum,
                                             A_type, B_type, C_type>;
250
251
    MMA::body<wg_wait>(pA, pB, accum);
  } else {
252
253
254
255
    using MMA =
        cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
                                   trans_B, clear_accum, lda, ldb, offset_a,
                                   offset_b, A_type, B_type, C_type>;
256
257
258
259
260
    MMA::body(pA, pB, accum);
  }
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
261
262
          bool trans_B, bool clear_accum = false, int lda = 0, int ldb = 0,
          int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
263
          int wg_wait = 0, typename A_type, typename B_type, typename C_type>
264
TL_DEVICE /**
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
           * Perform a read-share (B in shared memory, A in global) tiled GEMM
           * and accumulate into `accum`.
           *
           * Dispatches at compile time to either the Hopper wgmma
           * implementation or the fallback MMA implementation depending on
           * `use_wgmma`. The selected GemmTensorOp::body_rs performs the
           * region-tiled GEMM loop and updates the accumulator in-place.
           *
           * When `use_wgmma == true`, this function enforces wgmma constraints
           * at compile time:
           * - A's leading dimension must equal (trans_A ? M : K)
           * - B's leading dimension must equal (trans_B ? K : N)
           * - offset_a and offset_b must be zero
           *
           * @param pA Pointer to operand A (global memory). Layout/stride
           * expectations depend on template parameters.
           * @param pB Pointer to operand B (base for shared-memory staging).
           * Layout/stride expectations depend on template parameters.
           * @param accum Pointer to the accumulator/output C buffer updated
           * in-place.
           */
    void
    gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
288
  if constexpr (use_wgmma) {
289
290
291
292
293
294
    static_assert((trans_A && lda == M) || (!trans_A && lda == K),
                  "Hopper wgmma doesn't support custom stride for A");
    static_assert((trans_B && ldb == K) || (!trans_B && ldb == N),
                  "Hopper wgmma doesn't support custom stride for B");
    static_assert(offset_a == 0 && offset_b == 0,
                  "offset_a and offset_b must be zero for wgmma");
295
296
297
    using MMA = cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
                                             trans_A, trans_B, clear_accum,
                                             A_type, B_type, C_type>;
298
299
    MMA::body_rs<wg_wait>(pA, pB, accum);
  } else {
300
301
302
303
    using MMA =
        cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
                                   trans_B, clear_accum, lda, ldb, offset_a,
                                   offset_b, A_type, B_type, C_type>;
304
305
    MMA::body_rs(pA, pB, accum);
  }
306
307
}

308
309
310
311
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
          bool trans_B, bool clear_accum = false, int lda = 0, int ldb = 0,
          int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
          int wg_wait = 0, typename A_type, typename B_type, typename C_type>
312
TL_DEVICE /**
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
           * Perform a non-wgmma tiled GEMM where A regions are staged into
           * shared memory and B is read directly from global memory,
           * accumulating into `accum`.
           *
           * This overload dispatches to the tl_mma::GemmTensorOp::body_sr
           * implementation. Must be instantiated with `use_wgmma = false`
           * (enforced via static_assert).
           *
           * @param pA Pointer to the A operand in global memory (source that
           * will be staged to shared memory).
           * @param pB Pointer to the B operand in global memory (read
           * directly).
           * @param accum Pointer to the output accumulator matrix in global
           * memory.
           */
    void
    gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
330
331
332
333
334
335
336
337
  static_assert(!use_wgmma, "wgmma doesn't support gemm_sr");
  using MMA =
      cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
                                 trans_B, clear_accum, lda, ldb, offset_a,
                                 offset_b, A_type, B_type, C_type>;
  MMA::body_sr(pA, pB, accum);
}

338
339
340
341
342
343
344
345
346
347
348
template <int num_mma>
TL_DEVICE /**
           * Wait for all WMMA/MMA warps in the current warp-group to
           * synchronize.
           *
           * Blocks until the warp-group-wide rendezvous for `num_mma` MMA lanes
           * completes, ensuring all participating warps have arrived before
           * proceeding.
           */
    void
    wait_wgmma() {
349
  cute::warpgroup_wait<num_mma>();
350
351
}

352
353
354
template <int NumMmaThreads> TL_DEVICE void warp_scheduler_barrier_sync() {
  cutlass::arch::NamedBarrier::sync(NumMmaThreads,
                                    cutlass::canonical_warp_group_idx() /*id*/);
355
356
}

357
template <int NumMmaThreads> TL_DEVICE void warp_scheduler_barrier_arrive() {
358
359
  static_assert(NumMmaThreads == 256 || NumMmaThreads == 384);
  if constexpr (NumMmaThreads == 256) {
360
361
    cutlass::arch::NamedBarrier::arrive(
        NumMmaThreads, (1 - cutlass::canonical_warp_group_idx()) /*id*/);
362
  } else {
363
364
365
366
367
368
369
370
371
372
    cutlass::arch::NamedBarrier::arrive(
        NumMmaThreads,
        (cutlass::canonical_warp_group_idx() <= 1
             ? cutlass::canonical_warp_group_idx() + 1
             : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
    cutlass::arch::NamedBarrier::arrive(
        NumMmaThreads,
        (cutlass::canonical_warp_group_idx() <= 0
             ? cutlass::canonical_warp_group_idx() + 2
             : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
373
374
375
  }
}

376
template <int NumMmaThreads> TL_DEVICE void mma_init() {
377
378
379
380
381
382
383
384
385
386
  static_assert(NumMmaThreads == 256 || NumMmaThreads == 384);
  if (cutlass::canonical_warp_group_idx() > 0) {
    cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 0);
  }
  if constexpr (NumMmaThreads == 384) {
    if (cutlass::canonical_warp_group_idx() > 1) {
      cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 1 /*id*/);
    }
  }
}
387
} // namespace tl