gemm_sm70.h 7.08 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
#pragma once

#include <cutlass/cutlass.h>
#include <cutlass/gemm/warp/mma_tensor_op_sm70.h>

#include "common.h"

using cutlass::gemm::GemmShape;

// Primary template
// Add 128 bits padding when the last dim is a multiple of 256 bits
template <typename T, bool transpose, int M, int K, typename Enable = void>
struct DispatchSharedMemoryLayoutA {
14
15
16
  using Layout =
      typename std::conditional<transpose, cutlass::layout::ColumnMajor,
                                cutlass::layout::RowMajor>::type;
17
  static int constexpr Dim = transpose ? M : K;
18
19
  static int constexpr Stride =
      (Dim * sizeof(T) % 32 == 0) ? Dim + 16 / sizeof(T) : Dim;
20
21
22
};
template <typename T, bool transpose, int N, int K, typename Enable = void>
struct DispatchSharedMemoryLayoutB {
23
24
25
  using Layout =
      typename std::conditional<transpose, cutlass::layout::ColumnMajor,
                                cutlass::layout::RowMajor>::type;
26
  static int constexpr Dim = transpose ? K : N;
27
28
  static int constexpr Stride =
      (Dim * sizeof(T) % 32 == 0) ? Dim + 16 / sizeof(T) : Dim;
29
30
31
32
};

// Partial specialization for half_t
template <int M, int K>
33
34
35
36
struct DispatchSharedMemoryLayoutA<half_t, true, M, K,
                                   typename std::enable_if<M % 64 == 0>::type> {
  using Layout =
      cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous<16>;
37
38
39
40
41
  static int constexpr Stride = M;
};

template <int M, int K>
struct DispatchSharedMemoryLayoutA<half_t, false, M, K> {
42
43
  using Layout =
      cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, K>;
44
45
46
  static int constexpr Stride = M;
};

47
48
49
template <int N, int K> struct DispatchSharedMemoryLayoutB<half_t, true, N, K> {
  using Layout =
      cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise<16, K>;
50
51
52
53
54
55
  static int constexpr Stride = N;
};

template <int N, int K>
struct DispatchSharedMemoryLayoutB<half_t, false, N, K,
                                   typename std::enable_if<N % 64 == 0>::type> {
56
57
  using Layout =
      cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous<16>;
58
59
60
  static int constexpr Stride = N;
};

61
template <typename Shape, int num_warp_m, int num_warp_n, bool trans_A,
62
63
          bool trans_B, bool clear_accum, typename A_type_raw,
          typename B_type_raw, typename C_type_raw>
64
class GemmTensorOp {
65
public:
66
67
68
69
70
  using A_type = A_type_raw;
  using B_type = B_type_raw;
  using C_type = C_type_raw;
  using InstructionShape = GemmShape<16, 16, 4>;
  using SMemLayoutA =
71
72
      typename DispatchSharedMemoryLayoutA<A_type, trans_A, Shape::kM,
                                           Shape::kK>::Layout;
73
  using SMemLayoutB =
74
75
      typename DispatchSharedMemoryLayoutB<B_type, trans_B, Shape::kN,
                                           Shape::kK>::Layout;
76
  static constexpr int stride_A =
77
78
      DispatchSharedMemoryLayoutA<A_type, trans_A, Shape::kM,
                                  Shape::kK>::Stride;
79
  static constexpr int stride_B =
80
81
      DispatchSharedMemoryLayoutB<B_type, trans_B, Shape::kN,
                                  Shape::kK>::Stride;
82
83

  using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
84
85
86
87
88
89
90
91
92
      cutlass::arch::Mma<
          InstructionShape, 32, A_type,
          typename std::conditional<trans_A, cutlass::layout::ColumnMajor,
                                    cutlass::layout::RowMajor>::type,
          B_type,
          typename std::conditional<trans_B, cutlass::layout::ColumnMajor,
                                    cutlass::layout::RowMajor>::type,
          C_type, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>,
      cutlass::MatrixShape<1, 1>>;
93
94
95
96
97

  static_assert(Shape::kM % num_warp_m == 0);
  static_assert(Shape::kN % num_warp_n == 0);

