gemm_layouts.cc 19.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

/*!
 * \file layout/gemm_layouts.cc
 * \brief Define Layout used in MMA and other operations.
 *
 */

#include <tvm/tir/stmt_functor.h>

#include <cmath>

#include "layout.h"

namespace tvm {
namespace tl {

static IterVar make_itervar(std::string name, PrimExpr dom) {
  Var var = Var(name);
  return IterVar(Range(0, dom), var, IterVarType::kDataPar);
}

Fragment makeGemmFragment8x8() {
  IterVar i = make_itervar("i", 8);
  IterVar j = make_itervar("j", 8);
  IterVar rep = make_itervar("rep", 1);
  PrimExpr forward_thread = FloorDiv(j->var, 2) + 4 * i;
  PrimExpr index = FloorMod(j->var, 2);
  return Fragment({i, j}, {index}, forward_thread, rep);
}
/*
From https://github.com/RadeonOpenCompute/amd_matrix_instruction_calculator
./matrix_calculator.py --architecture cdna1 --instruction v_mfma_f32_16x16x16f16
--detail-instruction
*/
Fragment makeGemmFragmentAB16x16CDNA() {
  IterVar i = make_itervar("i", 16);
  IterVar j = make_itervar("j", 16);
  IterVar rep = make_itervar("rep", 1);
  PrimExpr forward_thread = 16 * FloorDiv(j->var, 4) + i;
  PrimExpr index = FloorMod(j->var, 4);
  return Fragment({i, j}, {index}, forward_thread, rep);
}

Fragment makeGemmFragmentAB16x16CDNATransposed() {
  IterVar i = make_itervar("i", 16);
  IterVar j = make_itervar("j", 16);
  IterVar rep = make_itervar("rep", 1);
  PrimExpr forward_thread = 16 * FloorDiv(i->var, 4) + j;
  PrimExpr index = FloorMod(i->var, 4);
  return Fragment({i, j}, {index}, forward_thread, rep);
}

Fragment makeGemmFragmentC16x16CDNA() {
  IterVar i = make_itervar("i", 16);
  IterVar j = make_itervar("j", 16);
  IterVar rep = make_itervar("rep", 1);
  PrimExpr forward_thread = 16 * FloorDiv(j->var, 4) + i;
  PrimExpr index = FloorMod(j->var, 4);
  return Fragment({i, j}, {index}, forward_thread, rep);
}

Fragment makeGemmFragment8x8Transposed() {
  IterVar i = make_itervar("i", 8);
  IterVar j = make_itervar("j", 8);
  IterVar rep = make_itervar("rep", 1);
  PrimExpr forward_thread = FloorDiv(i->var, 2) + 4 * j;
  PrimExpr index = FloorMod(i->var, 2);
  return Fragment({i, j}, {index}, forward_thread, rep);
}

Fragment makeGemmFragment8x16() {
  IterVar i = make_itervar("i", 8);
  IterVar j = make_itervar("j", 16);
  IterVar rep = make_itervar("rep", 1);
  PrimExpr forward_thread = FloorDiv(j->var, 4) + 4 * i;
  PrimExpr index = FloorMod(j->var, 4);
  return Fragment({i, j}, {index}, forward_thread, rep);
}

82
83
Fragment makeGemmFragmentC_F64(const int block_m, const int block_n,
                               const int warp_m, const int warp_n) {
84
85
86
87
88
  ICHECK(block_m % warp_m == 0);
  ICHECK(block_n % warp_n == 0);
  ICHECK(warp_m % 16 == 0);
  ICHECK(warp_n % 16 == 0);
  auto base_layout = makeGemmFragment8x8();
89
90
91
92
  auto warp_layout =
      base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
  auto block_layout =
      warp_layout->Repeat({warp_m / 8, warp_n / 8}, false, false);
93
94
95
  return block_layout;
}

96
97
Fragment makeGemmFragmentC(const int block_m, const int block_n,
                           const int warp_m, const int warp_n,
98
                           const int element_size) {
99
100
  if (element_size == 64)
    return makeGemmFragmentC_F64(block_m, block_n, warp_m, warp_n);
101
102
103
104
105
  ICHECK(block_m % warp_m == 0);
  ICHECK(block_n % warp_n == 0);
  ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;
  ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n;
  auto base_layout = makeGemmFragment8x8()->Repeat({2, 1}, false);
106
107
108
109
  auto warp_layout =
      base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
  auto block_layout =
      warp_layout->Repeat({warp_m / 16, warp_n / 8}, false, false);
110
111
112
  return block_layout;
}

113
114
115
116
117
Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n,
                               const int warp_m, const int warp_n,
                               const int element_size) {
  if (element_size == 64)
    LOG(FATAL) << "Not supported";
118
119
120
121
122
  ICHECK(block_m % warp_m == 0);
  ICHECK(block_n % warp_n == 0);
  ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;
  ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n;
  auto base_layout = makeGemmFragmentC16x16CDNA()->Repeat({1, 1}, false);
123
124
125
126
  auto warp_layout =
      base_layout->Repeat({warp_m / 16, warp_n / 16}, false, true);
  auto block_layout =
      warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
127
128
129
  return block_layout;
}

130
131
132
Fragment makeGemmFragmentCHopper(const int block_m, const int block_n,
                                 const int warp_m, const int warp_n,
                                 const int element_size) {
133
134
135
  ICHECK(block_m % warp_m == 0);
  // ICHECK(block_n == warp_n);
  ICHECK(warp_m % 16 == 0);
136
137
138
139
  auto warp_layout = makeGemmFragment8x8()->Repeat({2, warp_n / 8}, false,
                                                   false); // 16 x N (1 warp)
  auto block_layout = warp_layout->Repeat({block_m / warp_m, block_n / warp_n},
                                          true, false); // 16*Y x N (Y warp)
140
141
142
  return block_layout->Repeat({warp_m / 16, 1}, false, false);
}

143
144
145
Fragment makeGemmFragmentA(const int block_m, const int block_n,
                           const int block_k, const int warp_m,
                           const int warp_n, const int element_size) {
146
147
148
149
150
151
152
153
154
  // assume not transposed
  ICHECK(block_m % warp_m == 0);
  ICHECK(block_n % warp_n == 0);
  ICHECK(warp_m % 16 == 0);
  ICHECK(block_k % 16 == 0);
  // Only support 8-bit and 16-bit
  ICHECK(element_size == 8 || element_size == 16);
  if (element_size == 8) {
    auto base_layout = makeGemmFragment8x16()->Repeat({2, 2}, false, false);
155
156
157
158
    auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)
                           ->Replicate(block_n / warp_n);
    auto block_layout =
        warp_layout->Repeat({warp_m / 16, block_k / 32}, false, false);
159
160
161
    return block_layout;
  } else if (element_size == 16) {
    auto base_layout = makeGemmFragment8x8()->Repeat({2, 2}, false, false);
162
163
164
165
    auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)
                           ->Replicate(block_n / warp_n);
    auto block_layout =
        warp_layout->Repeat({warp_m / 16, block_k / 16}, false, false);
166
167
168
169
170
171
172
    return block_layout;
  } else {
    ICHECK(0);
    return Fragment();
  }
}

