gemm.h 11.9 KB
Newer Older
1
2
3
#pragma once

#include "common.h"
4
#include <type_traits>
5
6
7

namespace tl {

8
9
10
// Trait to determine the MFMA instruction to use based on data type
template <typename T> struct MfmaTraits;

11
12
13
14
15
16
17
18
19
20
21
22
// Specialization for int8
template <> struct MfmaTraits<int8_t> {
  template <typename AccType>
  static TL_DEVICE void mfma_op(const int8_t *b, const int8_t *a, AccType *c) {
    int64_t *b_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(b));
    int64_t *a_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(a));

    *c = __builtin_amdgcn_mfma_i32_16x16x32_i8(*b_packed, *a_packed, *c, 0, 0,
                                               0);
  }
};

23
24
25
26
27
28
29
30
31
// Specialization for half/float16
template <> struct MfmaTraits<half> {
  template <typename AccType>
  static TL_DEVICE void mfma_op(const half *b, const half *a, AccType *c) {
    *c = __builtin_amdgcn_mfma_f32_16x16x16f16(*((float16x4 *)b),
                                               *((float16x4 *)a), *c, 0, 0, 0);
  }
};

32
33
// Specialization for bfloat16_t
template <> struct MfmaTraits<bfloat16_t> {
34
  template <typename AccType>
35
36
  static TL_DEVICE void mfma_op(const bfloat16_t *b, const bfloat16_t *a,
                                AccType *c) {
37
38
39
    bfloat16x4_vec b_vec, a_vec;

    // Reinterpret the pointers
40
41
    short *b_short = reinterpret_cast<short *>(const_cast<bfloat16_t *>(b));
    short *a_short = reinterpret_cast<short *>(const_cast<bfloat16_t *>(a));
42
43
44
45
46
47
48
49
50
51
52
53

    // Copy the data
    for (int i = 0; i < 4; ++i) {
      b_vec[i] = b_short[i];
      a_vec[i] = a_short[i];
    }

    // Call the intrinsic and store the result directly to c
    *c = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(b_vec, a_vec, *c, 0, 0, 0);
  }
};

54
55
56
57
58
59
60
61
62
63
64
65
66
#if defined(HIP_FP8_ENABLED)
// Specialization for fp8_e4_t
template <> struct MfmaTraits<fp8_e4_t> {
  template <typename AccType>
  static TL_DEVICE void mfma_op(const fp8_e4_t *b, const fp8_e4_t *a,
                                AccType *c) {
    int64_t a_val = *reinterpret_cast<const int64_t *>(a);
    int64_t b_val = *reinterpret_cast<const int64_t *>(b);
    *c = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(b_val, a_val, *c, 0, 0, 0);
  }
};
#endif

67
// ref to bitblas/tl/mfma_macro_generator.py::kPack
68
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool TransposeA,
69
70
          bool TransposeB, bool clear_accum, int kPack, typename A_type,
          typename B_type, typename C_type, typename AccDataType = float>
71
class GemmTensorOp {
72
public:
73
74
  static_assert(!clear_accum, "clear_accum=true is not supported yet");

75
76
  static constexpr int micro_size_x = 16;
  static constexpr int micro_size_y = 16;
77
78
  static constexpr int micro_size_k = 32 / sizeof(A_type);
  static constexpr int vec_size = 8 / sizeof(A_type);
79
80
81
82
83
84
85
86
87
88
89
90
91

  // This part comes from the Codegen
  static constexpr int M_Tile = M;
  static constexpr int N_Tile = N;
  static constexpr int K_Tile = K;

  static constexpr int block_row_warps = num_warp_m;
  static constexpr int block_col_warps = num_warp_n;

  static constexpr int inner_k = K_Tile / (micro_size_k * kPack);
  static constexpr int warp_rows = M_Tile / (block_row_warps * micro_size_x);
  static constexpr int warp_cols = N_Tile / (block_col_warps * micro_size_y);

92
93
  // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen
  // part.
94
95
96
97
98
99
100
101
  static constexpr bool kPadA = true;
  static constexpr bool kPadB = true;
  static constexpr bool kPadC = true;

  static constexpr int BANK_SIZE_BYTES = 128;

