gemm_layouts.cc 33.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
/*!
 * \file layout/gemm_layouts.cc
 * \brief Define Layout used in MMA and other operations.
 *
 */

#include <tvm/tir/stmt_functor.h>

#include <cmath>

#include "layout.h"

namespace tvm {
namespace tl {

16
IterVar make_itervar(std::string name, PrimExpr dom) {
17
  Var var = Var(name, dom->dtype);
18
19
20
  return IterVar(Range(0, dom), var, IterVarType::kDataPar);
}

21
22
23
24
25
26
27
28
29
Fragment makeGemmFragment8x4() {
  IterVar i = make_itervar("i", 8);
  IterVar j = make_itervar("j", 4);
  IterVar rep = make_itervar("rep", 1);
  PrimExpr forward_thread = FloorDiv(j->var, 1) + 4 * i;
  PrimExpr index = FloorMod(j->var, 1);
  return Fragment({i, j}, {index}, forward_thread, rep);
}

30
31
32
33
34
35
36
37
Fragment makeGemmFragment8x8() {
  IterVar i = make_itervar("i", 8);
  IterVar j = make_itervar("j", 8);
  IterVar rep = make_itervar("rep", 1);
  PrimExpr forward_thread = FloorDiv(j->var, 2) + 4 * i;
  PrimExpr index = FloorMod(j->var, 2);
  return Fragment({i, j}, {index}, forward_thread, rep);
}
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

Fragment makeGemmFragment8x16() {
  IterVar i = make_itervar("i", 8);
  IterVar j = make_itervar("j", 16);
  IterVar rep = make_itervar("rep", 1);
  PrimExpr forward_thread = FloorDiv(j->var, 4) + 4 * i;
  PrimExpr index = FloorMod(j->var, 4);
  return Fragment({i, j}, {index}, forward_thread, rep);
}

Fragment makeGemmFragment8x8Transposed() {
  IterVar i = make_itervar("i", 8);
  IterVar j = make_itervar("j", 8);
  IterVar rep = make_itervar("rep", 1);
  PrimExpr forward_thread = FloorDiv(i->var, 2) + 4 * j;
  PrimExpr index = FloorMod(i->var, 2);
  return Fragment({i, j}, {index}, forward_thread, rep);
}

57
58
59
60
61
/*
From https://github.com/RadeonOpenCompute/amd_matrix_instruction_calculator
./matrix_calculator.py --architecture cdna1 --instruction v_mfma_f32_16x16x16f16
--detail-instruction
*/
62
Fragment makeGemmFragmentAB16x16CDNA(const int k_pack) {
63
  IterVar i = make_itervar("i", 16);
64
65
66
67
68
69
70
71
72
  IterVar j = make_itervar("j", 16 * k_pack);
  IterVar rep = make_itervar("rep", 1);
  PrimExpr forward_thread = 16 * FloorDiv(j->var, 4 * k_pack) + i;
  PrimExpr index = FloorMod(j->var, 4 * k_pack);
  return Fragment({i, j}, {index}, forward_thread, rep);
}

Fragment makeGemmFragmentAB16x16CDNATransposed(const int k_pack) {
  IterVar i = make_itervar("i", 16 * k_pack);
73
74
  IterVar j = make_itervar("j", 16);
  IterVar rep = make_itervar("rep", 1);
75
76
  PrimExpr forward_thread = 16 * FloorDiv(i->var, 4 * k_pack) + j;
  PrimExpr index = FloorMod(i->var, 4 * k_pack);
77
78
79
  return Fragment({i, j}, {index}, forward_thread, rep);
}

80
Fragment makeGemmFragmentAB16x32CDNA(const int k_pack) {
81
  IterVar i = make_itervar("i", 16);
82
83
84
85
86
87
88
89
90
  IterVar j = make_itervar("j", 32 * k_pack);
  IterVar rep = make_itervar("rep", 1);
  PrimExpr forward_thread = 16 * FloorDiv(j->var, 8 * k_pack) + i;
  PrimExpr index = FloorMod(j->var, 8 * k_pack);
  return Fragment({i, j}, {index}, forward_thread, rep);
}

Fragment makeGemmFragmentAB16x32CDNATransposed(const int k_pack) {
  IterVar i = make_itervar("i", 32 * k_pack);
91
92
  IterVar j = make_itervar("j", 16);
  IterVar rep = make_itervar("rep", 1);
93
94
  PrimExpr forward_thread = 16 * FloorDiv(i->var, 8 * k_pack) + j;
  PrimExpr index = FloorMod(i->var, 8 * k_pack);
95
96
97
98
99
100
101
102
103
104
105
106
  return Fragment({i, j}, {index}, forward_thread, rep);
}

Fragment makeGemmFragmentC16x16CDNA() {
  IterVar i = make_itervar("i", 16);
  IterVar j = make_itervar("j", 16);
  IterVar rep = make_itervar("rep", 1);
  PrimExpr forward_thread = 16 * FloorDiv(j->var, 4) + i;
  PrimExpr index = FloorMod(j->var, 4);
  return Fragment({i, j}, {index}, forward_thread, rep);
}

107
108
Fragment makeGemmFragmentC_F64(const int block_m, const int block_n,
                               const int warp_m, const int warp_n) {
109
110
111
  ICHECK(block_m % warp_m == 0);
  ICHECK(block_n % warp_n == 0);
  ICHECK(warp_m % 16 == 0);
112
  ICHECK(warp_n % 8 == 0);
113
  auto base_layout = makeGemmFragment8x8();
114
115
116
117
  auto warp_layout =
      base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
  auto block_layout =
      warp_layout->Repeat({warp_m / 8, warp_n / 8}, false, false);
118
119
120
  return block_layout;
}

121
122
Fragment makeGemmFragmentC(const int block_m, const int block_n,
                           const int warp_m, const int warp_n,
123
                           const int element_size) {
124
125
  if (element_size == 64)
    return makeGemmFragmentC_F64(block_m, block_n, warp_m, warp_n);
126
127
128
  ICHECK(block_m % warp_m == 0);
  ICHECK(block_n % warp_n == 0);
  ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;
129
  ICHECK(warp_n % 8 == 0) << "warp_n=" << warp_n;
130
  auto base_layout = makeGemmFragment8x8()->Repeat({2, 1}, false);
131
132
133
134
  auto warp_layout =
      base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
  auto block_layout =
      warp_layout->Repeat({warp_m / 16, warp_n / 8}, false, false);
135
136
137
  return block_layout;
}

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
Fragment makeGemmSparseFragmentC(const int block_m, const int block_n,
                                 const int warp_m, const int warp_n,
                                 const int element_size) {
  if (element_size == 64) {
    ICHECK(false) << "Not supported";
  }
  ICHECK(block_m % warp_m == 0);
  ICHECK(block_n % warp_n == 0);
  ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;
  ICHECK(warp_n % 8 == 0) << "warp_n=" << warp_n;
  auto base_layout = makeGemmFragment8x8()->Repeat({2, 1}, false);
  // NOTE: This func wasn't implemented by following the CUTLASS 2 iterator
  // but by inspecting the output, it appears that we first need to
  // repeat the warp layout while avoiding duplicate thread mappings.
  auto warp_layout =
      base_layout->Repeat({warp_m / 16, warp_n / 8}, false, false);
  auto block_layout =
      warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
  return block_layout;
}

Lukinon's avatar
Lukinon committed
159
Fragment makeGemmFragmentCDCU(const int block_m, const int block_n,
160
161
                              const int warp_m, const int warp_n,
                              const int element_size) {
Lukinon's avatar
Lukinon committed
162
163
164
165
166
167
168
169
170
171
  if (element_size == 64)
    LOG(FATAL) << "Not supported";
  ICHECK(block_m % warp_m == 0);
  ICHECK(block_n % warp_n == 0);
  ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;
  ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n;
  auto base_layout = makeGemmFragmentC16x16CDNA()->Repeat({1, 1}, false);
  auto warp_layout =
      base_layout->Repeat({warp_m / 16, warp_n / 16}, false, false);
  auto block_layout =
172
      warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
Lukinon's avatar
Lukinon committed
173
174
175
  return block_layout;
}

176
177
178
179
180
Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n,
                               const int warp_m, const int warp_n,
                               const int element_size) {
  if (element_size == 64)
    LOG(FATAL) << "Not supported";
181
182
183
184
185
  ICHECK(block_m % warp_m == 0);
  ICHECK(block_n % warp_n == 0);
  ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;
  ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n;
  auto base_layout = makeGemmFragmentC16x16CDNA()->Repeat({1, 1}, false);
186
187
188
189
  auto warp_layout =
      base_layout->Repeat({warp_m / 16, warp_n / 16}, false, true);
  auto block_layout =
      warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
190
191
192
  return block_layout;
}

193
194
195
Fragment makeGemmFragmentCHopper(const int block_m, const int block_n,
                                 const int warp_m, const int warp_n,
                                 const int element_size) {
196
  ICHECK(block_m % warp_m == 0);
197
  ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;
198

199
200
201
202
  auto warp_layout = makeGemmFragment8x8()->Repeat({2, warp_n / 8}, false,
                                                   false); // 16 x N (1 warp)
  auto block_layout = warp_layout->Repeat({block_m / warp_m, block_n / warp_n},
                                          true, false); // 16*Y x N (Y warp)
203
204
205
  return block_layout->Repeat({warp_m / 16, 1}, false, false);
}

206
207
Fragment makeGemmFragmentA(const int block_m, const int block_n,
                           const int block_k, const int warp_m,
208
209
                           const int warp_n, const int element_size,
                           bool transposed) {
210
211
212
213
214
215
  // assume not transposed
  ICHECK(block_m % warp_m == 0);
  ICHECK(block_n % warp_n == 0);
  ICHECK(warp_m % 16 == 0);
  ICHECK(block_k % 16 == 0);
  // Only support 8-bit and 16-bit
216
217
  ICHECK(element_size == 8 || element_size == 16 || element_size == 32)
      << "unsupported element bitwidth=" << element_size;
218
219
220
221
222

  if (transposed) {
    auto base_layout =
        makeGemmFragment8x8Transposed()->Repeat({2, 2}, false, true);
    auto warp_layout = base_layout->Repeat({1, block_m / warp_m}, true, false)
223
224
                           ->Replicate(block_n / warp_n);
    auto block_layout =
225
        warp_layout->Repeat({block_k / 16, warp_m / 16}, false, true);
226
    return block_layout;
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
  } else {
    if (element_size == 8) {
      auto base_layout = makeGemmFragment8x16()->Repeat({2, 2}, false, false);
      auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)
                             ->Replicate(block_n / warp_n);
      auto block_layout =
          warp_layout->Repeat({warp_m / 16, block_k / 32}, false, false);
      return block_layout;
    } else if (element_size == 16) {
      auto base_layout = makeGemmFragment8x8()->Repeat({2, 2}, false, false);
      auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)
                             ->Replicate(block_n / warp_n);
      auto block_layout =
          warp_layout->Repeat({warp_m / 16, block_k / 16}, false, false);
      return block_layout;
242
243
244
245
246
247
248
    } else if (element_size == 32) {
      auto base_layout = makeGemmFragment8x4()->Repeat({2, 2}, false, false);
      auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)
                             ->Replicate(block_n / warp_n);
      auto block_layout =
          warp_layout->Repeat({warp_m / 16, block_k / 8}, false, false);
      return block_layout;
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    } else {
      ICHECK(0);
      return Fragment();
    }
  }
}