173
174
175
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
                               const int block_k, const int warp_m,
                               const int warp_n, bool transposed) {
176
177
178
179
180
181
  // assume not transposed
  ICHECK(block_m % warp_m == 0);
  ICHECK(block_n % warp_n == 0);
  ICHECK(warp_m % 16 == 0);
  ICHECK(block_k % 16 == 0);
  if (transposed) {
182
183
184
185
186
187
    auto base_layout =
        makeGemmFragmentAB16x16CDNATransposed()->Repeat({1, 1}, false, false);
    auto warp_layout =
        base_layout->Repeat({warp_m / 16, block_k / 16}, false, false);
    auto block_layout = warp_layout->Repeat({block_m / warp_m, 1}, true, true)
                            ->Replicate(block_n / warp_n);
188
189
    return block_layout;
  } else {
190
191
192
193
194
195
    auto base_layout =
        makeGemmFragmentAB16x16CDNA()->Repeat({1, 1}, false, false);
    auto warp_layout =
        base_layout->Repeat({warp_m / 16, block_k / 16}, false, false);
    auto block_layout = warp_layout->Repeat({block_m / warp_m, 1}, true, true)
                            ->Replicate(block_n / warp_n);
196
197
198
199
    return block_layout;
  }
}

200
201
202
Fragment makeGemmFragmentB(const int block_m, const int block_n,
                           const int block_k, const int warp_m,
                           const int warp_n) {
203
204
205
  // transposed
  ICHECK(warp_n % 8 == 0);
  ICHECK(block_k % 16 == 0);
206
207
208
209
210
211
  auto base_layout =
      makeGemmFragment8x8Transposed()->Repeat({2, 1}, false, false);
  auto warp_layout = base_layout->Replicate(block_m / warp_m)
                         ->Repeat({1, block_n / warp_n}, true);
  auto block_layout =
      warp_layout->Repeat({block_k / 16, warp_n / 8}, false, true);
212
213
214
215
216
217
218
219
220
221
222
  return block_layout;
}

