gemm_sm90.h 10.3 KB
Newer Older
1
2
3
4
5
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once

#include <cute/algorithm/copy.hpp>
6
7
#include <cutlass/arch/barrier.h>
#include <cutlass/cutlass.h>
8
9
10
11
12
13
14
15
16
17
18
19
20
21

#include "common.h"

namespace cute {

template <GMMA::Major major, class ElementType, class BLK_MN, class BLK_K>
CUTE_HOST_DEVICE constexpr auto ss_smem_selector() {
  auto BLK_MN0 = size<0>(BLK_MN{});
  auto BLK_K0 = size<0>(BLK_K{});

  static_assert(BLK_MN0 % 8 == 0, "BLK_MN0 must be a multiple of 8.");
  static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8.");

  if constexpr (major == GMMA::Major::MN) {
22
23
24
    if constexpr (BLK_MN0 %
                      size<0>(GMMA::Layout_MN_SW128_Atom<ElementType>{}) ==
                  0) {
25
      return GMMA::Layout_MN_SW128_Atom<ElementType>{};
26
27
28
29
    } else if constexpr (BLK_MN0 %
                             size<0>(
                                 GMMA::Layout_MN_SW64_Atom<ElementType>{}) ==
                         0) {
30
      return GMMA::Layout_MN_SW64_Atom<ElementType>{};
31
32
33
34
    } else if constexpr (BLK_MN0 %
                             size<0>(
                                 GMMA::Layout_MN_SW32_Atom<ElementType>{}) ==
                         0) {
35
      return GMMA::Layout_MN_SW32_Atom<ElementType>{};
36
37
38
39
    } else if constexpr (BLK_MN0 %
                             size<0>(
                                 GMMA::Layout_MN_INTER_Atom<ElementType>{}) ==
                         0) {
40
41
42
43
      return GMMA::Layout_MN_INTER_Atom<ElementType>{};
    } else {
      static_assert(
          BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom<ElementType>{}) == 0,
44
45
          "BLK_MN0 must be a multiple of "
          "size<0>(GMMA::Layout_MN_INTER_Atom<ElementType>{})");
46
47
    }
  } else if constexpr (major == GMMA::Major::K) {
48
49
    if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_Atom<ElementType>{}) ==
                  0) {
50
      return GMMA::Layout_K_SW128_Atom<ElementType>{};
51
52
53
    } else if constexpr (BLK_K0 %
                             size<1>(GMMA::Layout_K_SW64_Atom<ElementType>{}) ==
                         0) {
54
      return GMMA::Layout_K_SW64_Atom<ElementType>{};
55
56
57
    } else if constexpr (BLK_K0 %
                             size<1>(GMMA::Layout_K_SW32_Atom<ElementType>{}) ==
                         0) {
58
      return GMMA::Layout_K_SW32_Atom<ElementType>{};
59
60
61
62
    } else if constexpr (BLK_K0 %
                             size<1>(
                                 GMMA::Layout_K_INTER_Atom<ElementType>{}) ==
                         0) {
63
64
65
66
      return GMMA::Layout_K_INTER_Atom<ElementType>{};
    } else {
      static_assert(
          BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom<ElementType>{}) == 0,
67
68
          "BLK_K0 must be a multiple of "
          "size<1>(GMMA::Layout_K_INTER_Atom<ElementType>{})");
69
70
71
72
    }
  }
}

73
74
75
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
          bool trans_B, typename A_type_raw, typename B_type_raw,
          typename C_type_raw>
76
class GemmTensorOp {
77
78
79
80
81
public:
  using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
                               tfloat32_t, A_type_raw>;
  using B_type = conditional_t<std::is_same<B_type_raw, float>::value,
                               tfloat32_t, B_type_raw>;
82
83
  using C_type = C_type_raw;

84
85
86
87
  static constexpr GMMA::Major GmmaMajorA =
      trans_A ? GMMA::Major::MN : GMMA::Major::K;
  static constexpr GMMA::Major GmmaMajorB =
      trans_B ? GMMA::Major::K : GMMA::Major::MN;
88

89
90
91
92
  using SmemLayoutAtomA =
      decltype(ss_smem_selector<GmmaMajorA, A_type, Int<M>, Int<K>>());
  using SmemLayoutAtomB =
      decltype(ss_smem_selector<GmmaMajorB, B_type, Int<N>, Int<K>>());
93

94
95
96
97
98
99
  using SmemLayoutA = decltype(tile_to_shape(
      SmemLayoutAtomA{}, Shape<Int<M>, Int<K>>{},
      conditional_t<trans_A, Step<_2, _1>, Step<_1, _2>>{}));
  using SmemLayoutB = decltype(tile_to_shape(
      SmemLayoutAtomB{}, Shape<Int<N>, Int<K>>{},
      conditional_t<trans_B, Step<_1, _2>, Step<_2, _1>>{}));
100
101
102
103

  // static_assert(num_warp_n == 1);
  static_assert(num_warp_m % 4 == 0);

104
105
  template <int wg_wait = 0>
  static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
106
    const int tid = threadIdx.x;
107
108
109
110
111
112
113
114
115
    Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                            SmemLayoutA{});
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
    auto tiled_mma = make_tiled_mma(
        GMMA::ss_op_selector<A_type, B_type, C_type,
                             Shape<Int<M>, Int<N / num_warp_n>, Int<K>>,
                             GmmaMajorA, GmmaMajorB>(),
        Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
116
117
118
    auto thr_mma = tiled_mma.get_thread_slice(tid);

    // Allocate registers for pipelining
119
120
    Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
    Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
121

122
123
    Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE)
    Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE)
