layout.cc 20.8 KB
Newer Older
1
2
3
4
5
6
/*!
 * \file layout/layout.cc
 *
 */

#include "layout.h"
7
#include <tvm/ffi/reflection/registry.h>
8
9
10
11
12
13
14
15
16
17
18
19
20

#include <tvm/arith/pattern.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

#include "arith/pattern_match.h"
#include "utils.h"

namespace tvm {
namespace tl {

using namespace tir;

21
static Var getPlaceholder(const std::string &s) {
22
23
24
25
26
27
28
29
  static std::unordered_map<std::string, Var> map;
  if (map.find(s) == map.end()) {
    map[s] = Var(s);
  }
  return map[s];
}

Var ReplicationPlaceholder() { return getPlaceholder("_rep"); }
30
31
32
Var InputPlaceholder(size_t idx) {
  return getPlaceholder(std::string{'_', char('i' + idx)});
}
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

Map<Var, Range> LayoutNode::getVarMap() const {
  Map<Var, Range> map;
  for (size_t i = 0; i < InputDim(); i++) {
    map.Set(InputPlaceholder(i), {0, input_size_[i]});
  }
  return map;
}

Map<Var, Range> FragmentNode::getVarMap() const {
  auto map = LayoutNode::getVarMap();
  map.Set(ReplicationPlaceholder(), {0, ReplicateExtent()});
  return map;
}

48
49
LayoutNode::LayoutNode(Array<PrimExpr> input_size,
                       Array<PrimExpr> forward_index) {
50
51
52
  input_size_ = input_size;
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
53
54
  forward_index_ = forward_index.Map(
      [&](const PrimExpr &e) { return analyzer.Simplify(e); });
55
56
57
58
59
60
61
62
63
64
}

Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) {
  Map<Var, PrimExpr> vmap;
  Array<PrimExpr> input_size;
  for (size_t i = 0; i < forward_var.size(); i++) {
    vmap.Set(forward_var[i]->var, InputPlaceholder(i));
    CHECK(is_zero(forward_var[i]->dom->min));
    input_size.push_back(forward_var[i]->dom->extent);
  }
65
66
  forward_index =
      forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
67
  auto n = tvm::ffi::make_object<LayoutNode>(input_size, forward_index);
68
69
70
71
  data_ = std::move(n);
}

Layout::Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) {
72
  auto n = tvm::ffi::make_object<LayoutNode>(input_size, forward_index);
73
74
75
  data_ = std::move(n);
}

76
77
78
79
80
void LayoutNode::RegisterReflection() {
  namespace refl = tvm::ffi::reflection;
  refl::ObjectDef<LayoutNode>()
      .def_ro("input_size", &LayoutNode::input_size_)
      .def_ro("forward_index", &LayoutNode::forward_index_);
81
82
}

83
84
void LayoutNode::UpdateAnalyzer(arith::Analyzer *analyzer) const {
  for (const auto &[var, dom] : getVarMap()) {
85
86
87
88
    analyzer->Bind(var, dom);
  }
}

89
90
91
92
93
94
95
96
Array<PrimExpr> LayoutNode::GetForwardVars() const {
  Array<PrimExpr> vars;
  for (size_t i = 0; i < InputDim(); i++) {
    vars.push_back(InputPlaceholder(i));
  }
  return vars;
}

97
98
99
100
101
102
103
104
105
106
Array<PrimExpr> LayoutNode::OutputShape() const {
  Array<PrimExpr> ret(OutputDim(), 1);
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
  for (size_t i = 0; i < ret.size(); i++) {
    auto ist = analyzer.int_set(forward_index_[i] + 1);
    if (arith::is_neg_inf(ist.min()) && arith::is_pos_inf(ist.max())) {
      // X-OR Expression
      ret.Set(i, input_size_[i]);
    } else {
107
      // CHECK(is_one(ist.min())) << ist.min();
108
109
110
111
112
113
      ret.Set(i, ist.max());
    }
  }
  return ret;
}

