"vscode:/vscode.git/clone" did not exist on "63330bffa2d3c4968c36317a875a0c72e5deaa8d"
gemm_sm80.h 14.4 KB
Newer Older
1
2
3
4
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once

5
6
7
#include <cute/arch/mma_sm80.hpp>
#include <cute/atom/mma_atom.hpp>
#include <cute/underscore.hpp>
8
9
10
11
12

#include "common.h"

namespace cute {

13
14
template <typename A_type, typename B_type, typename C_type, int num_warp_m,
          int num_warp_n>
15
16
struct DispatchInstruction;

17
18
using _X = Underscore;

19
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
20
21
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n> {
22
  using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
23
  using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
24
};
25
26
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n> {
27
  using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
28
  using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
29
};
30
31
32
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m,
                           num_warp_n> {
33
  using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
34
  using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
35
};
36
37
38
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m,
                           num_warp_n> {
39
  using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
40
  using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
41
};
42
43
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n> {
44
  using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
45
  using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
46
};
47
48
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n> {
49
  using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
50
  using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>;
51
52
};
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
53
54
template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n> {
55
  using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
56
  using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _16>;
57
58
59
60
61
62
63
};
#endif

template <int Bits, int N, int K, bool K_inner, typename Enable = void>
struct OperandTraits {
  // Primary template, use padded layout and default copy
  static constexpr int stride = K_inner ? K : N;
64
65
66
67
68
  static constexpr int padded =
      stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride;
  using Layout = typename std::conditional<
      K_inner, Layout<Shape<Int<N>, Int<K>>, Shape<Int<padded>, _1>>,
      Layout<Shape<Int<N>, Int<K>>, Shape<_1, Int<padded>>>>::type;
69
70
71
72
  using Copy = DefaultCopy;
};

template <int N, int K>
73
74
75
76
struct OperandTraits<16, N, K, true,
                     typename std::enable_if<K % 64 == 32>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
77
78
79
80
81
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
  using Copy = SM75_U32x4_LDSM_N;
};

template <int N, int K>
82
83
84
85
struct OperandTraits<16, N, K, true,
                     typename std::enable_if<K % 64 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
86
87
88
89
90
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
  using Copy = SM75_U32x4_LDSM_N;
};

template <int N, int K>
91
92
93
94
95
96
struct OperandTraits<16, N, K, false,
                     typename std::enable_if<N % 64 == 32>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
97
98
99
100
  using Copy = SM75_U16x8_LDSM_T;
};

template <int N, int K>
101
102
103
104
105
106
struct OperandTraits<16, N, K, false,
                     typename std::enable_if<N % 64 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
107
108
109
110
  using Copy = SM75_U16x8_LDSM_T;
};

template <int N, int K>
111
112
113
114
struct OperandTraits<32, N, K, true,
                     typename std::enable_if<K % 32 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
115
116
117
118
119
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
  using Copy = SM75_U32x4_LDSM_N;
};

template <int N, int K>
120
121
122
123
struct OperandTraits<32, N, K, true,
                     typename std::enable_if<K % 32 == 16>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
124
125
126
127
128
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
  using Copy = SM75_U32x4_LDSM_N;
};

template <int N, int K>
129
130
131
132
133
134
struct OperandTraits<32, N, K, false,
                     typename std::enable_if<N % 32 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
135
136
137
138
  using Copy = UniversalCopy<tfloat32_t>;
};

template <int N, int K>
139
140
141
142
143
144
struct OperandTraits<32, N, K, false,
                     typename std::enable_if<N % 32 == 16>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
145
146
147
148
  using Copy = UniversalCopy<tfloat32_t>;
};

template <int N, int K>
149
150
151
152
struct OperandTraits<8, N, K, true,
                     typename std::enable_if<K % 128 == 64>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
153
154
155
156
157
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
  using Copy = SM75_U32x4_LDSM_N;
};

template <int N, int K>
158
159
160
161
struct OperandTraits<8, N, K, true,
                     typename std::enable_if<K % 128 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
162
163
164
165
166
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
  using Copy = SM75_U32x4_LDSM_N;
};

template <int N, int K>
167
168
169
170
struct OperandTraits<64, N, K, true,
                     typename std::enable_if<K % 16 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{}));
171
172
173
174
175
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
  using Copy = DefaultCopy;
};

template <int N, int K>
176
177
178
179
180
181
struct OperandTraits<64, N, K, false,
                     typename std::enable_if<N % 16 == 0>::type> {
  using LayoutAtom = decltype(composition(
      Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{}));
  using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{},
                                        Step<_2, _1>{}));
182
183
184
  using Copy = DefaultCopy;
};

185
186
187
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
          bool trans_B, typename A_type_raw, typename B_type_raw,
          typename C_type_raw>