Fragment makeGemmFragmentB(const int block_m, const int block_n,
                           const int block_k, const int warp_m,
                           const int warp_n, bool transposed) {
  // transposed
  ICHECK(warp_n % 8 == 0);
  ICHECK(block_k % 16 == 0);
  if (transposed) {
    auto base_layout = makeGemmFragment8x8()->Repeat({1, 2}, false, false);
264
265
    auto warp_layout = base_layout->Replicate(block_m / warp_m)
                           ->Repeat({block_n / warp_n, 1}, true, false);
266
    auto block_layout =
267
        warp_layout->Repeat({warp_n / 8, block_k / 16}, false, false);
268
269
    return block_layout;
  } else {
270
271
272
273
274
275
276
    auto base_layout =
        makeGemmFragment8x8Transposed()->Repeat({2, 1}, false, false);
    auto warp_layout = base_layout->Replicate(block_m / warp_m)
                           ->Repeat({1, block_n / warp_n}, true);
    auto block_layout =
        warp_layout->Repeat({block_k / 16, warp_n / 8}, false, true);
    return block_layout;
277
278
279
  }
}

280
281
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
                               const int block_k, const int warp_m,
282
                               const int warp_n, const int element_size,
283
                               const int k_pack, bool transposed) {
284
285
286
287
  // assume not transposed
  ICHECK(block_m % warp_m == 0);
  ICHECK(block_n % warp_n == 0);
  ICHECK(warp_m % 16 == 0);
288
289
  const int mfma_k = k_pack * (element_size == 16 ? 16 : 32);
  ICHECK(block_k % mfma_k == 0);
290
291
  ICHECK(element_size == 8 || element_size == 16)
      << "element bitwidth=" << element_size;
292
  if (transposed) {
293
    auto base_layout =
294
295
296
297
298
        element_size == 16
            ? makeGemmFragmentAB16x16CDNATransposed(k_pack)->Repeat(
                  {1, 1}, false, false)
            : makeGemmFragmentAB16x32CDNATransposed(k_pack)->Repeat(
                  {1, 1}, false, false);
299
    auto warp_layout =
300
        base_layout->Repeat({block_k / mfma_k, warp_m / 16}, false, true);
301
    auto block_layout = warp_layout->Repeat({1, block_m / warp_m}, true, true)
302
                            ->Replicate(block_n / warp_n);
303
304
    return block_layout;
  } else {
305
    auto base_layout =
306
307
308
        element_size == 16
            ? makeGemmFragmentAB16x16CDNA(k_pack)->Repeat({1, 1}, false, false)
            : makeGemmFragmentAB16x32CDNA(k_pack)->Repeat({1, 1}, false, false);
309
    auto warp_layout =
310
        base_layout->Repeat({warp_m / 16, block_k / mfma_k}, false, false);
311
312
    auto block_layout = warp_layout->Repeat({block_m / warp_m, 1}, true, true)
                            ->Replicate(block_n / warp_n);
313
314
315
316
317
318
319
320
321
322
323
324
    return block_layout;
  }
}