114
115
116
Array<PrimExpr> LayoutNode::Forward(const Array<PrimExpr> &vars) const {
  if (vars.empty())
    return forward_index_;
117
118
119
120
121
122
123
124
  ICHECK_GE(vars.size(), InputDim());

  // Take the last InputDim() elements for transformation
  Array<PrimExpr> transform_vars;
  for (size_t i = vars.size() - InputDim(); i < vars.size(); i++) {
    transform_vars.push_back(vars[i]);
  }

125
126
  Map<Var, PrimExpr> vmap;
  for (size_t i = 0; i < InputDim(); i++) {
127
    vmap.Set(InputPlaceholder(i), transform_vars[i]);
128
  }
129
130

  Array<PrimExpr> transformed = forward_index_.Map(
131
      [&](const PrimExpr &e) { return Substitute(e, vmap); });
132
133
134
135
136
137
138
139
140
141
  // Concatenate with the remaining elements from vars
  Array<PrimExpr> result;
  for (size_t i = 0; i < vars.size() - InputDim(); i++) {
    result.push_back(vars[i]);
  }
  for (const auto &expr : transformed) {
    result.push_back(expr);
  }

  return result;
142
143
}

144
145
Fragment FragmentNode::Repeat(const Array<PrimExpr> &repeats,
                              bool repeat_on_thread,
146
147
148
149
150
151
                              bool lower_dim_first) const {
  ICHECK_EQ(repeats.size(), InputDim());
  Array<PrimExpr> new_input_size;
  Map<Var, PrimExpr> vmap;
  for (size_t i = 0; i < InputDim(); i++) {
    new_input_size.push_back(input_size_[i] * repeats[i]);
152
153
    vmap.Set(InputPlaceholder(i),
             FloorMod(InputPlaceholder(i), InputShape()[i]));
154
155
156
157
158
  }

  PrimExpr repeats_index = 0, repeat_stride = 1;
  if (lower_dim_first) {
    for (int i = InputDim() - 1; i >= 0; i--) {
159
160
      repeats_index +=
          repeat_stride * FloorDiv(InputPlaceholder(i), InputShape()[i]);
161
162
163
164
      repeat_stride *= repeats[i];
    }
  } else {
    for (size_t i = 0; i < InputDim(); i++) {
165
166
      repeats_index +=
          repeat_stride * FloorDiv(InputPlaceholder(i), InputShape()[i]);
167
168
169
170
171
172
      repeat_stride *= repeats[i];
    }
  }

  if (repeat_on_thread) {
    PrimExpr thread_size = ThreadExtent();
173
174
175
176
177
    auto new_forward_index = forward_index_.Map(
        [&](const PrimExpr &e) { return Substitute(e, vmap); });
    auto new_forward_thread =
        Substitute(forward_thread_, vmap) + thread_size * repeats_index;
    return Fragment(new_input_size, new_forward_index, new_forward_thread,
178
                    replicate_size_, std::nullopt);
179
180
181
182
183
184
  } else {
    ICHECK(OutputDim() == 1);
    PrimExpr frag_len = OutputShape()[0];
    Array<PrimExpr> new_forward_index = {Substitute(forward_index_[0], vmap) +
                                         frag_len * repeats_index};
    PrimExpr new_forward_thread = Substitute(forward_thread_, vmap);
185
    return Fragment(new_input_size, new_forward_index, new_forward_thread,
186
                    replicate_size_, std::nullopt);
187
188
189
190
191
192
  }
}

Fragment FragmentNode::Replicate(int repeats) const {
  ICHECK(repeats >= 1);
  Map<Var, PrimExpr> vmap;
193
194
  vmap.Set(ReplicationPlaceholder(),
           FloorMod(ReplicationPlaceholder(), ReplicateExtent()));
195
196
197
  PrimExpr new_forward_thread =
      Substitute(forward_thread_, vmap) +
      ThreadExtent() * FloorDiv(ReplicationPlaceholder(), ReplicateExtent());
198
  return Fragment(input_size_, forward_index_, new_forward_thread,
199
                  ReplicateExtent() * repeats, std::nullopt);
200
201
202
203
204
205
206
207
208
209
210
211
}