Fragment makeGemmFragment32x32(int element_size) {
  IterVar i = make_itervar("i", 32);
  IterVar j = make_itervar("j", 32);
  IterVar rep = make_itervar("rep", 1);
  ICHECK(element_size == 16 || element_size == 32);
  if (element_size == 16) {
    PrimExpr thd = FloorMod(i, 4) + FloorDiv(FloorMod(i, 16), 8) * 4 +
                   FloorDiv(FloorMod(j, 16), 8) * 8 + FloorDiv(i, 16) * 16;
223
224
    PrimExpr idx = FloorMod(j, 4) + FloorDiv(j, 16) * 4 +
                   FloorDiv(FloorMod(i, 8), 4) * 8 +
225
226
227
228
                   FloorDiv(FloorMod(j, 8), 4) * 16;
    return Fragment({i, j}, {idx}, thd, rep);
  } else {
    PrimExpr thd = FloorMod(i, 2) + 2 * FloorDiv(FloorMod(j, 4), 2) +
229
230
231
232
233
                   FloorDiv(FloorMod(i, 16), 8) * 4 +
                   FloorDiv(FloorMod(j, 16), 8) * 8 + FloorDiv(i, 16) * 16;
    PrimExpr idx = FloorMod(j, 2) + 2 * FloorDiv(FloorMod(i, 4), 2) +
                   FloorDiv(j, 16) * 4 + FloorDiv(FloorMod(i, 8), 4) * 8 +
                   FloorDiv(FloorMod(j, 8), 4) * 16;
234
235
236
237
    return Fragment({i, j}, {idx}, thd, rep);
  }
}

238
239
240
Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n,
                                const int warp_m, const int warp_n,
                                int element_size) {
241
242
243
244
245
  ICHECK(block_m % warp_m == 0);
  ICHECK(block_n % warp_n == 0);
  ICHECK(warp_m % 32 == 0);
  ICHECK(warp_n % 32 == 0);
  auto base_layout = makeGemmFragment32x32(element_size);
246
247
248
249
  auto warp_layout =
      base_layout->Repeat({warp_m / 32, warp_n / 32}, false, false);
  auto block_layout =
      warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true);
250
251
252
  return block_layout;
}

253
254
255
Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n,
                                const int block_k, const int warp_m,
                                const int warp_n) {
256
257
258
259
260
261
262
263
264
  // assume not transposed
  ICHECK(block_m % warp_m == 0);
  ICHECK(block_n % warp_n == 0);
  ICHECK(warp_m % 32 == 0);
  ICHECK(block_k % 4 == 0);
  // this is a special case
  IterVar i = make_itervar("i", 32);
  IterVar j = make_itervar("j", 4);
  IterVar rep = make_itervar("rep", 2);
265
266
  PrimExpr thd = FloorDiv(FloorMod(i, 16), 8) * 4 + 16 * FloorDiv(i, 16) +
                 FloorMod(i, 4) + 8 * rep;
267
268
  PrimExpr idx = j + FloorDiv(FloorMod(i, 8), 4) * 4;
  Fragment base_layout = Fragment({i, j}, {idx}, thd, rep);
269
270
271
272
  auto warp_layout =
      base_layout->Repeat({warp_m / 32, block_k / 4}, false, false);
  auto block_layout = warp_layout->Replicate(block_n / warp_n)
                          ->Repeat({block_m / warp_m, 1}, true);
273
274
275
  return block_layout;
}