Fragment makeGemmFragment32x32(int element_size) {
  IterVar i = make_itervar("i", 32);
  IterVar j = make_itervar("j", 32);
  IterVar rep = make_itervar("rep", 1);
  ICHECK(element_size == 16 || element_size == 32);
  if (element_size == 16) {
    PrimExpr thd = FloorMod(i, 4) + FloorDiv(FloorMod(i, 16), 8) * 4 +
                   FloorDiv(FloorMod(j, 16), 8) * 8 + FloorDiv(i, 16) * 16;
325
326
    PrimExpr idx = FloorMod(j, 4) + FloorDiv(j, 16) * 4 +
                   FloorDiv(FloorMod(i, 8), 4) * 8 +
327
328
329
330
                   FloorDiv(FloorMod(j, 8), 4) * 16;
    return Fragment({i, j}, {idx}, thd, rep);
  } else {
    PrimExpr thd = FloorMod(i, 2) + 2 * FloorDiv(FloorMod(j, 4), 2) +
331
332
333
334
335
                   FloorDiv(FloorMod(i, 16), 8) * 4 +
                   FloorDiv(FloorMod(j, 16), 8) * 8 + FloorDiv(i, 16) * 16;
    PrimExpr idx = FloorMod(j, 2) + 2 * FloorDiv(FloorMod(i, 4), 2) +
                   FloorDiv(j, 16) * 4 + FloorDiv(FloorMod(i, 8), 4) * 8 +
                   FloorDiv(FloorMod(j, 8), 4) * 16;
336
337
338
339
    return Fragment({i, j}, {idx}, thd, rep);
  }
}