Fragment FragmentNode::DeReplicate() const {
  ICHECK(OutputDim() == 1);
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
  int factor = 1;
  auto rep_size = as_const_int(ReplicateExtent());
  auto idx_size = as_const_int(OutputShape()[0]);
  if (rep_size && idx_size) {
    factor = arith::ZeroAwareGCD(*rep_size, *idx_size);
  }
212
  if (factor == 1)
213
    return tvm::ffi::GetRef<Fragment>(this);
214
215

  Map<Var, PrimExpr> vmap;
216
217
  vmap.Set(ReplicationPlaceholder(), ReplicationPlaceholder() * factor +
                                         FloorMod(forward_index_[0], factor));
218
219
  PrimExpr new_forward_thread = Substitute(forward_thread_, vmap);
  Array<PrimExpr> new_forward_index = {FloorDiv(forward_index_[0], factor)};
220
  return Fragment(input_size_, new_forward_index, new_forward_thread,
221
                  int(*rep_size) / factor, std::nullopt);
222
223
}

224
Fragment FragmentNode::BindThreadRange(Range thread_range) const {
225
  auto n = tvm::ffi::make_object<FragmentNode>(*this);
226
227
  n->thread_range_ = thread_range;
  return Fragment(n);
228
229
}

230
std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const {
231
  arith::Analyzer analyzer;
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
  auto collect_symbolic = [&](const Array<PrimExpr> &shape) {
    Array<PrimExpr> symbolic_dims;
    for (const auto &dim : shape) {
      if (!as_const_int(dim)) {
        symbolic_dims.push_back(dim);
      }
    }
    return symbolic_dims;
  };
  Array<PrimExpr> symbolic_dims = collect_symbolic(input_size_);
  Array<PrimExpr> output_shape = OutputShape();
  symbolic_dims.insert(symbolic_dims.end(), output_shape.begin(),
                       output_shape.end());
  symbolic_dims = collect_symbolic(symbolic_dims);
  bool is_static_shape = symbolic_dims.empty();
  auto level = is_static_shape ? arith::IterMapLevel::Bijective
                               : arith::IterMapLevel::NoCheck;
  if (!is_static_shape) {
    // Runtime guards keep dynamic tails safe, so we allow NoCheck here and
    // warn.
    LOG(WARNING) << "Layout::Inverse on symbolic layout, falling back to "
                    "NoCheck; symbolic dims: "
                 << symbolic_dims;
  }
256
  arith::IterMapResult res =
257
      arith::DetectIterMap(forward_index_, getVarMap(), 1, level, &analyzer);
258
259
  ICHECK(res->errors.empty())
      << "Layout " << DebugOutput() << " has errors: " << res->errors;
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277

  auto outputs_shape = OutputShape();
  Array<PrimExpr> outputs;
  for (size_t i = 0; i < OutputDim(); i++) {
    outputs.push_back(InputPlaceholder(i));
  }

  auto inv = arith::InverseAffineIterMap(res->indices, outputs);

  Array<PrimExpr> backward_index;
  for (size_t i = 0; i < InputDim(); i++) {
    if (inv.find(InputPlaceholder(i)) != inv.end()) {
      backward_index.push_back(inv[InputPlaceholder(i)]);
    } else {
      backward_index.push_back(0);
    }
  }

278
  return {Layout(outputs_shape, backward_index), level};
279
280
}

281
282
283
284
Layout LayoutNode::Inverse() const {
  auto inverse_result = InverseWithLevel();
  return std::move(inverse_result.first);
}
285
286
287
288
289
PrimExpr infer_fragment_index(const Map<Var, Range> &input_iters,
                              const PrimExpr &forward_thread,
                              arith::Analyzer *analyzer) {
  Array<arith::IterSplitExpr> splits = DivideUnusedIterators(
      {forward_thread}, ToIterVars(input_iters), analyzer);
290
291

  Array<arith::IterSplitExpr> split_without_rep;
292
  for (const auto &split : splits) {
293
    CHECK(split->source->source.as<Var>());
294
295
296
    if (split->source->source.as<Var>().value().same_as(
            ReplicationPlaceholder()))
      continue;
297
298
299
300
301
    split_without_rep.push_back(split);
  }
  return MakeFlattenedExpression(split_without_rep);
}