276
277
278
PrimExpr xor2x2(const PrimExpr &i, const PrimExpr &j) {
  return FloorMod(i + j, 2);
}
279

280
PrimExpr xor4x4(const PrimExpr &i, const PrimExpr &j) {
281
282
283
284
285
286
287
  PrimExpr i0 = FloorMod(i, 2);
  PrimExpr j0 = FloorMod(j, 2);
  PrimExpr i1 = FloorDiv(i, 2);
  PrimExpr j1 = FloorDiv(j, 2);
  return 2 * xor2x2(i1, j1) + xor2x2(i0, j0);
}

288
PrimExpr xor8x8(const PrimExpr &i, const PrimExpr j) {
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
  PrimExpr i0 = FloorMod(i, 2);
  PrimExpr j0 = FloorMod(j, 2);
  PrimExpr i1 = FloorDiv(i, 2);
  PrimExpr j1 = FloorDiv(j, 2);
  return 2 * xor4x4(i1, j1) + xor2x2(i0, j0);
}

Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size) {
  // Swizzle 2 bit
  Var i = InputPlaceholder(0);
  Var j = InputPlaceholder(1);
  int vector_size = 128 / element_size;
  ICHECK(stride % 8 == 0);
  ICHECK(continuous % (vector_size * 4) == 0);
  PrimExpr ts = FloorDiv(i, 8);
  PrimExpr s = FloorMod(i, 8);
  PrimExpr tc = FloorDiv(FloorDiv(j, vector_size), 4);
  PrimExpr c = FloorMod(FloorDiv(j, vector_size), 4);
  PrimExpr vec = FloorMod(j, vector_size);
  PrimExpr c_swizzle = xor4x4(c, FloorDiv(s, 2));
  PrimExpr index = vec + (c_swizzle + s * 4) * vector_size;
  return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
}

Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size) {
  // Swizzle 3 bit
  Var i = InputPlaceholder(0);
  Var j = InputPlaceholder(1);
  int vector_size = 128 / element_size;
  ICHECK(stride % 8 == 0);
  ICHECK(continuous % (vector_size * 8) == 0);
  PrimExpr ts = FloorDiv(i, 8);
  PrimExpr s = FloorMod(i, 8);
  PrimExpr tc = FloorDiv(FloorDiv(j, vector_size), 8);
  PrimExpr c = FloorMod(FloorDiv(j, vector_size), 8);
  PrimExpr vec = FloorMod(j, vector_size);
  PrimExpr c_swizzle = xor8x8(c, s);
  PrimExpr index = vec + (c_swizzle + s * 8) * vector_size;
  return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
}

330
331
332
333
// Detail implementation please ref to
// bitblas::tl::mfma_layout::make_mfma_swizzle_layout
Layout makeMatrixCoreSwizzleLayout(int stride, int continuous, int element_size,
                                   int kPack = 1) {
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
  const int numBanks = 32;
  const int bankBitWidth = 32;
  const int SIMDWidth = 16;
  const int vecSize = 4 * kPack;
  const int innerDimLength = continuous;
  const int typeWidthInBit = element_size;

  const int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit;
  const int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
  const int maxPhase = std::min(SIMDWidth / perPhase, innerDimLength / vecSize);

  IterVar row = make_itervar("row", stride);
  IterVar col = make_itervar("col", continuous);
  PrimExpr phase = FloorMod(row / perPhase, maxPhase);
  PrimExpr colOffSwizzled = ((col / vecSize) ^ phase) * vecSize;
  PrimExpr colOffOrdered = FloorMod(col, vecSize);
  PrimExpr colOff = colOffSwizzled + colOffOrdered;

  return Layout(Array{row, col}, {row, colOff});
}