340
341
342
Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n,
                                const int warp_m, const int warp_n,
                                int element_size) {
343
344
345
346
347
  ICHECK(block_m % warp_m == 0);
  ICHECK(block_n % warp_n == 0);
  ICHECK(warp_m % 32 == 0);
  ICHECK(warp_n % 32 == 0);
  auto base_layout = makeGemmFragment32x32(element_size);
348
349
350
351
  auto warp_layout =
      base_layout->Repeat({warp_m / 32, warp_n / 32}, false, false);
  auto block_layout =
      warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true);
352
353
354
  return block_layout;
}

355
356
357
Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n,
                                const int block_k, const int warp_m,
                                const int warp_n) {
358
359
360
361
362
363
364
365
366
  // assume not transposed
  ICHECK(block_m % warp_m == 0);
  ICHECK(block_n % warp_n == 0);
  ICHECK(warp_m % 32 == 0);
  ICHECK(block_k % 4 == 0);
  // this is a special case
  IterVar i = make_itervar("i", 32);
  IterVar j = make_itervar("j", 4);
  IterVar rep = make_itervar("rep", 2);
367
368
  PrimExpr thd = FloorDiv(FloorMod(i, 16), 8) * 4 + 16 * FloorDiv(i, 16) +
                 FloorMod(i, 4) + 8 * rep;
369
370
  PrimExpr idx = j + FloorDiv(FloorMod(i, 8), 4) * 4;
  Fragment base_layout = Fragment({i, j}, {idx}, thd, rep);
371
372
373
374
  auto warp_layout =
      base_layout->Repeat({warp_m / 32, block_k / 4}, false, false);
  auto block_layout = warp_layout->Replicate(block_n / warp_n)
                          ->Repeat({block_m / warp_m, 1}, true);
375
376
377
  return block_layout;
}

378
379
380
PrimExpr xor2x2(const PrimExpr &i, const PrimExpr &j) {
  return FloorMod(i + j, 2);
}
381

382
PrimExpr xor4x4(const PrimExpr &i, const PrimExpr &j) {
383
384
385
386
387
388
389
  PrimExpr i0 = FloorMod(i, 2);
  PrimExpr j0 = FloorMod(j, 2);
  PrimExpr i1 = FloorDiv(i, 2);
  PrimExpr j1 = FloorDiv(j, 2);
  return 2 * xor2x2(i1, j1) + xor2x2(i0, j0);
}

390
PrimExpr xor8x8(const PrimExpr &i, const PrimExpr j) {
391
392
393
394
395
396
397
  PrimExpr i0 = FloorMod(i, 2);
  PrimExpr j0 = FloorMod(j, 2);
  PrimExpr i1 = FloorDiv(i, 2);
  PrimExpr j1 = FloorDiv(j, 2);
  return 2 * xor4x4(i1, j1) + xor2x2(i0, j0);
}

398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
// Layout swizzling for 32 bytes
Layout makeQuarterBankSwizzleLayout(int stride, int continuous,
                                    int element_size) {
  // Swizzle 1 bit
  Var i = InputPlaceholder(0);
  Var j = InputPlaceholder(1);
  int vector_size = 128 / element_size;
  ICHECK(stride % 8 == 0) << "stride=" << stride;
  ICHECK(continuous % (vector_size * 2) == 0)
      << "continuous=" << continuous << ", vector_size=" << vector_size;
  PrimExpr ts = FloorDiv(i, 8);
  PrimExpr s = FloorMod(i, 8);
  PrimExpr tc = FloorDiv(FloorDiv(j, vector_size), 2);
  PrimExpr c = FloorMod(FloorDiv(j, vector_size), 2);
  PrimExpr vec = FloorMod(j, vector_size);
  PrimExpr c_swizzle = xor2x2(c, FloorDiv(s, 4));
  PrimExpr index = vec + (c_swizzle + s * 2) * vector_size;
  return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
}

