gemm.h 8.89 KB
Newer Older
1
2
3
4
5
6
7
#pragma once

#include "common.h"

namespace tl {

// ref to bitblas/tl/mfma_macro_generator.py::kPack
8
9
10
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool TransposeA,
          bool TransposeB, int kPack, typename A_type, typename B_type,
          typename C_type, typename AccDataType = float>
11
class GemmTensorOp {
12
public:
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
  static constexpr int micro_size_x = 16;
  static constexpr int micro_size_y = 16;
  static constexpr int micro_size_k = 16;

  // This part comes from the Codegen
  static constexpr int M_Tile = M;
  static constexpr int N_Tile = N;
  static constexpr int K_Tile = K;

  static constexpr int block_row_warps = num_warp_m;
  static constexpr int block_col_warps = num_warp_n;

  static constexpr int inner_k = K_Tile / (micro_size_k * kPack);
  static constexpr int warp_rows = M_Tile / (block_row_warps * micro_size_x);
  static constexpr int warp_cols = N_Tile / (block_col_warps * micro_size_y);

29
30
  // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen
  // part.
31
32
33
34
35
36
37
38
  static constexpr bool kPadA = true;
  static constexpr bool kPadB = true;
  static constexpr bool kPadC = true;

  static constexpr int BANK_SIZE_BYTES = 128;

  static constexpr int warp_size = 64;

39
40
41
42
  TL_DEVICE static constexpr auto reverse_index_map(int thread_id,
                                                    int local_id) {
    return std::make_pair(thread_id % 16,
                          (thread_id / 16) * (4 * kPack) + local_id);
43
44
  }

45
46
47
48
  TL_DEVICE static constexpr auto reverse_index_map_transposed(int thread_id,
                                                               int local_id) {
    return std::make_pair((thread_id / 16) * (4 * kPack) + local_id,
                          thread_id % 16);
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
  }

  /*
   * Detailed Implementation please
   * checkout bitblas/tl/utils.py:get_swizzle_layout
   */
  template <int continuous = 32, int element_size = 2>
  TL_DEVICE static auto make_mfma_swizzle_layout(const int row, const int col) {
    const auto dtype_bits = element_size * 8;

    const int numBanks = 32;
    const int bankBitWidth = 32;
    const int SIMDWidth = 16;
    const int vecSize = 4 * kPack;
    const int innerDimLength = continuous;
    const int typeWidthInBit = dtype_bits;

    const int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit;
    const int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
68
69
    const int maxPhase =
        std::min(SIMDWidth / perPhase, innerDimLength / vecSize);
70
71
72
73
74
75
76
77
78
79

    const int phase = (row / perPhase) % maxPhase;
    const int colOffSwizzled = (((col / vecSize) ^ phase) * vecSize);
    const int colOffOrdered = col % vecSize;
    const int colOff = colOffSwizzled + colOffOrdered;

    return std::make_pair(row, colOff);
  }

  template <int continuous = 32, int element_size = 2>
80
81
  TL_DEVICE static constexpr auto make_layout_padded(const int row,
                                                     const int col) {
82
83
84
85
    return std::make_pair(row, col);
  }