124

125
126
127
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
128
129
130
131
132
133
134
135
136
137
138
139

    warpgroup_fence_operand(acc);
    warpgroup_arrive();
    CUTLASS_PRAGMA_UNROLL
    for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
      // warpgroup_arrive();
      // (V,M) x (V,N) => (V,M,N)
      gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc);
      tiled_mma.accumulate_ = GMMA::ScaleOut::One;
    }

    warpgroup_commit_batch();
140
141
142
    if constexpr (wg_wait >= 0) {
      warpgroup_wait<wg_wait>();
    }
143
144
145
146
147
148
149
150
151
152
153
    warpgroup_fence_operand(acc);
    // warpgroup_fence_operand(acc);
    // warpgroup_arrive();

    // gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc);

    // warpgroup_commit_batch();
    // if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
    // warpgroup_fence_operand(acc);
  }

154
155
156
  template <int wg_wait = 0>
  static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
157
    // TODO: Move bar.sync out of body_rs
158
159
    // asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(num_warp_m * num_warp_n *
    // 32));
160
    const int tid = threadIdx.x;
161
162
163
164
165
166
167
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
    auto tiled_mma = make_tiled_mma(
        GMMA::rs_op_selector<A_type, B_type, C_type,
                             Shape<Int<M>, Int<N / num_warp_n>, Int<K>>,
                             GmmaMajorA, GmmaMajorB>(),
        Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
168
169
170
    auto thr_mma = tiled_mma.get_thread_slice(tid);

    // Allocate registers for pipelining
171
172
173
174
175
176
177
178
    Tensor tCsB = thr_mma.partition_B(sB);       // (MMA,MMA_N,MMA_K,PIPE)
    Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE)
    Tensor tCrA =
        make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
                    partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
179
180
181
182
183
184
185
186
187
188
189
190

    warpgroup_fence_operand(tCrA);
    warpgroup_fence_operand(acc);
    warpgroup_arrive();
    CUTLASS_PRAGMA_UNROLL
    for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
      // warpgroup_arrive();
      // (V,M) x (V,N) => (V,M,N)
      gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc);
      tiled_mma.accumulate_ = GMMA::ScaleOut::One;
    }
    warpgroup_commit_batch();
191
192
193
    if constexpr (wg_wait >= 0) {
      warpgroup_wait<wg_wait>();
    }
194
195
196
197
198
199
200
201
202
    warpgroup_fence_operand(acc);
    warpgroup_fence_operand(tCrA);

    // warpgroup_fence_operand(acc);
    // warpgroup_arrive();

    // gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc);

    // warpgroup_commit_batch();
203

204
205
206
207
208
    // if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
    // warpgroup_fence_operand(acc);
  }
};

209
} // namespace cute
210
211
212

namespace tl {

213
214
215
216
217
218
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
          bool trans_B, int wg_wait = 0, typename A_type, typename B_type,
          typename C_type>
TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
  using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
                                 trans_B, A_type, B_type, C_type>;
219
220
221
  MMA::body<wg_wait>(pA, pB, accum);
}

222
223
224
225
226
227
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
          bool trans_B, int wg_wait = 0, typename A_type, typename B_type,
          typename C_type>
TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
  using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
                                 trans_B, A_type, B_type, C_type>;
228
229
230
  MMA::body_rs<wg_wait>(pA, pB, accum);
}

231
template <int num_mma> TL_DEVICE void wait_wgmma() {
232
  cute::warpgroup_wait<num_mma>();
233
234
}

235
236
237
template <int NumMmaThreads> TL_DEVICE void warp_scheduler_barrier_sync() {
  cutlass::arch::NamedBarrier::sync(NumMmaThreads,
                                    cutlass::canonical_warp_group_idx() /*id*/);
238
239
}

240
template <int NumMmaThreads> TL_DEVICE void warp_scheduler_barrier_arrive() {
241
242
  static_assert(NumMmaThreads == 256 || NumMmaThreads == 384);
  if constexpr (NumMmaThreads == 256) {
243
244
    cutlass::arch::NamedBarrier::arrive(
        NumMmaThreads, (1 - cutlass::canonical_warp_group_idx()) /*id*/);
245
  } else {
246
247
248
249
250
251
252
253
254
255
    cutlass::arch::NamedBarrier::arrive(
        NumMmaThreads,
        (cutlass::canonical_warp_group_idx() <= 1
             ? cutlass::canonical_warp_group_idx() + 1
             : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
    cutlass::arch::NamedBarrier::arrive(
        NumMmaThreads,
        (cutlass::canonical_warp_group_idx() <= 0
             ? cutlass::canonical_warp_group_idx() + 2
             : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
256
257
258
  }
}

259
template <int NumMmaThreads> TL_DEVICE void mma_init() {
260
261
262
263
264
265
266
267
268
269
  static_assert(NumMmaThreads == 256 || NumMmaThreads == 384);
  if (cutlass::canonical_warp_group_idx() > 0) {
    cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 0);
  }
  if constexpr (NumMmaThreads == 384) {
    if (cutlass::canonical_warp_group_idx() > 1) {
      cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 1 /*id*/);
    }
  }
}
270
} // namespace tl