  static constexpr int warp_size = 64;

102
103
104
  TL_DEVICE static constexpr auto reverse_index_map(int thread_id,
                                                    int local_id) {
    return std::make_pair(thread_id % 16,
105
                          (thread_id / 16) * (vec_size * kPack) + local_id);
106
107
  }

108
109
  TL_DEVICE static constexpr auto reverse_index_map_transposed(int thread_id,
                                                               int local_id) {
110
    return std::make_pair((thread_id / 16) * (vec_size * kPack) + local_id,
111
                          thread_id % 16);
112
113
114
115
116
117
118
119
120
121
122
123
124
  }

  /*
   * Detailed Implementation please
   * checkout bitblas/tl/utils.py:get_swizzle_layout
   */
  template <int continuous = 32, int element_size = 2>
  TL_DEVICE static auto make_mfma_swizzle_layout(const int row, const int col) {
    const auto dtype_bits = element_size * 8;

    const int numBanks = 32;
    const int bankBitWidth = 32;
    const int SIMDWidth = 16;
125
    const int vecSize = vec_size * kPack;
126
127
128
129
130
    const int innerDimLength = continuous;
    const int typeWidthInBit = dtype_bits;

    const int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit;
    const int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
131
132
    const int maxPhase =
        std::min(SIMDWidth / perPhase, innerDimLength / vecSize);
133
134
135
136
137
138
139
140
141
142

    const int phase = (row / perPhase) % maxPhase;
    const int colOffSwizzled = (((col / vecSize) ^ phase) * vecSize);
    const int colOffOrdered = col % vecSize;
    const int colOff = colOffSwizzled + colOffOrdered;

    return std::make_pair(row, colOff);
  }

  template <int continuous = 32, int element_size = 2>
143
144
  TL_DEVICE static constexpr auto make_layout_padded(const int row,
                                                     const int col) {
145
146
147
148
    return std::make_pair(row, col);
  }

  template <int continuous = 32, int element_size = 2>
149
150
  TL_DEVICE static constexpr auto make_swizzle_layout(const int row,
                                                      const int col) {
151
152
153
    auto [n_row, n_col] =
        make_mfma_swizzle_layout<continuous, element_size>(row, col);
    return n_row * continuous + n_col;
154
155
  }

156
157
  static TL_DEVICE void body(A_type *A_shared, B_type *B_shared,
                             C_type *C_local) {
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    auto tid = threadIdx.x;
    auto warp_id = tid / warp_size;
    auto warp_n = warp_id / block_row_warps;
    auto warp_m = warp_id % block_row_warps;
    auto warp_row_tiles = warp_rows * micro_size_x;
    auto warp_col_tiles = warp_cols * micro_size_y;

    auto lane_id = tid % warp_size;
    auto tx = lane_id;

    constexpr auto local_size_a = (micro_size_x * micro_size_k) / warp_size;
    constexpr auto local_size_b = (micro_size_y * micro_size_k) / warp_size;
    constexpr auto local_size_c = (micro_size_x * micro_size_y) / warp_size;

    constexpr auto last_dim_a = TransposeA ? M_Tile : K_Tile;
    constexpr auto last_dim_b = TransposeB ? K_Tile : N_Tile;

    A_type A_local[warp_rows * kPack * local_size_a];
    B_type B_local[warp_cols * kPack * local_size_b];

    for (int ki = 0; ki < inner_k; ki++) {
      // Fetch A into register
      for (int i = 0; i < warp_rows; i++) {
        const auto l = warp_m * warp_row_tiles + i * micro_size_x;
        const auto r = ki * (kPack * micro_size_k);
        for (int local_id = 0; local_id < (kPack * local_size_a); local_id++) {
184
185
186
187
          if constexpr (TransposeA) {
            auto [row, col] = reverse_index_map_transposed(lane_id, local_id);
            A_local[i * kPack * local_size_a + local_id] =
                A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>(
188
                    r + row, l + col)];
189
190
191
192
193
194
          } else {
            auto [row, col] = reverse_index_map(lane_id, local_id);
            A_local[i * kPack * local_size_a + local_id] =
                A_shared[make_swizzle_layout<last_dim_a, sizeof(A_type)>(
                    l + row, r + col)];
          }
195
196
197
198
199
200
201
        }
      }
      // Fetch B into register
      for (int j = 0; j < warp_cols; j++) {
        const auto l = warp_n * warp_col_tiles + j * micro_size_y;
        const auto r = ki * (kPack * micro_size_k);
        for (int local_id = 0; local_id < (kPack * local_size_b); local_id++) {
202
203
204
205
206
207
208
209
210
211
212
          if constexpr (TransposeB) {
            auto [row, col] = reverse_index_map(lane_id, local_id);
            B_local[j * kPack * local_size_b + local_id] =
                B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
                    l + row, r + col)];
          } else {
            auto [row, col] = reverse_index_map_transposed(lane_id, local_id);
            B_local[j * kPack * local_size_b + local_id] =
                B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
                    r + row, l + col)];
          }