188
class GemmTensorOp {
189
190
191
192
193
194
195
public:
  using A_type =
      typename std::conditional<std::is_same<A_type_raw, float>::value,
                                tfloat32_t, A_type_raw>::type;
  using B_type =
      typename std::conditional<std::is_same<B_type_raw, float>::value,
                                tfloat32_t, A_type_raw>::type;
196
  using C_type = C_type_raw;
197
198
  using Instruction =
      DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n>;
199

200
201
202
203
  using OperandATraits =
      OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A>;
  using OperandBTraits =
      OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B>;
204
205
206
207
208
  using SmemLayoutA = typename OperandATraits::Layout;
  using SmemLayoutB = typename OperandBTraits::Layout;
  using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>;
  using SmemCopyB = Copy_Atom<typename OperandBTraits::Copy, B_type>;

209
210
211
  using TileMma = TiledMMA<typename Instruction::MMA,
                           Layout<Shape<Int<num_warp_m>, Int<num_warp_n>, _1>>,
                           typename Instruction::MMA_Group>;
212
213

  template <class... Args>
214
  static CUTE_DEVICE auto remove_swizzle(Layout<Args...> const &layout) {
215
216
217
218
219
    return layout;
  }
  // In fp16, when layout is KxN and n_warp is 1 and N % 64 == 0
  // the original layout fail to compile, currently using this as a workaround
  template <class... Args>
220
221
  static CUTE_DEVICE auto
  remove_swizzle(ComposedLayout<Args...> const &layout) {
222
223
224
225
226
227
    if constexpr (sizeof(A_type) == 2)
      return layout.layout_b();
    else
      return layout;
  }

228
  static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) {
229
    const int tid = threadIdx.x;
230
231
232
233
    Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                            SmemLayoutA{});
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
    auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
    auto thr_copy_A = tiled_copy_A.get_thread_slice(tid);
    auto thr_copy_B = tiled_copy_B.get_thread_slice(tid);

    Tensor tCrA = thr_mma.partition_fragment_A(sA);
    Tensor tCrB = thr_mma.partition_fragment_B(sB);
    Tensor tCsA = thr_copy_A.partition_S(sA);
    Tensor tCsB = thr_copy_B.partition_S(sB);

    Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
    Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);

249
250
251
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
252

253
254
    // when layout is KxN and n_warp is 1, there seem to be a bug, use this as a
    // workaround
255
256
257
258
259
260
261
262
263
264
    auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
    auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
      copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k));
      gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc);
    }
  }

265
266
  static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
267
    const int tid = threadIdx.x;
268
269
    Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
                            SmemLayoutB{});
270
271
272
273
274
275
276
277
278
279
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
    auto thr_copy_B = tiled_copy_B.get_thread_slice(tid);

    Tensor tCrB = thr_mma.partition_fragment_B(sB);
    Tensor tCsB = thr_copy_B.partition_S(sB);

    Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);

280
281
282
283
284
285
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
    Tensor tCrA =
        make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
                    partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
286
287
288
289
290
291
292
293
294
295
296
297

    auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
    copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0));
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      if (k < size<2>(tCrA) - 1) {
        copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1));
      }
      gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc);
    }
  }

298
299
  static CUTE_DEVICE void body_sr(A_type_raw *pA, B_type_raw *pB,
                                  C_type_raw *pC) {
300
    const int tid = threadIdx.x;
301
302
    Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
                            SmemLayoutA{});
303
304
305
306
307
308
309
310
311
312
    TileMma tiled_mma;
    auto thr_mma = tiled_mma.get_thread_slice(tid);
    auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
    auto thr_copy_A = tiled_copy_A.get_thread_slice(tid);

    Tensor tCrA = thr_mma.partition_fragment_A(sA);
    Tensor tCsA = thr_copy_A.partition_S(sA);

    Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);

313
314
315
316
317
318
    Tensor acc =
        make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
                    partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
    Tensor tCrB =
        make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
                    partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
319
320
321
322
323
324
325
326
327
328
329
330
331

    auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
    copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0));
    CUTE_UNROLL
    for (int k = 0; k < size<2>(tCrA); ++k) {
      if (k < size<2>(tCrA) - 1) {
        copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1));
      }
      gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc);
    }
  }
};

332
} // namespace cute
333
334
335

namespace tl {

336
337
338
339
340
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
          bool trans_B, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
  using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
                                 trans_B, A_type, B_type, C_type>;
341
342
343
  MMA::body(pA, pB, accum);
}

344
345
346
347
348
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
          bool trans_B, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
  using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
                                 trans_B, A_type, B_type, C_type>;
349
350
351
  MMA::body_rs(pA, pB, accum);
}

352
353
354
355
356
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
          bool trans_B, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
  using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
                                 trans_B, A_type, B_type, C_type>;
357
358
359
  MMA::body_sr(pA, pB, accum);
}

360
} // namespace tl