// Layout swizzling for 64 bytes
419
420
421
422
423
Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size) {
  // Swizzle 2 bit
  Var i = InputPlaceholder(0);
  Var j = InputPlaceholder(1);
  int vector_size = 128 / element_size;
424
425
426
  ICHECK(stride % 8 == 0) << "stride=" << stride;
  ICHECK(continuous % (vector_size * 4) == 0)
      << "continuous=" << continuous << ", vector_size=" << vector_size;
427
428
429
430
431
432
433
434
435
436
  PrimExpr ts = FloorDiv(i, 8);
  PrimExpr s = FloorMod(i, 8);
  PrimExpr tc = FloorDiv(FloorDiv(j, vector_size), 4);
  PrimExpr c = FloorMod(FloorDiv(j, vector_size), 4);
  PrimExpr vec = FloorMod(j, vector_size);
  PrimExpr c_swizzle = xor4x4(c, FloorDiv(s, 2));
  PrimExpr index = vec + (c_swizzle + s * 4) * vector_size;
  return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
}

437
// Layout swizzling for 128 bytes
438
439
440
441
442
Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size) {
  // Swizzle 3 bit
  Var i = InputPlaceholder(0);
  Var j = InputPlaceholder(1);
  int vector_size = 128 / element_size;
443
444
445
  ICHECK(stride % 8 == 0) << "stride=" << stride;
  ICHECK(continuous % (vector_size * 8) == 0)
      << "continuous=" << continuous << ", vector_size=" << vector_size;
446
447
448
449
450
451
452
453
454
455
  PrimExpr ts = FloorDiv(i, 8);
  PrimExpr s = FloorMod(i, 8);
  PrimExpr tc = FloorDiv(FloorDiv(j, vector_size), 8);
  PrimExpr c = FloorMod(FloorDiv(j, vector_size), 8);
  PrimExpr vec = FloorMod(j, vector_size);
  PrimExpr c_swizzle = xor8x8(c, s);
  PrimExpr index = vec + (c_swizzle + s * 8) * vector_size;
  return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
}

456
457
458
459
// Detail implementation please ref to
// bitblas::tl::mfma_layout::make_mfma_swizzle_layout
Layout makeMatrixCoreSwizzleLayout(int stride, int continuous, int element_size,
                                   int kPack = 1) {
460
461
462
  const int numBanks = 32;
  const int bankBitWidth = 32;
  const int SIMDWidth = 16;
463
  const int vecSize = (64 / element_size) * kPack;
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
  const int innerDimLength = continuous;
  const int typeWidthInBit = element_size;

  const int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit;
  const int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
  const int maxPhase = std::min(SIMDWidth / perPhase, innerDimLength / vecSize);

  IterVar row = make_itervar("row", stride);
  IterVar col = make_itervar("col", continuous);
  PrimExpr phase = FloorMod(row / perPhase, maxPhase);
  PrimExpr colOffSwizzled = ((col / vecSize) ^ phase) * vecSize;
  PrimExpr colOffOrdered = FloorMod(col, vecSize);
  PrimExpr colOff = colOffSwizzled + colOffOrdered;

  return Layout(Array{row, col}, {row, colOff});
}

Layout makeGemmABLayoutF64_Kinner(int stride, int continuous) {
  // Swizzle<2, 0, 4>
  Var i = InputPlaceholder(0);
  Var j = InputPlaceholder(1);
  PrimExpr tc = FloorDiv(j, 16);
  PrimExpr ts = FloorDiv(i, 4);
  PrimExpr c = FloorMod(j, 16);
  PrimExpr s = FloorMod(i, 4);
  PrimExpr swizzled_c = FloorDiv(c, 4) * 4 + xor4x4(FloorMod(c, 4), s);
  PrimExpr index = swizzled_c + s * 16;
  return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
}

Layout makeGemmABLayoutF64_Kouter(int stride, int continuous) {
  // Swizzle<2, 2, 2>
  Var i = InputPlaceholder(0);
  Var j = InputPlaceholder(1);
  PrimExpr tc = FloorDiv(j, 16);
  PrimExpr ts = FloorDiv(i, 4);
  PrimExpr c = FloorMod(j, 16);
  PrimExpr s = FloorMod(i, 4);
  PrimExpr swizzled_c = FloorMod(c, 4) + xor4x4(FloorDiv(c, 4), s) * 4;
  PrimExpr index = swizzled_c + s * 16;
  return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
}