302
303
FragmentNode::FragmentNode(Array<PrimExpr> input_size,
                           Array<PrimExpr> forward_index,
304
305
306
307
308
309
310
                           PrimExpr forward_thread, PrimExpr replicate_size) {
  input_size_ = input_size;
  replicate_size_ = replicate_size;
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
  forward_thread_ = analyzer.Simplify(forward_thread);
  if (forward_index.empty()) {
311
312
    forward_index = {
        infer_fragment_index(getVarMap(), forward_thread_, &analyzer)};
313
  }
314
315
  forward_index_ = forward_index.Map(
      [&](const PrimExpr &e) { return analyzer.Simplify(e); });
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
}

Fragment::Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
                   PrimExpr forward_thread, IterVar thread_replicate) {
  Map<Var, PrimExpr> vmap;
  Array<PrimExpr> input_size;
  PrimExpr replicate_size = 1;
  for (size_t i = 0; i < forward_var.size(); i++) {
    vmap.Set(forward_var[i]->var, InputPlaceholder(i));
    CHECK(is_zero(forward_var[i]->dom->min));
    input_size.push_back(forward_var[i]->dom->extent);
  }
  if (thread_replicate.defined()) {
    ICHECK(is_zero(thread_replicate->dom->min));
    replicate_size = thread_replicate->dom->extent;
    vmap.Set(thread_replicate->var, ReplicationPlaceholder());
  }
333
334
  forward_index =
      forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
335
336
  forward_thread = Substitute(forward_thread, vmap);

337
338
  auto n = tvm::ffi::make_object<FragmentNode>(input_size, forward_index,
                                               forward_thread, replicate_size);
339
340
341
342
  data_ = std::move(n);
}

Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
343
344
                   PrimExpr forward_thread, PrimExpr replicate_size,
                   Optional<Var> replicate_var) {
345
  if (replicate_var.defined()) {
346
347
    forward_thread = Substitute(
        forward_thread, {{replicate_var.value(), ReplicationPlaceholder()}});
348
  }
349
350
  auto n = tvm::ffi::make_object<FragmentNode>(input_size, forward_index,
                                               forward_thread, replicate_size);
351
352
353
  data_ = std::move(n);
}

354
355
356
357
358
359
360
// which means the forward_thread is rep_var -> lambda i, rep: rep
bool FragmentNode::IsCompletedReplicated() const {
  arith::Analyzer analyzer;
  return ExprDeepEqual()(analyzer.Simplify(forward_thread_),
                         ReplicationPlaceholder());
}

361
362
363
364
365
366
367
368
PrimExpr FragmentNode::ThreadExtent() const {
  Array<PrimExpr> ret(OutputDim(), 1);
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
  auto ist = analyzer.int_set(forward_thread_ + 1);
  return ist.max();
}

369
370
371
372
373
374
375
376
377
378
379
Array<PrimExpr> FragmentNode::GetForwardVars() const {
  Array<PrimExpr> vars;
  if (*as_const_int(ReplicateExtent()) > 1) {
    vars.push_back(ReplicationPlaceholder());
  }
  for (size_t i = 0; i < InputDim(); i++) {
    vars.push_back(InputPlaceholder(i));
  }
  return vars;
}

380
381
PrimExpr FragmentNode::ForwardThread(const Array<PrimExpr> &vars,
                                     const Optional<PrimExpr> &rep_var) const {
382
383
384
385
386
  Map<Var, PrimExpr> vmap;
  ICHECK_EQ(vars.size(), InputDim());
  for (size_t i = 0; i < InputDim(); i++) {
    vmap.Set(InputPlaceholder(i), vars[i]);
  }
387
388
  if (rep_var.defined())
    vmap.Set(ReplicationPlaceholder(), rep_var.value());
389
390
391
392
393

  return Substitute(forward_thread_, vmap);
}

Layout FragmentNode::Inverse() const {
394
395
396
397
398
  auto result = InverseWithLevel();
  return std::move(result.first);
}

std::pair<Layout, arith::IterMapLevel> FragmentNode::InverseWithLevel() const {
399
400
401
402
  auto input_size_copy = input_size_;
  input_size_copy.push_back(ReplicateExtent());
  auto forward_index_copy = forward_index_;
  forward_index_copy.push_back(
403
404
      Substitute(forward_thread_,
                 {{ReplicationPlaceholder(), InputPlaceholder(InputDim())}}));
405
  auto fwd = Layout(input_size_copy, forward_index_copy);
406
  return fwd->InverseWithLevel();
407
408
409
410
411
412
413
414
}

