gemm_sm70.h 6.94 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once

#include <cutlass/cutlass.h>
#include <cutlass/gemm/warp/mma_tensor_op_sm70.h>

#include "common.h"

using cutlass::gemm::GemmShape;

// Primary template
// Add 128 bits padding when the last dim is a multiple of 256 bits
template <typename T, bool transpose, int M, int K, typename Enable = void>
struct DispatchSharedMemoryLayoutA {
16
17
18
  using Layout =
      typename std::conditional<transpose, cutlass::layout::ColumnMajor,
                                cutlass::layout::RowMajor>::type;
19
  static int constexpr Dim = transpose ? M : K;
20
21
  static int constexpr Stride =
      (Dim * sizeof(T) % 32 == 0) ? Dim + 16 / sizeof(T) : Dim;
22
23
24
};
template <typename T, bool transpose, int N, int K, typename Enable = void>
struct DispatchSharedMemoryLayoutB {
25
26
27
  using Layout =
      typename std::conditional<transpose, cutlass::layout::ColumnMajor,
                                cutlass::layout::RowMajor>::type;
28
  static int constexpr Dim = transpose ? K : N;
29
30
  static int constexpr Stride =
      (Dim * sizeof(T) % 32 == 0) ? Dim + 16 / sizeof(T) : Dim;
31
32
33
34
};

// Partial specialization for half_t
template <int M, int K>
35
36
37
38
struct DispatchSharedMemoryLayoutA<half_t, true, M, K,
                                   typename std::enable_if<M % 64 == 0>::type> {
  using Layout =
      cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous<16>;
39
40
41
42
43
  static int constexpr Stride = M;
};

template <int M, int K>
struct DispatchSharedMemoryLayoutA<half_t, false, M, K> {
44
45
  using Layout =
      cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, K>;
46
47
48
  static int constexpr Stride = M;
};

49
50
51
template <int N, int K> struct DispatchSharedMemoryLayoutB<half_t, true, N, K> {
  using Layout =
      cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise<16, K>;
52
53
54
55
56
57
  static int constexpr Stride = N;
};

template <int N, int K>
struct DispatchSharedMemoryLayoutB<half_t, false, N, K,
                                   typename std::enable_if<N % 64 == 0>::type> {
58
59
  using Layout =
      cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous<16>;
60
61
62
  static int constexpr Stride = N;
};

63
64
65
template <typename Shape, int num_warp_m, int num_warp_n, bool trans_A,
          bool trans_B, typename A_type_raw, typename B_type_raw,
          typename C_type_raw>
66
class GemmTensorOp {
67
public:
68
69
70
71
72
  using A_type = A_type_raw;
  using B_type = B_type_raw;
  using C_type = C_type_raw;
  using InstructionShape = GemmShape<16, 16, 4>;
  using SMemLayoutA =
73
74
      typename DispatchSharedMemoryLayoutA<A_type, trans_A, Shape::kM,
                                           Shape::kK>::Layout;
75
  using SMemLayoutB =
76
77
      typename DispatchSharedMemoryLayoutB<B_type, trans_B, Shape::kN,
                                           Shape::kK>::Layout;
78
  static constexpr int stride_A =
79
80
      DispatchSharedMemoryLayoutA<A_type, trans_A, Shape::kM,
                                  Shape::kK>::Stride;
81
  static constexpr int stride_B =
82
83
      DispatchSharedMemoryLayoutB<B_type, trans_B, Shape::kN,
                                  Shape::kK>::Stride;
84
85

  using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
86
87
88
89
90
91
92
93
94
      cutlass::arch::Mma<
          InstructionShape, 32, A_type,
          typename std::conditional<trans_A, cutlass::layout::ColumnMajor,
                                    cutlass::layout::RowMajor>::type,
          B_type,
          typename std::conditional<trans_B, cutlass::layout::ColumnMajor,
                                    cutlass::layout::RowMajor>::type,
          C_type, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>,
      cutlass::MatrixShape<1, 1>>;
95
96
97
98
99