// The Default Layout for Tensor Access
Layout makeGemmLayoutLinear(int stride, int continuous) {
  IterVar i = make_itervar("i", stride);
  IterVar j = make_itervar("j", continuous);
  return Layout(Array{i, j}, {i * continuous + j});
}

Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size) {
  IterVar i = make_itervar("i", stride);
  IterVar j = make_itervar("j", continuous);
  int padded = continuous;
  // Add 128 bits padding when the last dim is a multiple of 256 bits
519
520
  if ((element_size * continuous) % 256 == 0)
    padded += 128 / element_size;
521
522
523
524
525
526
527
528
529
530
  return Layout(Array{i, j}, {i * padded + j});
}

Layout MakeGemmVoltaABLayoutCrosswise(int stride, int continuous) {
  ICHECK(stride % 32 == 0 && continuous % 32 == 0);
  IterVar i = make_itervar("i", stride);
  IterVar j = make_itervar("j", continuous);
  PrimExpr vec_contiguous_idx = FloorDiv(j, 4);
  PrimExpr vec_strided_within_tile = FloorMod(vec_contiguous_idx, 8);

531
532
533
534
535
536
537
538
539
540
541
  PrimExpr bit2 =
      FloorMod(FloorDiv(FloorMod(i, 32), 16) + FloorDiv(FloorMod(i, 16), 8) +
                   FloorDiv(vec_strided_within_tile, 4),
               2);
  PrimExpr bit1 = xor2x2(FloorDiv(FloorMod(i, 8), 4),
                         FloorDiv(FloorMod(vec_strided_within_tile, 4), 2));
  PrimExpr permuted_vec_contiguous =
      FloorDiv(i, 16) * 16 + FloorMod(i, 4) * 4 + bit2 * 2 + bit1;

  PrimExpr offset = FloorMod(j, 4) + permuted_vec_contiguous * 4 +
                    vec_contiguous_idx * stride * 4;
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
  return Layout(Array{i, j}, {offset});
}

Layout MakeGemmVoltaALayoutCongruous(int stride, int continuous) {
  ICHECK(stride % 4 == 0 && continuous % 64 == 0);
  IterVar i = make_itervar("i", stride);
  IterVar j = make_itervar("j", continuous);
  PrimExpr vec_contiguous_idx = FloorDiv(j, 8);
  PrimExpr vec_strided_idx = i;
  PrimExpr tile_contiguous_idx = FloorDiv(vec_contiguous_idx, 8);
  PrimExpr tile_strided_idx = FloorDiv(vec_strided_idx, 4);
  PrimExpr tile_contiguous_residual = FloorMod(vec_contiguous_idx, 8);
  PrimExpr tile_strided_residual = FloorMod(vec_strided_idx, 4);

  PrimExpr permuted_strided_within_tile = FloorDiv(tile_contiguous_residual, 2);
  PrimExpr permuted_contiguous_within_tile =
      FloorMod(tile_contiguous_residual, 2) * 4 +
      xor4x4(tile_strided_residual, permuted_strided_within_tile);

561
562
  PrimExpr element_strided =
      permuted_strided_within_tile + tile_strided_idx * 4;
563
  PrimExpr element_contiguous =
564
565
      FloorMod(j, 8) +
      (permuted_contiguous_within_tile + tile_contiguous_idx * 8) * 8;
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
  PrimExpr offset = element_strided * continuous + element_contiguous;
  return Layout(Array{i, j}, {offset});
}

Layout MakeGemmVoltaBLayoutCongruous(int stride, int continuous) {
  ICHECK(stride % 4 == 0 && continuous % 64 == 0);
  IterVar i = make_itervar("i", stride);
  IterVar j = make_itervar("j", continuous);
  PrimExpr vec_contiguous_idx = FloorDiv(j, 8);
  PrimExpr vec_strided_idx = i;
  PrimExpr tile_contiguous_idx = FloorDiv(vec_contiguous_idx, 8);
  PrimExpr tile_strided_idx = FloorDiv(vec_strided_idx, 4);
  PrimExpr tile_contiguous_residual = FloorMod(vec_contiguous_idx, 8);
  PrimExpr tile_strided_residual = FloorMod(vec_strided_idx, 4);

  PrimExpr permuted_strided_within_tile = FloorMod(tile_contiguous_residual, 4);
  PrimExpr permuted_contiguous_within_tile =
      FloorDiv(tile_contiguous_residual, 4) * 4 +
      xor4x4(tile_strided_residual, permuted_strided_within_tile);

586
587
  PrimExpr element_strided =
      permuted_strided_within_tile + tile_strided_idx * 4;
588
  PrimExpr element_contiguous =
589
590
      FloorMod(j, 8) +
      (permuted_contiguous_within_tile + tile_contiguous_idx * 8) * 8;
591
592
593
594
  PrimExpr offset = element_strided * continuous + element_contiguous;
  return Layout(Array{i, j}, {offset});
}

