gemm_sm90.h 10.3 KB
Newer Older
1
2
3
4
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once

5
6
#include <cute/arch/mma_sm90.hpp>
#include <cute/atom/mma_atom.hpp>
7
8
#include <cutlass/arch/barrier.h>
#include <cutlass/cutlass.h>
9
10
11
12
13

#include "common.h"

namespace cute {

14
15
using namespace SM90;

16
17
18
19
20
21
22
23
24
template <GMMA::Major major, class ElementType, class BLK_MN, class BLK_K>
CUTE_HOST_DEVICE constexpr auto ss_smem_selector() {
  auto BLK_MN0 = size<0>(BLK_MN{});
  auto BLK_K0 = size<0>(BLK_K{});

  static_assert(BLK_MN0 % 8 == 0, "BLK_MN0 must be a multiple of 8.");
  static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8.");

  if constexpr (major == GMMA::Major::MN) {
25
26
27
    if constexpr (BLK_MN0 %
                      size<0>(GMMA::Layout_MN_SW128_Atom<ElementType>{}) ==
                  0) {
28
      return GMMA::Layout_MN_SW128_Atom<ElementType>{};
29
30
31
32
    } else if constexpr (BLK_MN0 %
                             size<0>(
                                 GMMA::Layout_MN_SW64_Atom<ElementType>{}) ==
                         0) {
33
      return GMMA::Layout_MN_SW64_Atom<ElementType>{};
34
35
36
37
    } else if constexpr (BLK_MN0 %
                             size<0>(
                                 GMMA::Layout_MN_SW32_Atom<ElementType>{}) ==
                         0) {
38
      return GMMA::Layout_MN_SW32_Atom<ElementType>{};
39
40
41
42
    } else if constexpr (BLK_MN0 %
                             size<0>(
                                 GMMA::Layout_MN_INTER_Atom<ElementType>{}) ==
                         0) {
43
44
45
46
      return GMMA::Layout_MN_INTER_Atom<ElementType>{};
    } else {
      static_assert(
          BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom<ElementType>{}) == 0,
47
48
          "BLK_MN0 must be a multiple of "
          "size<0>(GMMA::Layout_MN_INTER_Atom<ElementType>{})");
49
50
    }
  } else if constexpr (major == GMMA::Major::K) {
51
52
    if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_Atom<ElementType>{}) ==
                  0) {
53
      return GMMA::Layout_K_SW128_Atom<ElementType>{};
54
55
56
    } else if constexpr (BLK_K0 %
                             size<1>(GMMA::Layout_K_SW64_Atom<ElementType>{}) ==
                         0) {
57
      return GMMA::Layout_K_SW64_Atom<ElementType>{};
58
59
60
    } else if constexpr (BLK_K0 %
                             size<1>(GMMA::Layout_K_SW32_Atom<ElementType>{}) ==
                         0) {
61
      return GMMA::Layout_K_SW32_Atom<ElementType>{};
62
63
64
65
    } else if constexpr (BLK_K0 %
                             size<1>(
                                 GMMA::Layout_K_INTER_Atom<ElementType>{}) ==
                         0) {
66
67
68
69
      return GMMA::Layout_K_INTER_Atom<ElementType>{};
    } else {
      static_assert(
          BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom<ElementType>{}) == 0,
70
71
          "BLK_K0 must be a multiple of "
          "size<1>(GMMA::Layout_K_INTER_Atom<ElementType>{})");
72
73
74
75
    }
  }
}

76
77
78
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
          bool trans_B, typename A_type_raw, typename B_type_raw,
          typename C_type_raw>
79
class GemmTensorOp {
80
81
82
83
84
public:
  using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
                               tfloat32_t, A_type_raw>;
  using B_type = conditional_t<std::is_same<B_type_raw, float>::value,
                               tfloat32_t, B_type_raw>;
85
86
  using C_type = C_type_raw;

87
88
89
90
  static constexpr GMMA::Major GmmaMajorA =
      trans_A ? GMMA::Major::MN : GMMA::Major::K;
  static constexpr GMMA::Major GmmaMajorB =
      trans_B ? GMMA::Major::K : GMMA::Major::MN;
91

92
93
94
95
  using SmemLayoutAtomA =
      decltype(ss_smem_selector<GmmaMajorA, A_type, Int<M>, Int<K>>());
  using SmemLayoutAtomB =
      decltype(ss_smem_selector<GmmaMajorB, B_type, Int<N>, Int<K>>());
96

97
98
99
100
101
102
  using SmemLayoutA = decltype(tile_to_shape(
      SmemLayoutAtomA{}, Shape<Int<M>, Int<K>>{},
      conditional_t<trans_A, Step<_2, _1>, Step<_1, _2>>{}));
  using SmemLayoutB = decltype(tile_to_shape(
      SmemLayoutAtomB{}, Shape<Int<N>, Int<K>>{},
      conditional_t<trans_B, Step<_1, _2>, Step<_2, _1>>{}));
103
104
105
106

  // static_assert(num_warp_n == 1);
  static_assert(num_warp_m % 4 == 0);

107
108
  template <int wg_wait = 0>
  static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
109
    const int tid = threadIdx.x;
110
111
112
113
114
115
116
117
118
    Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                            SmemLayoutA{});
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
    auto tiled_mma = make_tiled_mma(
        GMMA::ss_op_selector<A_type, B_type, C_type,
                             Shape<Int<M>, Int<N / num_warp_n>, Int<K>>,
                             GmmaMajorA, GmmaMajorB>(),
        Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
119
120
121
    auto thr_mma = tiled_mma.get_thread_slice(tid);

    // Allocate registers for pipelining
122
123
    Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
    Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
124

125
126
    Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE)
    Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE)