Fragment FragmentNode::CondenseReplicateVar() const {
  arith::Analyzer analyzer;
  auto input_iters = getVarMap();
  input_iters.Set(ReplicationPlaceholder(), {0, ReplicateExtent()});
  PrimExpr new_forward_thread;
  IterVar new_thread_replicate;
415
416
417
  std::tie(new_forward_thread, new_thread_replicate) =
      CompressIterator(forward_thread_, ToIterVars(input_iters),
                       ReplicationPlaceholder(), &analyzer);
418
419
420
421
  return Fragment(input_size_, forward_index_, new_forward_thread,
                  new_thread_replicate->dom->extent, new_thread_replicate->var);
}

422
423
std::string LayoutNode::DebugOutput() const {
  std::stringstream ss;
424
425
426
  ss << "Layout(" << InputShape() << " -> " << OutputShape()
     << ", transform: " << GetForwardVars() << " -> " << GetForwardIndex()
     << ")";
427
  return ss.str();
428
429
}

430
431
std::string FragmentNode::DebugOutput() const {
  std::stringstream ss;
432
433
434
435
436
437
438
439
  ss << "Fragment(" << InputShape() << " -> " << OutputShape()
     << ", replicate: " << ReplicateExtent() << ", thread: " << ThreadExtent()
     << ", forward_thread: " << forward_thread_
     << ", forward_index: " << GetForwardIndex();
  if (thread_range_.defined()) {
    ss << ", thread_range: " << thread_range_;
  }
  ss << ")";
440
  return ss.str();
441
442
}

443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
bool LayoutNode::IsEqual(const LayoutNode *other, bool skip_index) const {
  bool ret = StructuralEqual()(this->InputShape(), other->InputShape());
  ret &= StructuralEqual()(this->OutputShape(), other->OutputShape());
  if (!skip_index) {
    ret &= StructuralEqual()(this->forward_index_, other->forward_index_);
  }
  return ret;
}

bool FragmentNode::IsEqual(const FragmentNode *other, bool skip_index) const {
  // Fragment Layout Comparison can skip the index comparison
  // when the output shape is the same, as we can do
  // a[i, j] = b[j, i] in register level.

  bool ret = StructuralEqual()(this->InputShape(), other->InputShape());
458
459
460
461
  if (!ret) {
    // may be broadcast case
    return true;
  }
462
463
464
  if (this->thread_range_.defined() && other->thread_range_.defined()) {
    ret &= StructuralEqual()(this->thread_range_, other->thread_range_);
  }
465
466
467
468
469
470
471
472
473
  ret &= StructuralEqual()(this->OutputShape(), other->OutputShape());
  ret &= StructuralEqual()(this->ReplicateExtent(), other->ReplicateExtent());
  ret &= StructuralEqual()(this->ThreadExtent(), other->ThreadExtent());
  if (!skip_index) {
    ret &= StructuralEqual()(this->forward_index_, other->forward_index_);
  }
  return ret;
}

474
475
476
477
478
479
480
void FragmentNode::RegisterReflection() {
  namespace refl = tvm::ffi::reflection;
  refl::ObjectDef<FragmentNode>()
      .def_ro("forward_thread", &FragmentNode::forward_thread_)
      .def_ro("replicate_size", &FragmentNode::replicate_size_);
}