595
Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
596
                             bool k_inner) {
597
  if (k_inner && continuous % 32 == 0 && stride % 32 == 0)
598
    return MakeGemmVoltaABLayoutCrosswise(stride, continuous);
599
  if (is_a && continuous % 64 == 0 && stride % 4 == 0)
600
    return MakeGemmVoltaALayoutCongruous(stride, continuous);
601
  if (!is_a && continuous % 64 == 0 && stride % 4 == 0)
602
    return MakeGemmVoltaBLayoutCongruous(stride, continuous);
603
604
605
  return makeGemmABLayoutPadded(stride, continuous, 16);
}

606
607
// ref:
// https://github.com/nvidia/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/layout/tensor_op_multiplicand_sm75.h#L54
608
// Although the four settings (T or NT) used distinct layouts in CUTLASS, they
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
// appeared to result in the same mem layout
Layout makeTensorOpMultiplicand(int mat_stride, int mat_continuous,
                                int elementsize, int crosswise) {
  /// This layout is optimized for 128b accesses
  static int const kAccessSize = 128;
  int kCrosswise = crosswise;

  int kElementSize = elementsize;
  int kElementsPerAccess = kAccessSize / kElementSize;

  /// Contiguous dimension of the tile shape matches one shared memory cache
  /// line - 128B.  For 128bit access size, it equals to 8 accesses.
  int kTileShapeContiguous = 128 / (kAccessSize / 8);

  int kFactor = kTileShapeContiguous * kElementsPerAccess / kCrosswise;

  ICHECK(kFactor > 0)
      << "kCrosswise should be no large than one shared memory cache line.";

  /// The strided dimension needs to be at least (WarpSize(32) /
  /// kTileShapeContiguous) for a warp to access.  To ensure conflict free
  /// access, it also needs to be at least (kTileShapeContiguous / kFactor).
  /// See comments below
  /// Fundamental tile shape in units of vectors to guarantee bank conflict free
  /// shared memory load/store.
  /// For kFactor = 1, TileShape = <8, 8>
  /// For kFactor > 1, TileShape = <8, 4>
  int kTileShapeStride =
      ((kTileShapeContiguous / kFactor) > (32 / kTileShapeContiguous))
          ? (kTileShapeContiguous / kFactor)
          : (32 / kTileShapeContiguous);

  const int kPartitionShapeContiguous = 4;
  const int kPartitionShapeStride = 4;

  // NOTE: it's always row major for tl
  IterVar i = make_itervar("i", mat_stride);
  IterVar j = make_itervar("j", mat_continuous);

  PrimExpr vec_contiguous_idx = FloorDiv(j, kElementsPerAccess);
  PrimExpr vec_strided_idx = FloorDiv(i, kFactor);

  // Compute the fundamental tile being accessed
  PrimExpr tile_contiguous_idx =
      FloorDiv(vec_contiguous_idx, FloorDiv(kTileShapeContiguous, kFactor));

  PrimExpr tile_contiguous_residual =
      FloorMod(vec_contiguous_idx, FloorDiv(kTileShapeContiguous, kFactor)) +
      (FloorMod(i, kFactor) * FloorDiv(kTileShapeContiguous, kFactor));
  PrimExpr tile_strided_residual = FloorMod(vec_strided_idx, kTileShapeStride);

  // Compute the 'partition' within the fundamental tile
  PrimExpr partition_contiguous_idx =
      FloorDiv(tile_contiguous_residual, kPartitionShapeContiguous);
  PrimExpr partition_strided_idx =
      FloorDiv(tile_strided_residual, kPartitionShapeStride);

  PrimExpr partition_contiguous_residual =
      FloorMod(tile_contiguous_residual, kPartitionShapeContiguous);
  PrimExpr partition_strided_residual =
      FloorMod(tile_strided_residual, kPartitionShapeStride);

  //
  // Then swizzle
  //

  PrimExpr permuted_vec_contiguous_within_partition = xor4x4(
      partition_contiguous_residual, FloorMod(partition_strided_residual, 4));

  PrimExpr permuted_partition_contiguous_within_tile =
      xor2x2(partition_contiguous_idx, FloorMod(partition_strided_idx, 2));

  //
  // Compute final element location
  //

  PrimExpr element_contiguous =
      (tile_contiguous_idx * kTileShapeContiguous +
       permuted_partition_contiguous_within_tile * kPartitionShapeContiguous +
       permuted_vec_contiguous_within_partition) *
          kElementsPerAccess +
      FloorMod(j, kElementsPerAccess);

  const PrimExpr &element_strided = vec_strided_idx;

  const int stride = mat_continuous;

  return Layout(Array{i, j},
                {element_contiguous + element_strided * stride * kFactor});
}

Layout makeGemmSparseAmpereABLayout(int mat_stride, int mat_continuous,
                                    int elementsize) {
  int kCrosswise = std::min(mat_continuous, (1024 / elementsize));
  return makeTensorOpMultiplicand(mat_stride, mat_continuous, elementsize,
                                  kCrosswise);
}