127

128
129
130
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
131
132
133
134
135
136
137
138
139
140
141
142

    warpgroup_fence_operand(acc);
    warpgroup_arrive();
    CUTLASS_PRAGMA_UNROLL
    for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
      // warpgroup_arrive();
      // (V,M) x (V,N) => (V,M,N)
      gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc);
      tiled_mma.accumulate_ = GMMA::ScaleOut::One;
    }

    warpgroup_commit_batch();
143
144
145
    if constexpr (wg_wait >= 0) {
      warpgroup_wait<wg_wait>();
    }
146
147
148
149
150
151
152
153
154
155
156
    warpgroup_fence_operand(acc);
    // warpgroup_fence_operand(acc);
    // warpgroup_arrive();

    // gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc);

    // warpgroup_commit_batch();
    // if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
    // warpgroup_fence_operand(acc);
  }

157
158
159
  template <int wg_wait = 0>
  static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
160
    // TODO: Move bar.sync out of body_rs
161
162
    // asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(num_warp_m * num_warp_n *
    // 32));
163
    const int tid = threadIdx.x;
164
165
166
167
168
169
170
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
    auto tiled_mma = make_tiled_mma(
        GMMA::rs_op_selector<A_type, B_type, C_type,
                             Shape<Int<M>, Int<N / num_warp_n>, Int<K>>,
                             GmmaMajorA, GmmaMajorB>(),
        Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
171
172
173
    auto thr_mma = tiled_mma.get_thread_slice(tid);

    // Allocate registers for pipelining
174
175
176
177
178
179
180
181
    Tensor tCsB = thr_mma.partition_B(sB);       // (MMA,MMA_N,MMA_K,PIPE)
    Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE)
    Tensor tCrA =
        make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
                    partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
182
183
184
185
186
187
188
189
190
191
192
193

    warpgroup_fence_operand(tCrA);
    warpgroup_fence_operand(acc);
    warpgroup_arrive();
    CUTLASS_PRAGMA_UNROLL
    for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
      // warpgroup_arrive();
      // (V,M) x (V,N) => (V,M,N)
      gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc);
      tiled_mma.accumulate_ = GMMA::ScaleOut::One;
    }
    warpgroup_commit_batch();
194
195
196
    if constexpr (wg_wait >= 0) {
      warpgroup_wait<wg_wait>();
    }
197
198
199
200
201
202
203
204
205
    warpgroup_fence_operand(acc);
    warpgroup_fence_operand(tCrA);

    // warpgroup_fence_operand(acc);
    // warpgroup_arrive();

    // gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc);

    // warpgroup_commit_batch();
206

207
208
209
210
211
    // if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
    // warpgroup_fence_operand(acc);
  }
};

212
} // namespace cute
213
214
215

namespace tl {

216
217
218
219
220
221
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
          bool trans_B, int wg_wait = 0, typename A_type, typename B_type,
          typename C_type>
TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
  using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
                                 trans_B, A_type, B_type, C_type>;
222
223
224
  MMA::body<wg_wait>(pA, pB, accum);
}

225
226
227
228
229
230
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
          bool trans_B, int wg_wait = 0, typename A_type, typename B_type,
          typename C_type>
TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
  using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
                                 trans_B, A_type, B_type, C_type>;
231
232
233
  MMA::body_rs<wg_wait>(pA, pB, accum);
}

234
template <int num_mma> TL_DEVICE void wait_wgmma() {
235
  cute::warpgroup_wait<num_mma>();
236
237
}

238
239
240
template <int NumMmaThreads> TL_DEVICE void warp_scheduler_barrier_sync() {
  cutlass::arch::NamedBarrier::sync(NumMmaThreads,
                                    cutlass::canonical_warp_group_idx() /*id*/);
241
242
}

243
template <int NumMmaThreads> TL_DEVICE void warp_scheduler_barrier_arrive() {
244
245
  static_assert(NumMmaThreads == 256 || NumMmaThreads == 384);
  if constexpr (NumMmaThreads == 256) {
246
247
    cutlass::arch::NamedBarrier::arrive(
        NumMmaThreads, (1 - cutlass::canonical_warp_group_idx()) /*id*/);
248
  } else {
249
250
251
252
253
254
255
256
257
258
    cutlass::arch::NamedBarrier::arrive(
        NumMmaThreads,
        (cutlass::canonical_warp_group_idx() <= 1
             ? cutlass::canonical_warp_group_idx() + 1
             : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
    cutlass::arch::NamedBarrier::arrive(
        NumMmaThreads,
        (cutlass::canonical_warp_group_idx() <= 0
             ? cutlass::canonical_warp_group_idx() + 2
             : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
259
260
261
  }
}

262
template <int NumMmaThreads> TL_DEVICE void mma_init() {
263
264
265
266
267
268
269
270
271
272
  static_assert(NumMmaThreads == 256 || NumMmaThreads == 384);
  if (cutlass::canonical_warp_group_idx() > 0) {
    cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 0);
  }
  if constexpr (NumMmaThreads == 384) {
    if (cutlass::canonical_warp_group_idx() > 1) {
      cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 1 /*id*/);
    }
  }
}
273
} // namespace tl