481
TVM_FFI_STATIC_INIT_BLOCK() {
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef()
      .def_packed("tl.Layout",
                  [](PackedArgs args, Any *rv) {
                    *rv = Layout(args[0].cast<Array<IterVar>>(),
                                 args[1].cast<Array<PrimExpr>>());
                  })
      .def("tl.Layout_input_shape",
           [](Layout layout) { return layout->InputShape(); })
      .def("tl.Layout_output_shape",
           [](Layout layout) { return layout->OutputShape(); })
      .def("tl.Layout_inverse", [](Layout layout) { return layout->Inverse(); })
      .def("tl.Layout_index",
           [](Layout layout) { return layout->GetForwardIndex(); })
      .def("tl.Layout_forward_vars",
           [](Layout layout) { return layout->GetForwardVars(); })
498
499
500
501
502
      .def("tl.Layout_is_equal",
           [](Layout layout, Layout other) {
             const LayoutNode *other_node = other.as<LayoutNode>();
             return layout->IsEqual(other_node);
           })
503
504
505
506
507
508
509
510
      .def_packed("tl.Fragment",
                  [](PackedArgs args, Any *rv) {
                    *rv = Fragment(
                        /*forward_var=*/args[0].cast<Array<IterVar>>(),
                        /*forward_index=*/args[1].cast<Array<PrimExpr>>(),
                        /*forward_thread=*/args[2].cast<PrimExpr>(),
                        /*thread_replicate=*/args[3].cast<IterVar>());
                  })
511
512
513
514
515
      .def("tl.Fragment_is_equal",
           [](Fragment fragment, Fragment other) {
             const FragmentNode *other_node = other.as<FragmentNode>();
             return fragment->IsEqual(other_node);
           })
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
      .def("tl.Fragment_thread_size",
           [](Fragment fragment) { return fragment->ThreadExtent(); })
      .def("tl.Fragment_thread",
           [](Fragment fragment) { return fragment->GetForwardThread(); })
      .def("tl.Fragment_repeat",
           [](Fragment fragment, Array<PrimExpr> repeats, bool repeat_on_thread,
              bool lower_dim_first) {
             return fragment->Repeat(repeats, repeat_on_thread,
                                     lower_dim_first);
           })
      .def("tl.Fragment_replicate",
           [](Fragment fragment, int repeats) {
             return fragment->Replicate(repeats);
           })
      .def("tl.Fragment_condense_rep_var",
           [](Fragment fragment) { return fragment->CondenseReplicateVar(); })
      .def("tl.make_swizzled_layout",
533
534
535
536
537
538
539
540
541
542
           [](int stride, int continuous, int element_size, bool k_inner,
              bool allow_pad = true) {
             if (allow_pad) {
               return makeGemmABLayout(stride, continuous, continuous,
                                       element_size, k_inner);
             } else {
               return makeGemmABLayoutHopper(stride, continuous, continuous,
                                             element_size, k_inner);
             }
           })
543
544
545
546
547
      .def("tl.make_volta_swizzled_layout",
           [](int stride, int mat_continuous, bool is_a, bool k_inner) {
             return makeGemmVoltaABLayout(stride, mat_continuous, is_a,
                                          k_inner);
           })
548
549
550
551
552
      .def("tl.make_wgmma_swizzled_layout",
           [](int stride, int mat_continuous, int continuity, int element_size,
              bool k_inner) {
             return makeGemmABLayoutHopper(stride, mat_continuous, continuity,
                                           element_size, k_inner);
553
554
555
556
557
558
           })
      .def("tl.make_tcgen05mma_swizzled_layout",
           [](int stride, int mat_continuous, int continuity, int element_size,
              bool k_inner) {
             return makeGemmABLayoutSm100(stride, mat_continuous, continuity,
                                          element_size, k_inner);
559
560
           })
      .def("tl.make_full_bank_swizzled_layout",
561
           [](int stride, int continuous, int element_size) {
562
563
564
565
566
567
568
569
570
571
572
573
574
575
             return makeFullBankSwizzleLayout(stride, continuous, element_size);
           })
      .def("tl.make_half_bank_swizzled_layout",
           [](int stride, int continuous, int element_size) {
             return makeHalfBankSwizzleLayout(stride, continuous, element_size);
           })
      .def("tl.make_quarter_bank_swizzled_layout",
           [](int stride, int continuous, int element_size) {
             return makeQuarterBankSwizzleLayout(stride, continuous,
                                                 element_size);
           })
      .def("tl.make_linear_layout", [](int stride, int continuous) {
        return makeGemmLayoutLinear(stride, continuous);
      });
576
}
577

578
TVM_FFI_STATIC_INIT_BLOCK() {
579
580
581
  namespace refl = tvm::ffi::reflection;
  LayoutNode::RegisterReflection();
  FragmentNode::RegisterReflection();
582
}
583

584
585
} // namespace tl
} // namespace tvm