  template <int continuous = 32, int element_size = 2>
86
87
  TL_DEVICE static constexpr auto make_swizzle_layout(const int row,
                                                      const int col) {
88
89
90
    constexpr auto vector_size = BANK_SIZE_BYTES / (element_size * 8);

    if (continuous % (vector_size * 4) == 0) {
91
92
      auto [n_row, n_col] =
          make_mfma_swizzle_layout<continuous, element_size>(row, col);
93
94
95
96
97
98
99
100
101
102
      return n_row * continuous + n_col;
    } else {
      auto [n_row, n_col] = make_layout_padded(row, col);
      int padded = continuous;
      if ((element_size * 8 * continuous) % 256 == 0)
        padded += BANK_SIZE_BYTES / (element_size * 8);
      return n_row * padded + n_col;
    }
  }

103
104
  static TL_DEVICE void body(A_type *A_shared, B_type *B_shared,
                             C_type *C_local) {
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    auto tid = threadIdx.x;
    auto warp_id = tid / warp_size;
    auto warp_n = warp_id / block_row_warps;
    auto warp_m = warp_id % block_row_warps;
    auto warp_row_tiles = warp_rows * micro_size_x;
    auto warp_col_tiles = warp_cols * micro_size_y;

    auto lane_id = tid % warp_size;
    auto tx = lane_id;

    constexpr auto local_size_a = (micro_size_x * micro_size_k) / warp_size;
    constexpr auto local_size_b = (micro_size_y * micro_size_k) / warp_size;
    constexpr auto local_size_c = (micro_size_x * micro_size_y) / warp_size;

    constexpr auto last_dim_a = TransposeA ? M_Tile : K_Tile;
    constexpr auto last_dim_b = TransposeB ? K_Tile : N_Tile;

    A_type A_local[warp_rows * kPack * local_size_a];
    B_type B_local[warp_cols * kPack * local_size_b];

    for (int ki = 0; ki < inner_k; ki++) {
      // Fetch A into register
      for (int i = 0; i < warp_rows; i++) {
        const auto l = warp_m * warp_row_tiles + i * micro_size_x;
        const auto r = ki * (kPack * micro_size_k);
        for (int local_id = 0; local_id < (kPack * local_size_a); local_id++) {
          auto [row, col] = reverse_index_map(lane_id, local_id);
          A_local[i * kPack * local_size_a + local_id] =
133
134
              A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>(
                  l + row, r + col)];
135
136
137
138
139
140
141
142
143
144
        }
      }

      // Fetch B into register
      for (int j = 0; j < warp_cols; j++) {
        const auto l = warp_n * warp_col_tiles + j * micro_size_y;
        const auto r = ki * (kPack * micro_size_k);
        for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) {
          auto [row, col] = reverse_index_map(lane_id, local_id);
          B_local[j * kPack * local_size_b + local_id] =
145
146
              B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
                  l + row, r + col)];
147
148
149
150
151
152
153
        }
      }

      // Compute
      for (int kp = 0; kp < kPack; kp++) {
        for (int i = 0; i < warp_rows; ++i) {
          for (int j = 0; j < warp_cols; ++j) {
154
155
156
157
158
            *(((float32x4 *)C_local) + ((i * warp_cols) + j)) =
                __builtin_amdgcn_mfma_f32_16x16x16f16(
                    *(((float16x4 *)B_local) + j * kPack + kp),
                    *(((float16x4 *)A_local) + i * kPack + kp),
                    *(((float32x4 *)C_local) + ((i * warp_cols) + j)), 0, 0, 0);
159
160
161
162
163
164
          }
        }
      }
    }
  }

165
166
  static TL_DEVICE void body_rs(A_type *A_local, B_type *B_shared,
                                C_type *C_local) {
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    auto tid = threadIdx.x;
    auto warp_id = tid / warp_size;
    auto warp_n = warp_id / block_row_warps;
    auto warp_m = warp_id % block_row_warps;
    auto warp_row_tiles = warp_rows * micro_size_x;
    auto warp_col_tiles = warp_cols * micro_size_y;

    auto lane_id = tid % warp_size;
    auto tx = lane_id;

    constexpr auto local_size_a = (micro_size_x * micro_size_k) / warp_size;
    constexpr auto local_size_b = (micro_size_y * micro_size_k) / warp_size;
    constexpr auto local_size_c = (micro_size_x * micro_size_y) / warp_size;

    constexpr auto last_dim_a = TransposeA ? M_Tile : K_Tile;
    constexpr auto last_dim_b = TransposeB ? K_Tile : N_Tile;

    B_type B_local[warp_cols * kPack * local_size_b];

    for (int ki = 0; ki < inner_k; ki++) {
      // Fetch B into register
      for (int j = 0; j < warp_cols; j++) {
        const auto l = warp_n * warp_col_tiles + j * micro_size_y;
        const auto r = ki * kPack * micro_size_k;
        for (int local_id = 0; local_id < kPack * local_size_b; local_id++) {
          auto [row, col] = reverse_index_map(lane_id, local_id);
          B_local[j * local_size_b + local_id] =
194
195
              B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
                  l + row, r + col)];
196
197
198
199
200
201
202
        }
      }

      // Compute
      for (int kp = 0; kp < kPack; kp++) {
        for (int i = 0; i < warp_rows; ++i) {
          for (int j = 0; j < warp_cols; ++j) {
203
204
205
206
207
208
            *(((float32x4 *)C_local) + ((i * warp_cols) + j)) =
                __builtin_amdgcn_mfma_f32_16x16x16f16(
                    *(((float16x4 *)B_local) + j * kPack + kp),
                    *(((float16x4 *)A_local) + ki * warp_rows * kPack +
                      i * kPack + kp),
                    *(((float32x4 *)C_local) + ((i * warp_cols) + j)), 0, 0, 0);
209
210
211
212
213
214
215
          }
        }
      }
    }
  }
};

216
} // namespace tl
217
218
219

namespace tl {

220
221
222
223
224
225
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
          bool trans_B, int kPack, typename A_type, typename B_type,
          typename C_type>
TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
  using Compute = GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
                               trans_B, kPack, A_type, B_type, C_type>;
226
227
228
  Compute::body(pA, pB, accum);
}

229
230
231
232
233
234
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
          bool trans_B, int kPack, typename A_type, typename B_type,
          typename C_type>
TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
  using Compute = GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
                               trans_B, kPack, A_type, B_type, C_type>;
235
236
237
  Compute::body_rs(pA, pB, accum);
}

238
} // namespace tl