  static_assert(Shape::kM % num_warp_m == 0);
  static_assert(Shape::kN % num_warp_n == 0);

  using MmaWarp = typename cutlass::gemm::warp::MmaVoltaTensorOp<
100
101
102
103
      GemmShape<Shape::kM / num_warp_m, Shape::kN / num_warp_n,
                InstructionShape::kK>,
      A_type, SMemLayoutA, B_type, SMemLayoutB, C_type,
      cutlass::layout::RowMajor, Policy>;
104
105
106
107
108
109
110
111
112
113
114
115

  using TensorRefA = typename MmaWarp::IteratorA::TensorRef;
  using TensorRefB = typename MmaWarp::IteratorB::TensorRef;
  using FragmentA = typename MmaWarp::FragmentA;
  using FragmentB = typename MmaWarp::FragmentB;
  using FragmentC = typename MmaWarp::FragmentC;
  using IteratorA = typename MmaWarp::IteratorA;
  using IteratorB = typename MmaWarp::IteratorB;

  static_assert(Shape::kK % InstructionShape::kK == 0);
  static int constexpr kKgroups = Shape::kK / InstructionShape::kK;

116
117
118
  static CUTLASS_DEVICE void body(A_type_raw *pA, B_type_raw *pB,
                                  FragmentC &accum, const int warp_idx_m,
                                  const int warp_idx_n, const int lane_id) {
119
120
121
    MmaWarp mma_op;
    FragmentA frag_A;
    FragmentB frag_B;
122
123
    const TensorRefA ref_A((A_type *)pA, stride_A);
    const TensorRefB ref_B((B_type *)pB, stride_B);
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    IteratorA iter_A(ref_A, lane_id);
    IteratorB iter_B(ref_B, lane_id);
    iter_A.add_tile_offset({warp_idx_m, 0});
    iter_B.add_tile_offset({0, warp_idx_n});
    CUTLASS_PRAGMA_UNROLL
    for (int k = 0; k < kKgroups; ++k) {
      iter_A.load(frag_A);
      iter_B.load(frag_B);
      ++iter_A;
      ++iter_B;
      mma_op(accum, frag_A, frag_B, accum);
    }
  }

138
139
140
  static CUTLASS_DEVICE void body_rs(const FragmentA *frag_A, B_type_raw *pB,
                                     FragmentC &accum, const int warp_idx_n,
                                     const int lane_id) {
141
142
    MmaWarp mma_op;
    FragmentB frag_B;
143
    const TensorRefB ref_B((B_type *)pB, stride_B);
144
145
146
147
148
149
150
151
152
153
154
155
156
    IteratorB iter_B(ref_B, lane_id);
    iter_B.add_tile_offset({0, warp_idx_n});
    CUTLASS_PRAGMA_UNROLL
    for (int k = 0; k < kKgroups; ++k) {
      iter_B.load(frag_B);
      ++iter_B;
      mma_op(accum, frag_A[k], frag_B, accum);
    }
  }
};

namespace tl {

157
158
159
160
161
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
          bool trans_B, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
  using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A,
                           trans_B, A_type, B_type, C_type>;
162
163
164
  using FragmentC = typename MMA::FragmentC;
  int warp_id = threadIdx.x / 32;
  int lane_id = threadIdx.x % 32;
165
166
  MMA::body(pA, pB, *(FragmentC *)(accum), warp_id / num_warp_n,
            warp_id % num_warp_n, lane_id);
167
168
}

169
170
171
172
173
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
          bool trans_B, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
  using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A,
                           trans_B, A_type, B_type, C_type>;
174
175
176
177
  using FragmentA = typename MMA::FragmentA;
  using FragmentC = typename MMA::FragmentC;
  int warp_id = threadIdx.x / 32;
  int lane_id = threadIdx.x % 32;
178
179
  MMA::body_rs((const FragmentA *)(pA), pB, *(FragmentC *)(accum),
               warp_id % num_warp_n, lane_id);
180
181
}

182
}; // namespace tl