213
214
215
216
217
218
        }
      }
      // Compute
      for (int kp = 0; kp < kPack; kp++) {
        for (int i = 0; i < warp_rows; ++i) {
          for (int j = 0; j < warp_cols; ++j) {
219
            auto acc_ptr = ((float32x4 *)C_local) + ((i * warp_cols) + j);
220
221
            auto b_ptr = ((B_type *)B_local) + (j * kPack + kp) * vec_size;
            auto a_ptr = ((A_type *)A_local) + (i * kPack + kp) * vec_size;
222

223
224
            // Use the trait to select the correct MFMA instruction, either fp8,
            // fp16 or bf16 currently
225
            MfmaTraits<A_type>::mfma_op(b_ptr, a_ptr, acc_ptr);
226
227
228
229
230
231
          }
        }
      }
    }
  }

232
233
  static TL_DEVICE void body_rs(A_type *A_local, B_type *B_shared,
                                C_type *C_local) {
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    auto tid = threadIdx.x;
    auto warp_id = tid / warp_size;
    auto warp_n = warp_id / block_row_warps;
    auto warp_m = warp_id % block_row_warps;
    auto warp_row_tiles = warp_rows * micro_size_x;
    auto warp_col_tiles = warp_cols * micro_size_y;

    auto lane_id = tid % warp_size;
    auto tx = lane_id;

    constexpr auto local_size_a = (micro_size_x * micro_size_k) / warp_size;
    constexpr auto local_size_b = (micro_size_y * micro_size_k) / warp_size;
    constexpr auto local_size_c = (micro_size_x * micro_size_y) / warp_size;

    constexpr auto last_dim_a = TransposeA ? M_Tile : K_Tile;
    constexpr auto last_dim_b = TransposeB ? K_Tile : N_Tile;

    B_type B_local[warp_cols * kPack * local_size_b];

    for (int ki = 0; ki < inner_k; ki++) {
      // Fetch B into register
      for (int j = 0; j < warp_cols; j++) {
        const auto l = warp_n * warp_col_tiles + j * micro_size_y;
        const auto r = ki * kPack * micro_size_k;
        for (int local_id = 0; local_id < kPack * local_size_b; local_id++) {
259
260
          if constexpr (TransposeB) {
            auto [row, col] = reverse_index_map(lane_id, local_id);
261
            B_local[j * kPack * local_size_b + local_id] =
262
263
264
265
                B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
                    l + row, r + col)];
          } else {
            auto [row, col] = reverse_index_map_transposed(lane_id, local_id);
266
            B_local[j * kPack * local_size_b + local_id] =
267
268
269
                B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
                    r + row, l + col)];
          }
270
271
272
273
274
275
276
        }
      }

      // Compute
      for (int kp = 0; kp < kPack; kp++) {
        for (int i = 0; i < warp_rows; ++i) {
          for (int j = 0; j < warp_cols; ++j) {
277
            auto acc_ptr = ((float32x4 *)C_local) + ((i * warp_cols) + j);
278
            auto b_ptr = ((B_type *)B_local) + (j * kPack + kp) * vec_size;
279
            auto a_ptr = ((A_type *)A_local) +
280
                         (ki * warp_rows * kPack + i * kPack + kp) * vec_size;
281

282
283
            // Use the trait to select the correct MFMA instruction, either fp8,
            // fp16 or bf16 currently
284
            MfmaTraits<A_type>::mfma_op(b_ptr, a_ptr, acc_ptr);
285
286
287
288
289
290
291
          }
        }
      }
    }
  }
};

292
} // namespace tl
293
294
295

namespace tl {

296
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
297
298
          bool trans_B, bool clear_accum, int kPack, typename A_type,
          typename B_type, typename C_type>
299
TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
300
301
302
  using Compute =
      GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B,
                   clear_accum, kPack, A_type, B_type, C_type>;
303
304
305
  Compute::body(pA, pB, accum);
}

306
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
307
308
          bool trans_B, bool clear_accum, int kPack, typename A_type,
          typename B_type, typename C_type>
309
TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
310
311
312
  using Compute =
      GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B,
                   clear_accum, kPack, A_type, B_type, C_type>;
313
314
315
  Compute::body_rs(pA, pB, accum);
}

316
} // namespace tl