Layout makeGemmABLayoutF64_Kinner(int stride, int continuous) {
  // Swizzle<2, 0, 4>
  Var i = InputPlaceholder(0);
  Var j = InputPlaceholder(1);
  PrimExpr tc = FloorDiv(j, 16);
  PrimExpr ts = FloorDiv(i, 4);
  PrimExpr c = FloorMod(j, 16);
  PrimExpr s = FloorMod(i, 4);
  PrimExpr swizzled_c = FloorDiv(c, 4) * 4 + xor4x4(FloorMod(c, 4), s);
  PrimExpr index = swizzled_c + s * 16;
  return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
}

Layout makeGemmABLayoutF64_Kouter(int stride, int continuous) {
  // Swizzle<2, 2, 2>
  Var i = InputPlaceholder(0);
  Var j = InputPlaceholder(1);
  PrimExpr tc = FloorDiv(j, 16);
  PrimExpr ts = FloorDiv(i, 4);
  PrimExpr c = FloorMod(j, 16);
  PrimExpr s = FloorMod(i, 4);
  PrimExpr swizzled_c = FloorMod(c, 4) + xor4x4(FloorDiv(c, 4), s) * 4;
  PrimExpr index = swizzled_c + s * 16;
  return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
}

// The Default Layout for Tensor Access
Layout makeGemmLayoutLinear(int stride, int continuous) {
  IterVar i = make_itervar("i", stride);
  IterVar j = make_itervar("j", continuous);
  return Layout(Array{i, j}, {i * continuous + j});
}

Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size) {
  IterVar i = make_itervar("i", stride);
  IterVar j = make_itervar("j", continuous);
  int padded = continuous;
  // Add 128 bits padding when the last dim is a multiple of 256 bits
393
394
  if ((element_size * continuous) % 256 == 0)
    padded += 128 / element_size;
395
396
397
398
399
400
401
402
403
404
  return Layout(Array{i, j}, {i * padded + j});
}

Layout MakeGemmVoltaABLayoutCrosswise(int stride, int continuous) {
  ICHECK(stride % 32 == 0 && continuous % 32 == 0);
  IterVar i = make_itervar("i", stride);
  IterVar j = make_itervar("j", continuous);
  PrimExpr vec_contiguous_idx = FloorDiv(j, 4);
  PrimExpr vec_strided_within_tile = FloorMod(vec_contiguous_idx, 8);

405
406
407
408
409
410
411
412
413
414
415
  PrimExpr bit2 =
      FloorMod(FloorDiv(FloorMod(i, 32), 16) + FloorDiv(FloorMod(i, 16), 8) +
                   FloorDiv(vec_strided_within_tile, 4),
               2);
  PrimExpr bit1 = xor2x2(FloorDiv(FloorMod(i, 8), 4),
                         FloorDiv(FloorMod(vec_strided_within_tile, 4), 2));
  PrimExpr permuted_vec_contiguous =
      FloorDiv(i, 16) * 16 + FloorMod(i, 4) * 4 + bit2 * 2 + bit1;

  PrimExpr offset = FloorMod(j, 4) + permuted_vec_contiguous * 4 +
                    vec_contiguous_idx * stride * 4;
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
  return Layout(Array{i, j}, {offset});
}

Layout MakeGemmVoltaALayoutCongruous(int stride, int continuous) {
  ICHECK(stride % 4 == 0 && continuous % 64 == 0);
  IterVar i = make_itervar("i", stride);
  IterVar j = make_itervar("j", continuous);
  PrimExpr vec_contiguous_idx = FloorDiv(j, 8);
  PrimExpr vec_strided_idx = i;
  PrimExpr tile_contiguous_idx = FloorDiv(vec_contiguous_idx, 8);
  PrimExpr tile_strided_idx = FloorDiv(vec_strided_idx, 4);
  PrimExpr tile_contiguous_residual = FloorMod(vec_contiguous_idx, 8);
  PrimExpr tile_strided_residual = FloorMod(vec_strided_idx, 4);

  PrimExpr permuted_strided_within_tile = FloorDiv(tile_contiguous_residual, 2);
  PrimExpr permuted_contiguous_within_tile =
      FloorMod(tile_contiguous_residual, 2) * 4 +
      xor4x4(tile_strided_residual, permuted_strided_within_tile);

435
436
  PrimExpr element_strided =
      permuted_strided_within_tile + tile_strided_idx * 4;
437
  PrimExpr element_contiguous =
438
439
      FloorMod(j, 8) +
      (permuted_contiguous_within_tile + tile_contiguous_idx * 8) * 8;
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
  PrimExpr offset = element_strided * continuous + element_contiguous;
  return Layout(Array{i, j}, {offset});
}