  using MmaWarp = typename cutlass::gemm::warp::MmaVoltaTensorOp<
98
99
100
101
      GemmShape<Shape::kM / num_warp_m, Shape::kN / num_warp_n,
                InstructionShape::kK>,
      A_type, SMemLayoutA, B_type, SMemLayoutB, C_type,
      cutlass::layout::RowMajor, Policy>;
102
103
104
105
106
107
108
109
110
111
112
113

  using TensorRefA = typename MmaWarp::IteratorA::TensorRef;
  using TensorRefB = typename MmaWarp::IteratorB::TensorRef;
  using FragmentA = typename MmaWarp::FragmentA;
  using FragmentB = typename MmaWarp::FragmentB;
  using FragmentC = typename MmaWarp::FragmentC;
  using IteratorA = typename MmaWarp::IteratorA;
  using IteratorB = typename MmaWarp::IteratorB;

  static_assert(Shape::kK % InstructionShape::kK == 0);
  static int constexpr kKgroups = Shape::kK / InstructionShape::kK;

114
115
116
  static CUTLASS_DEVICE void body(A_type_raw *pA, B_type_raw *pB,
                                  FragmentC &accum, const int warp_idx_m,
                                  const int warp_idx_n, const int lane_id) {
117
118
119
    MmaWarp mma_op;
    FragmentA frag_A;
    FragmentB frag_B;
120
121
    const TensorRefA ref_A((A_type *)pA, stride_A);
    const TensorRefB ref_B((B_type *)pB, stride_B);
122
123
124
125
    IteratorA iter_A(ref_A, lane_id);
    IteratorB iter_B(ref_B, lane_id);
    iter_A.add_tile_offset({warp_idx_m, 0});
    iter_B.add_tile_offset({0, warp_idx_n});
126
127
128
    if constexpr (clear_accum) {
      accum.clear();
    }
129
130
131
132
133
134
135
136
137
138
    CUTLASS_PRAGMA_UNROLL
    for (int k = 0; k < kKgroups; ++k) {
      iter_A.load(frag_A);
      iter_B.load(frag_B);
      ++iter_A;
      ++iter_B;
      mma_op(accum, frag_A, frag_B, accum);
    }
  }

139
140
141
  static CUTLASS_DEVICE void body_rs(const FragmentA *frag_A, B_type_raw *pB,
                                     FragmentC &accum, const int warp_idx_n,
                                     const int lane_id) {
142
143
    MmaWarp mma_op;
    FragmentB frag_B;
144
    const TensorRefB ref_B((B_type *)pB, stride_B);
145
146
    IteratorB iter_B(ref_B, lane_id);
    iter_B.add_tile_offset({0, warp_idx_n});
147
148
149
    if constexpr (clear_accum) {
      accum.clear();
    }
150
151
152
153
154
155
156
157
158
159
160
    CUTLASS_PRAGMA_UNROLL
    for (int k = 0; k < kKgroups; ++k) {
      iter_B.load(frag_B);
      ++iter_B;
      mma_op(accum, frag_A[k], frag_B, accum);
    }
  }
};

namespace tl {

161
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
162
163
          bool trans_B, bool clear_accum, typename A_type, typename B_type,
          typename C_type>
164
165
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
  using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A,
166
                           trans_B, clear_accum, A_type, B_type, C_type>;
167
168
169
  using FragmentC = typename MMA::FragmentC;
  int warp_id = threadIdx.x / 32;
  int lane_id = threadIdx.x % 32;
170
171
  MMA::body(pA, pB, *(FragmentC *)(accum), warp_id / num_warp_n,
            warp_id % num_warp_n, lane_id);
172
173
}

174
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
175
176
          bool trans_B, bool clear_accum, typename A_type, typename B_type,
          typename C_type>
177
178
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
  using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A,
179
                           trans_B, clear_accum, A_type, B_type, C_type>;
180
181
182
183
  using FragmentA = typename MMA::FragmentA;
  using FragmentC = typename MMA::FragmentC;
  int warp_id = threadIdx.x / 32;
  int lane_id = threadIdx.x % 32;
184
185
  MMA::body_rs((const FragmentA *)(pA), pB, *(FragmentC *)(accum),
               warp_id % num_warp_n, lane_id);
186
187
}

188
}; // namespace tl