707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
/*!
 * \brief Creates a memory layout for GEMM's A or B matrices.
 *
 * This function selects an appropriate memory layout based on the matrix
 * dimensions, element size, continuity, and a k-factor. It aims to optimize
 * memory access patterns, potentially using swizzling techniques or specialized
 * layouts for different data types and hardware characteristics.
 *
 * \param mat_stride The leading dimension of the matrix (e.g., K for a
 * row-major M x K matrix). This is the number of elements to skip to get to the
 * same column in the next row (row-major) or to the same row in the next column
 * (column-major). \param mat_continuous The length of the dimension stored
 * contiguously in memory (e.g., K for a row-major M x K matrix, or M for a
 * column-major M x K matrix). \param continuity The size of the dimension that
 * is continuous from the perspective of memory bank access. This is used to
 * select specific swizzling strategies. It might be the same as mat_continuous
 *                   or different based on tiling or hardware details.
 * \param element_size The size of each element in the matrix, in bits (e.g., 8,
725
 * 16, 32, 64). \param k_inner Whether the K dimension is in the inner loop.
726
727
728
 * selection, particularly for fp64 and int8 types. It often relates to how the
 * K dimension of the GEMM (M x K * K x N) is handled or tiled.
 *                - For fp64 (element_size == 64):
729
730
731
732
 *                  - k_inner == false often implies K is in the "outer" loop
 * (e.g., KxN matrix).
 *                  - k_inner == true often implies K is in the "inner" loop
 * (e.g., NxK matrix).
733
 *                - For int8 (element_size == 8):
734
 *                  - k_inner == false uses a padded layout.
735
736
 * \return A Layout object representing the chosen memory layout.
 */
737
Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
738
                        int element_size, bool k_inner) {
739
  if (element_size == 64) {
740
    if (!k_inner && continuity % 16 == 0) // float64 KxN
741
      return makeGemmABLayoutF64_Kouter(mat_stride, mat_continuous);
742
    if (k_inner && continuity % 16 == 0) // float64 NxK
743
744
      return makeGemmABLayoutF64_Kinner(mat_stride, mat_continuous);
    return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
745
746
  }
  int vector_size = 128 / element_size;
747
  if (!k_inner && element_size == 8) // int8 KxN
748
    return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
749
  else if (mat_continuous % (vector_size * 8) == 0)
750
751
    // return makeHalfBankSwizzleLayout(mat_stride, mat_continuous,
    // element_size);
752
    return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
753
  else if (mat_continuous % (vector_size * 4) == 0)
754
    return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
755
  else {
756
    return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
757
758
759
  }
}

760
Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
761
                              int continuity, int element_size, bool k_inner) {
762
  if (element_size == 64) {
763
    if (!k_inner && continuity % 16 == 0) // float64 KxN
764
      return makeGemmABLayoutF64_Kouter(mat_stride, mat_continuous);
765
    if (k_inner && continuity % 16 == 0) // float64 NxK
766
767
768
769
770
      return makeGemmABLayoutF64_Kinner(mat_stride, mat_continuous);
    return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
                                        element_size);
  }
  int vector_size = 128 / element_size;
771

772
773
774
775
776
  if (mat_continuous % (vector_size * 8) == 0)
    return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
  else if (mat_continuous % (vector_size * 4) == 0)
    return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
  else if (mat_continuous % (vector_size * 2) == 0)
777
778
    return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
                                        element_size);
779
780
781
782
783
  else if (mat_continuous % vector_size == 0)
    return makeGemmLayoutLinear(mat_stride, mat_continuous);
  else
    ICHECK(0) << "Unsupported layout for Hopper with stride=" << mat_stride
              << ", continuous=" << mat_continuous
784
              << ", element_size=" << element_size << ", k_inner=" << k_inner;
785
786
787
}

Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity,
788
                             int element_size, bool k_inner) {
789
790
791
792
793
  if (element_size == 64) {
    ICHECK(0) << "float64 on sm100 is not supported now";
  }
  int vector_size = 128 / element_size;
  if (mat_continuous % (vector_size * 8) == 0)
794
795
796
    return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
  else if (mat_continuous % (vector_size * 4) == 0)
    return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
797
  else if (mat_continuous % (vector_size * 2) == 0)
798
799
    return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
                                        element_size);
800
801
802
803
804
  else if (mat_continuous % vector_size == 0)
    return makeGemmLayoutLinear(mat_stride, mat_continuous);
  else
    ICHECK(0) << "Unsupported layout for sm100 with stride=" << mat_stride
              << ", continuous=" << mat_continuous
805
              << ", element_size=" << element_size << ", k_inner=" << k_inner;
806
  __builtin_unreachable(); // to prevent compiler warning
807
808
}

809
810
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
                            int kPack) {
811
  return makeMatrixCoreSwizzleLayout(stride, continuous, element_size, kPack);
812
}
813
814
} // namespace tl
} // namespace tvm