Layout MakeGemmVoltaBLayoutCongruous(int stride, int continuous) {
  ICHECK(stride % 4 == 0 && continuous % 64 == 0);
  IterVar i = make_itervar("i", stride);
  IterVar j = make_itervar("j", continuous);
  PrimExpr vec_contiguous_idx = FloorDiv(j, 8);
  PrimExpr vec_strided_idx = i;
  PrimExpr tile_contiguous_idx = FloorDiv(vec_contiguous_idx, 8);
  PrimExpr tile_strided_idx = FloorDiv(vec_strided_idx, 4);
  PrimExpr tile_contiguous_residual = FloorMod(vec_contiguous_idx, 8);
  PrimExpr tile_strided_residual = FloorMod(vec_strided_idx, 4);

  PrimExpr permuted_strided_within_tile = FloorMod(tile_contiguous_residual, 4);
  PrimExpr permuted_contiguous_within_tile =
      FloorDiv(tile_contiguous_residual, 4) * 4 +
      xor4x4(tile_strided_residual, permuted_strided_within_tile);

460
461
  PrimExpr element_strided =
      permuted_strided_within_tile + tile_strided_idx * 4;
462
  PrimExpr element_contiguous =
463
464
      FloorMod(j, 8) +
      (permuted_contiguous_within_tile + tile_contiguous_idx * 8) * 8;
465
466
467
468
  PrimExpr offset = element_strided * continuous + element_contiguous;
  return Layout(Array{i, j}, {offset});
}

469
470
471
472
473
474
475
476
Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
                             int kfactor) {
  if (kfactor == 2)
    return MakeGemmVoltaABLayoutCrosswise(stride, continuous);
  if (is_a && continuous % 64 == 0)
    return MakeGemmVoltaALayoutCongruous(stride, continuous);
  if (!is_a && continuous % 64 == 0)
    return MakeGemmVoltaBLayoutCongruous(stride, continuous);
477
478
479
  return makeGemmABLayoutPadded(stride, continuous, 16);
}

480
481
Layout makeGemmABLayout(int stride, int continuous, int element_size,
                        int kfactor) {
482
  if (element_size == 64) {
483
    if (kfactor == 1 && continuous % 16 == 0) // float64 KxN
484
      return makeGemmABLayoutF64_Kouter(stride, continuous);
485
    if (kfactor == 2 && continuous % 16 == 0) // float64 NxK
486
487
488
489
      return makeGemmABLayoutF64_Kinner(stride, continuous);
    return makeGemmABLayoutPadded(stride, continuous, element_size);
  }
  int vector_size = 128 / element_size;
490
  if (kfactor == 1 && element_size == 8) // int8 KxN
491
492
493
494
495
496
497
498
499
500
    return makeGemmABLayoutPadded(stride, continuous, element_size);
  else if (continuous % (vector_size * 8) == 0)
    return makeFullBankSwizzleLayout(stride, continuous, element_size);
  else if (continuous % (vector_size * 4) == 0)
    return makeHalfBankSwizzleLayout(stride, continuous, element_size);
  else {
    return makeGemmABLayoutPadded(stride, continuous, element_size);
  }
}

501
502
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
                            int kPack) {
503
504
505
506
507
508
509
  int vector_size = 128 / element_size;
  if (continuous % (vector_size * 4) == 0)
    return makeMatrixCoreSwizzleLayout(stride, continuous, element_size, kPack);
  else {
    return makeGemmABLayoutPadded(stride, continuous, element_size);
  }
}
510
511
} // namespace tl
} // namespace tvm