"INSTALL/git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "a37c6af8d083a98e64bb6c30efccf60eb3f71a6d"
layout_inference.cc 31.7 KB
Newer Older
1
2
3
4
5
/*!
 * \file layout_inference.cc
 * \brief infer the fragment/shared memory layout
 */

6
#include <tvm/ffi/reflection/registry.h>
7
#include <tvm/tir/builtin.h>
8
#include <tvm/tir/index_map.h>
9
10
11
12
13
14
15
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>

#include <queue>

16
#include "../layout/utils.h"
17
#include "../op/parallel.h"
18
#include "../op/region.h"
19
#include "arith/ir_mutator_with_analyzer.h"
20
#include "arith/ir_visitor_with_analyzer.h"
21
#include "common/loop_fusion_utils.h"
22
#include "common/loop_parallel_transform_utils.h"
23
#include "common/union_find.h"
24
#include "layout_reducer.h"
25
26
#include "loop_partition.h"
#include "loop_vectorize.h"
27
28
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
29
30
31
32

namespace tvm {
namespace tl {

33
34
35
using namespace tir;

/*!
36
 * \brief collect the mapping from the buffer var to it allocated buffer
37
 */
38
class ThreadBindingCollector : public StmtExprVisitor {
39
40
public:
  void VisitStmt_(const AttrStmtNode *op) final {
41
42
43
44
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      thread_binding_[iv->var.get()] = iv;
    }
45
46
47
    StmtExprVisitor::VisitStmt_(op);
  }

48
49
  // The thread binding map
  std::unordered_map<const VarNode *, IterVar> thread_binding_;
50
51
};

52
53
using namespace tir;
using arith::IRMutatorWithAnalyzer;
54
using arith::IRVisitorWithAnalyzer;
55
56
57
58
59
60
61

struct LayoutInferenceResult {
  Map<Buffer, Layout> layout_map;
  Map<For, Fragment> for_map;
  Map<For, PrimExpr> predicate_map;
};

62
class BufferUseDefCollector : public IRVisitorWithAnalyzer {
63
public:
64
65
  BufferUseDefCollector(bool skip_thread_partition)
      : skip_thread_partition_(skip_thread_partition) {}
66

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
  /**
   * @brief Execute a single layout-inference step for the infer node at the
   * given index.
   *
   * Runs InferLayout on the TileOperator at cur_infer_id with the provided
   * InferLevel and thread bounds, applies returned buffer->layout updates into
   * layout_map (respecting strict_layout_map constraints for fragment buffers),
   * and optionally propagates changes to dependent infer nodes by enqueueing
   * them into q and marking in_queue.
   *
   * The function mutates layout_map and, when update_queue is true, may modify
   * q and in_queue. It performs internal sanity checks via ICHECK and will
   * LOG(WARNING) for buffers that cannot be propagated; ICHECK failures abort
   * execution.
   *
   * @param cur_infer_id Index of the infer operator in infer_list_ to run (must
   * be within range).
   * @param level Inference relaxation level to pass to the operator's
   * InferLayout.
   * @param update_queue If true, discovered layout changes will enqueue
   * dependent infer nodes.
   * @param layout_map Mutable map of inferred layouts that will be updated with
   * returned layouts.
   * @param strict_layout_map Read-only map of layouts produced in the strict
   * phase; used to enforce containment checks for local.fragment buffers when
   * relaxing.
   * @param q BFS queue used to propagate dependent inference indices; new
   * indices may be pushed.
   * @param in_queue Parallel boolean vector tracking queued status; entries
   * corresponding to enqueued indices will be set to true.
   */
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
  void RunInferStep(int cur_infer_id, InferLevel level, bool update_queue,
                    LayoutMap &layout_map, const LayoutMap &strict_layout_map,
                    std::queue<int> &q, std::vector<bool> &in_queue) {
    auto num_infer = infer_list_.size();

    // Range check for cur_infer_id
    ICHECK_GE(cur_infer_id, 0) << "cur_infer_id is negative, which is invalid.";
    ICHECK_LT(cur_infer_id, num_infer)
        << "cur_infer_id " << cur_infer_id << " is out of range, must be < "
        << num_infer << ".";

    // Make sure we can safely access infer_list_[cur_infer_id] and
    // thread_var_vec_[cur_infer_id]
    auto &next = infer_list_[cur_infer_id];
    auto iter_var = thread_var_vec_[cur_infer_id];
    auto thread_bounds = thread_bounds_vec_[cur_infer_id];
    // Double-check that 'next' is valid
115
116
    ICHECK(next.defined()) << "infer_list_[" << cur_infer_id
                           << "] is null inside run_infer_step.";
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135

    // Check iter_var->dom and dom->extent
    ICHECK(iter_var.defined())
        << "thread_var_vec_[" << cur_infer_id << "] is not defined.";
    ICHECK(iter_var->dom.defined())
        << "iter_var->dom is not defined for infer_list_[" << cur_infer_id
        << "].";
    ICHECK(iter_var->dom->extent.defined())
        << "iter_var->dom->extent is not defined for infer_list_["
        << cur_infer_id << "].";

    const int64_t *extent_ptr = as_const_int(iter_var->dom->extent);
    ICHECK(extent_ptr != nullptr)
        << "iter_var->dom->extent is not a constant integer, which is "
           "required for layout inference.";

    // Run InferLayout
    auto updates = next->InferLayout(
        LayoutInferArgs{target_, thread_bounds, layout_map}, level);
136

137
138
139
140
141
142
143
144
145
146
147
148
    // Process the returned updates
    for (const auto &[buffer, layout] : updates) {
      // Basic validity checks
      ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
      ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";

      if (layout_map.count(buffer)) {
        // If new layout contains the old one, update map
        if (buffer.scope() == "local.fragment" &&
            level != InferLevel::kStrict && !strict_layout_map.count(buffer)) {
          // Actually this test has been done in ParallelOp::InferLayout
          // already. Just do it again to avoid missing implementations in other
149
          // `TileOperator`s.
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
          auto dst_layout = layout.as<Fragment>().value();
          auto src_layout = layout_map[buffer].as<Fragment>().value();
          ICHECK(dst_layout->InputDim() == src_layout->InputDim());
          Array<PrimExpr> indices;
          indices.reserve(dst_layout->InputDim());
          arith::Analyzer inner_analyzer;
          for (int i = 0; i < dst_layout->InputDim(); ++i) {
            auto x = InputPlaceholder(i);
            indices.push_back(x);
            // should be literal - literal = 0, any analyzer will work
            ICHECK(is_zero(inner_analyzer.Simplify(
                dst_layout->InputShape()[i] - src_layout->InputShape()[i])));
            inner_analyzer.Bind(x, Range(0, dst_layout->InputShape()[i]));
          }
          if (ProveFragmentContains(src_layout, dst_layout, indices, indices,
                                    inner_analyzer)) {
            layout_map.Set(buffer, layout);
            continue;
          }
        }
        // If already in map, ensure they are structurally equal
        ICHECK(StructuralEqual()(layout, layout_map[buffer]))
            << "Get different layout for " << buffer
            << "\n current layout: " << layout->DebugOutput()
            << "\n previous layout: " << layout_map[buffer]->DebugOutput();
      } else {
        // Otherwise, update map
        layout_map.Set(buffer, layout);
        if (!update_queue)
          continue;

        // Check if buffer exists in use_list_
        if (!use_list_.count(buffer)) {
          LOG(WARNING) << "Layout inference failed for buffer " << buffer
                       << ". "
                       << "The buffer cannot be inferred with current layout "
                          "inference rules.";
          continue;
        }

        // Push back into BFS queue
        for (int idx : use_list_[buffer]) {
          ICHECK_GE(idx, 0)
              << "Index in use_list_ for buffer " << buffer << " is negative.";
          ICHECK_LT(idx, num_infer)
              << "Index in use_list_ for buffer " << buffer
              << " out of range: " << idx << " >= " << num_infer << ".";

          if (!in_queue[idx] && idx != cur_infer_id) {
            in_queue[idx] = true;
            q.push(idx);
          }
        }
      }
    }
  };

  void FinishInferQueue(InferLevel level, LayoutMap &layout_map,
                        const LayoutMap &strict_layout_map, std::queue<int> &q,
                        std::vector<bool> &in_queue) {
    auto num_infer = infer_list_.size();
    while (!q.empty()) {
      int cur_infer_id = q.front();
      q.pop();
      // Range check again, just to be safe
      ICHECK_GE(cur_infer_id, 0);
      ICHECK_LT(cur_infer_id, num_infer);

      in_queue[cur_infer_id] = false;
      RunInferStep(cur_infer_id, level, true, layout_map, strict_layout_map, q,
                   in_queue);
    }
  };

224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
  /**
   * @brief Run the multi-stage layout inference and return the collected
   * results.
   *
   * Performs layout inference over the collected TileOperator entries in three
   * phases: (1) strict per-operator inference, (2) common inference via a BFS
   * propagation queue, and (3) a free-mode relaxation phase that explores
   * alternative root orderings within connected components to reduce register
   * footprint. After inference completes, verifies that all local.fragment
   * buffers have inferred layouts and collects loop (For) -> Fragment layouts
   * and any per-loop predicates discovered during inference.
   *
   * The method consumes/permutes internal inference state (notably moves
   * entries out of infer_list_) and returns a LayoutInferenceResult containing:
   * - layout_map: inferred Layout for each Buffer,
   * - for_map: mapping from For nodes to their inferred Fragment layout,
   * - predicate_map: optional loop predicates keyed by For nodes.
   *
   * The function performs internal consistency checks (ICHECK) on sizes and
   * required definitions; violations will terminate via ICHECK failure.
   *
   * @return LayoutInferenceResult A tuple-like struct with the inferred
   *         layout_map, for_map, and predicate_map.
   */
248
  LayoutInferenceResult Run() {
249
250
251
252
253
    // Basic consistency check: infer_list_ and thread_var_vec_ should have the
    // same size
    ICHECK_EQ(infer_list_.size(), thread_var_vec_.size())
        << "Size mismatch: infer_list_ and thread_var_vec_ must match in "
           "length.";
254
255
256
    ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size())
        << "Size mismatch: thread_bounds_vec_ and infer_list_ must match in "
           "length.";
257
258
259
260
261

    // If needed, you can also check that annotated_layout_map_ is not empty, or
    // anything else relevant to your setup.

    // Copy the annotated layout map to local variable
262
    Map<Buffer, Layout> layout_map = annotated_layout_map_;
263
    Map<Buffer, Layout> strict_layout_map;
264
265
    int num_infer = infer_list_.size();

266
    // Prepare BFS queue for iterative inference
267
268
    std::queue<int> q;
    std::vector<bool> in_queue(num_infer, true);
269
270
    for (int i = 0; i < num_infer; i++) {
      // Check that each infer_list_ entry is valid
271
      ICHECK(infer_list_[i].defined())
272
273
274
275
276
277
278
          << "infer_list_[" << i
          << "] is null. The inference object is not allocated properly.";

      // Check that each thread_var_vec_ entry is defined
      if (!thread_var_vec_[i].defined() && skip_thread_partition_) {
        thread_var_vec_[i] = thread_var_;
      }
279
      q.push(i);
280
    }
281

282
    // step 1: infer strict layout
283
    for (int i = 0; i < num_infer; i++) {
284
285
      RunInferStep(i, InferLevel::kStrict, false, layout_map, strict_layout_map,
                   q, in_queue);
286
287
    }

288
289
290
291
    for (const auto &[buffer, layout] : layout_map) {
      strict_layout_map.Set(buffer, layout);
    }

292
    // step 2: infer common layout with BFS
293
294
    FinishInferQueue(InferLevel::kCommon, layout_map, strict_layout_map, q,
                     in_queue);
295

296
    // step 3: relax constraints to free and re-run
297
298
    InferInFreeMode(layout_map, strict_layout_map);

299
    // Check that all local.fragment buffers have inferred layouts
300
    for (const auto &[buffer, _] : use_list_) {
301
302
303
304
305
      if (buffer.scope() == "local.fragment") {
        ICHECK_NE(layout_map.count(buffer), 0)
            << "The layout for fragment " << buffer
            << " can not be inferred correctly.";
      }
306
307
    }

308
    // Collect layout info for For nodes
309
310
    Map<For, Fragment> for_map;
    Map<For, PrimExpr> predicate_map;
311
312
313
    ICHECK(infer_list_.size() == thread_var_vec_.size())
        << "infer_list_ and thread_var_vec_ size mismatch";
    for (int i = 0; i < infer_list_.size(); i++) {
314
      TileOperator base_infer = std::move(infer_list_[i]);
315
316
      auto thread_var = thread_var_vec_[i];

317
      // Check if base_infer is valid
318
319
320
      ICHECK(base_infer.defined()) << "Null pointer encountered in "
                                      "infer_list_ while collecting for_map.";
      if (auto for_infer = base_infer.as<ParallelOpNode>()) {
321
        // Check that the loop layout is defined
322
        ICHECK(for_infer->GetLoopLayout().defined())
323
            << "The Layout for Parallel for cannot be inferred correctly:\n"
324
325
            << for_infer->GetRoot();
        for_map.Set(for_infer->GetRoot(), for_infer->GetLoopLayout());
326
        // thread_var_ should be defined if we rely on it
327
328
        ICHECK(thread_var.defined())
            << "thread_var is not defined. Cannot retrieve predicate.";
329

330
        if (auto predicate = for_infer->GetPredicate(thread_var->var)) {
331
          predicate_map.Set(for_infer->GetRoot(), predicate.value());
332
        }
333
334
335
336
337
338
      }
    }

    return {layout_map, for_map, predicate_map};
  }

339
340
  void Collect(const PrimFunc &f) {
    for (const auto &[_, buffer] : f->buffer_map) {
341
342
343
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
344
345
    ICHECK(target.defined())
        << "Layout_Inference: Require the target attribute";
346
347
348
349
    target_ = target.value();
    this->operator()(f->body);
  }

350
private:
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
  /**
   * @brief Visits a Call expression to collect tile-operator-based inference
   * inputs.
   *
   * Processes non-global function calls by parsing them into a TileOperator
   * (via ParseOperator). If the parse succeeds, records:
   * - buffers referenced by call arguments into the collector's use lists,
   * - the call AST node into infer_list_stmt_,
   * - the parsed TileOperator into infer_list_,
   * - the current thread IterVar into thread_var_vec_,
   * - the thread iteration bounds into thread_bounds_vec_ (uses analyzer const
   * bounds when available; otherwise [0,1]).
   *
   * Calls to global functions (where op->op is a GlobalVar) are ignored.
   *
   * @param op The Call node being visited.
   */
368
  void VisitExpr_(const CallNode *op) final {
369
    IRVisitorWithAnalyzer::VisitExpr_(op);
370
    // Do not analysis the call node to the global function.
371
372
    if (op->op.as<GlobalVarNode>())
      return;
373
374

    auto p = ParseOperator(GetRef<Call>(op), buffer_data_to_buffer_);
375
    if (p.defined()) {
376
      for (const auto &arg : op->args) {
377
378
379
380
        if (auto buffer = getBufferFromAccessPtr(arg)) {
          addToUseList(buffer.value());
        }
      }
381
      infer_list_stmt_.push_back(GetRef<ObjectRef>(op));
382
383
      infer_list_.push_back(std::move(p));
      thread_var_vec_.push_back(thread_var_);
384
385
386
387
      if (analyzer_.const_int_bound.IsBound(thread_var_->var)) {
        auto const_int_bound = analyzer_.const_int_bound(thread_var_);
        auto min_value = const_int_bound->min_value;
        auto max_value = const_int_bound->max_value;
388
        auto extent = max_value - min_value + 1;
389
390
        auto dtype = thread_var_->var.dtype();
        thread_bounds_vec_.push_back(Range::FromMinExtent(
391
            IntImm(dtype, min_value), IntImm(dtype, extent)));
392
393
394
      } else {
        thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
      }
395
396
397
    }
  }

398
  Optional<Buffer> getBufferFromAccessPtr(const PrimExpr &expr) {
399
    auto call = expr.as<CallNode>();
400
401
402
403
    if (!call) {
      return std::nullopt;
    }
    if (call->op.same_as(builtin::tvm_access_ptr())) {
404
405
      auto var = call->args[1].as<Var>().value();
      return buffer_data_to_buffer_[var];
406
407
    } else if (call->op.same_as(RegionOp::Get())) {
      return call->args[0].as<BufferLoadNode>()->buffer;
408
    }
409
    return std::nullopt;
410
411
  }

412
  void addToUseList(const Buffer &buffer) {
413
414
415
416
417
418
419
    int infer_idx = infer_list_.size();
    if (use_list_.find(buffer) == use_list_.end()) {
      use_list_[buffer] = {};
    }
    use_list_[buffer].push_back(infer_idx);
  }

420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
  /**
   * @brief Handles For nodes during IR traversal.
   *
   * When the loop is a parallel loop (ForKind::kParallel), records it as a
   * ParallelOp:
   * - constructs a ParallelOp for the loop and appends it to the internal infer
   * lists (infer_list_ and infer_list_stmt_),
   * - registers all buffers referenced by the loop indices with use-list
   * bookkeeping,
   * - captures the current thread IterVar context and its compile-time extent
   * (if available) into thread_var_vec_ and thread_bounds_vec_ (falls back to
   * range [0,1] when unknown).
   *
   * For non-parallel loops, continues recursive traversal into the loop body.
   *
   * Side effects:
   * - Mutates infer_list_, infer_list_stmt_, use_list_ (via addToUseList),
   * thread_var_vec_, and thread_bounds_vec_.
   */
439
  void VisitStmt_(const ForNode *op) final {
440
    if (op->kind == ForKind::kParallel) {
441
      auto infer = ParallelOp(GetRef<For>(op));
442
      for (const auto &[buffer, _] : infer->GetIndiceMap()) {
443
444
        addToUseList(buffer);
      }
445
      infer_list_stmt_.push_back(GetRef<ObjectRef>(op));
446
447
      infer_list_.push_back(std::move(infer));
      thread_var_vec_.push_back(thread_var_);
448
449
450
451
      if (thread_var_.defined() &&
          analyzer_.const_int_bound.IsBound(thread_var_->var)) {
        auto const_int_bound = analyzer_.const_int_bound(thread_var_);
        auto dtype = thread_var_->var.dtype();
452
453
        auto extent =
            const_int_bound->max_value - const_int_bound->min_value + 1;
454
        thread_bounds_vec_.push_back(Range::FromMinExtent(
455
            IntImm(dtype, const_int_bound->min_value), IntImm(dtype, extent)));
456
457
458
      } else {
        thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
      }
459
    } else {
460
      IRVisitorWithAnalyzer::VisitStmt(op->body);
461
462
463
    }
  }

464
  void VisitStmt_(const BlockNode *op) final {
465
466
467
468
    for (auto buffer : op->alloc_buffers) {
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    if (op->annotations.count(attr::kLayoutMap)) {
469
      // Check if the layout map is Map<Var, Layout>
470
471
472
      auto map =
          op->annotations.Get(attr::kLayoutMap)->as<Map<Var, Layout>>().value();
      for (const auto &[var, layout] : map) {
473
474
        ICHECK(buffer_data_to_buffer_.count(var))
            << "buffer " << var << " is not found in the block";
475
476
477
478
479
        auto buffer = buffer_data_to_buffer_[var];
        ICHECK(StructuralEqual()(layout->InputShape(), buffer->shape));
        annotated_layout_map_.Set(buffer, layout);
      }
    }
480
    IRVisitorWithAnalyzer::VisitStmt_(op);
481
482
  }

483
  void VisitStmt_(const AttrStmtNode *op) final {
484
485
486
487
488
489
490
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      if (iv->thread_tag == "threadIdx.x") {
        ICHECK(iv->dom->extent.as<IntImmNode>());
        thread_var_ = iv;
      }
    }
491
    IRVisitorWithAnalyzer::VisitStmt_(op);
492
493
494
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
495
  std::vector<ObjectRef> infer_list_stmt_;
496
  std::vector<TileOperator> infer_list_;
497
498
  std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
      use_list_;
499
500
501
502
  // This is a workaround for cpu backend,
  // we need to define a thread_var for the serial loop.
  IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
                                IterVarType::kDataPar);
503
  std::vector<IterVar> thread_var_vec_;
504
  std::vector<Range> thread_bounds_vec_;
505
506
  Target target_;
  LayoutMap annotated_layout_map_;
507
  bool skip_thread_partition_{false};
508

509
510
511
512
513
514
515
516
517
  /**
   * @brief Create a deep copy of the current inference operator list.
   *
   * Returns a vector containing clones of each TileOperator in the collector's
   * internal infer_list_. The returned list is independent of the original so
   * subsequent modifications to either do not affect the other.
   *
   * @return std::vector<TileOperator> Cloned copy of infer_list_.
   */
518
519
  std::vector<TileOperator> BackupInferList() {
    std::vector<TileOperator> back_infer_list;
520
521
522
523
524
525
526
    back_infer_list.reserve(infer_list_.size());
    for (auto &&p : infer_list_) {
      back_infer_list.push_back(p->Clone());
    }
    return back_infer_list;
  }

527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
  /**
   * @brief Explore alternative inference orders within connected components to
   * relax layouts.
   *
   * This function performs a "free-mode" exploration that attempts different
   * root operators within each connected component of the operator-use graph in
   * order to find a layout assignment with lower register (fragment) usage.
   *
   * Detailed behavior:
   * - Builds connected components of infer_list_ by unioning operators that
   * share buffer uses (use_list_).
   * - For each component, iterates each member operator as a candidate root:
   *   - Backups the current infer_list_ and uses a temporary copy of
   * layout_map.
   *   - Runs RunInferStep and FinishInferQueue in InferLevel::kFree starting
   * from the candidate root and then (as a fallback) runs the remaining members
   *     to try to cover the whole component.
   *   - If inference succeeds, computes a coarse register usage metric by
   *     summing the product of OutputShape dimensions for all Fragment layouts
   * in the temporary layout map.
   *   - Tracks the candidate that yields the smallest register usage.
   * - If a better plan is found for a component, replaces the global
   * infer_list_ and updates layout_map with the best layout_map found.
   *
   * Side effects:
   * - Mutates layout_map to the best-found free-mode layout assignment when a
   *   better plan is discovered.
   * - Mutates the member infer_list_ (backed up and restored during attempts;
   *   finally set to the best plan if found).
   *
   * Notes:
   * - LayoutConflictException and NormalizeIterException raised during attempts
   *   are caught and treated as failed attempts; they do not propagate out of
   *   this function.
   * - The register-usage metric is a heuristic (sum of fragment output element
   *   counts) used to prefer less-replicated layouts.
   *
   * @param layout_map[in,out] The current global layout map to be updated with
   * a better free-mode result if found.
   * @param strict_layout_map Read-only map of layouts inferred in strict mode,
   *                          used to constrain free-mode inference.
   */
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
  void InferInFreeMode(LayoutMap &layout_map,
                       const LayoutMap &strict_layout_map) {
    // Group operators into connected components
    UnionFind<int> uf;
    for (int i = 0; i < infer_list_.size(); i++) {
      uf.MakeSet(i);
    }
    for (const auto &[buffer, infer_indices] : use_list_) {
      if (infer_indices.empty())
        continue;

      // Union all infer_list_ indices that share the same buffer
      int first_idx = infer_indices[0];
      for (size_t i = 1; i < infer_indices.size(); i++) {
        uf.Union(first_idx, infer_indices[i]);
      }
    }
    std::unordered_map<int, std::vector<int>> components;
    for (int i = 0; i < infer_list_.size(); i++) {
      int root = uf.Find(i);
      components[root].push_back(i);
    }
591
    // Create a map from root to buffers
592
593
594
595
596
    std::unordered_map<int, std::vector<Buffer>> components_buffers;
    for (const auto &[buffer, infer_indices] : use_list_) {
      int root = uf.Find(infer_indices[0]);
      components_buffers[root].push_back(buffer);
    }
597
598
    // Keep components_buffers for debug purpose
    (void)components_buffers;
599
600
601
602
603

    // For each component, try each op as root, and determine the least
    // replicated one
    std::queue<int> q;
    std::vector<bool> in_queue(infer_list_.size(), false);
604

605
606
607
608
    for (auto &&[root, members] : components) {
      decltype(infer_list_) best_infer_list;
      LayoutMap best_layout_map;
      int64_t min_reg_num = INT64_MAX;
609

610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
      for (int attempt_infer_root : members) {
        // backup infer_list_ in class member
        auto back_infer_list = BackupInferList();
        // create temporarily used layout_map, new handle so that it copies on
        // write
        LayoutMap tmp_layout_map = layout_map;
        // infer from attempt_infer_root in free mode
        bool do_update = true;
        try {
          RunInferStep(attempt_infer_root, InferLevel::kFree, true,
                       tmp_layout_map, strict_layout_map, q, in_queue);
          FinishInferQueue(InferLevel::kFree, tmp_layout_map, strict_layout_map,
                           q, in_queue);
          // Silly workaround: we have no clue if single root will iterate over
          // the entire component, since the InferLayout implementations have
          // complicated conditioning inside and we know nothing about it.
          // This would constantly result in incomplete layouts for buffers in
          // this component. Instead of trying all combinations of root
          // selection order, we simply go through all other loops in order
          // after the first search from attempt_infer_root.
          for (int other_infer_root : members) {
            if (other_infer_root != attempt_infer_root) {
              RunInferStep(other_infer_root, InferLevel::kFree, true,
                           tmp_layout_map, strict_layout_map, q, in_queue);
              // must also be kFree here to avoid conflicts.
              FinishInferQueue(InferLevel::kFree, tmp_layout_map,
                               strict_layout_map, q, in_queue);
            }
          }
        } catch (LayoutConflictException e) {
          // such an order fails, try others
          do_update = false;
        } catch (NormalizeIterException e) {
          // such an order encounters iterators that is not normalizable, try
          // others e.g. i * 576 % 2048
          do_update = false;
        }

        if (do_update) {
          // compute total register number
          int64_t reg_num = 0;
          for (auto &&[buffer, layout] : tmp_layout_map) {
            if (auto frag = layout.as<Fragment>()) {
              int64_t frag_reg_num = 1;
              for (auto i : frag.value()->OutputShape()) {
                auto pci = as_const_int(i);
                ICHECK(pci != nullptr);
                frag_reg_num *= *pci;
              }
              reg_num += frag_reg_num;
            }
          }
          // if it's any better, update the best_* storage
          if (reg_num < min_reg_num) {
            best_infer_list = std::move(infer_list_);
            best_layout_map = tmp_layout_map;
            min_reg_num = reg_num;
          }
        }
        // recover stateful infer_list_, head on next
        infer_list_ = std::move(back_infer_list);
      }
      if (min_reg_num < INT64_MAX) {
        // now apply the best plan for this component
        infer_list_ = std::move(best_infer_list);
        layout_map = best_layout_map;
      }
    }
  }
679
680
681
};

class LayoutInferencer : public IRMutatorWithAnalyzer {
682
public:
683
  static PrimFunc Substitute(PrimFunc f, bool skip_thread_partition = false) {
684
    arith::Analyzer analyzer;
685
    PrimFuncNode *fptr = f.CopyOnWrite();
686
    fptr->body = ParallelLoopFuser::Fuse(f->body);
687
    BufferUseDefCollector collector(skip_thread_partition);
688
689
    collector.Collect(f);
    auto result = collector.Run();
690
    LayoutInferencer substituter(result, skip_thread_partition, &analyzer);
691
692
693
694
    fptr->body = substituter.VisitStmt(f->body);
    return f;
  }

695
696
private:
  LayoutInferencer(const LayoutInferenceResult result,
697
698
699
                   bool skip_thread_partition, arith::Analyzer *analyzer)
      : arith::IRMutatorWithAnalyzer(analyzer), result_(result),
        skip_thread_partition_(skip_thread_partition){};
700

701
  Stmt VisitStmt_(const BlockNode *op) final {
702
703
704
705
706
707
708
709
710
711
712
713
714
    Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));

    for (auto buffer : block->alloc_buffers) {
      if (buffer.scope() == "local.framgent") {
        ICHECK(result_.layout_map.count(buffer))
            << "Cannot inference fragment layout for " << buffer;
      }
    }
    auto block_ptr = block.CopyOnWrite();
    block_ptr->annotations.Set(attr::kLayoutMap, result_.layout_map);
    return block;
  }

715
  Stmt VisitStmt_(const ForNode *op) final {
716
717
718
719
720
721
    Map<Var, ReducerInfo> reducer_info;
    if (op->annotations.count(attr::kReducerInfo))
      reducer_info = op->annotations.Get(attr::kReducerInfo)
                         ->as<Map<Var, ReducerInfo>>()
                         .value();

722
723
    For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
    if (result_.for_map.count(GetRef<For>(op))) {
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
      auto root = GetRef<For>(op);
      // This check is a workaround to support T.Parallel for local buffers.
      // For example:
      //   for i in T.Parallel(1024):
      //     A_local[i] = A_global[i]
      // Here, A_local is a register-local buffer held independently by each
      // thread, so explicit thread binding is not required.
      //
      // We use PostOrderVisit to detect whether the buffer store targets a
      // "local" buffer, which indicates register usage and justifies skipping
      // thread binding.
      bool is_register_store = false;
      PostOrderVisit(root, [&](const ObjectRef &obj) {
        if (const auto *store = obj.as<BufferStoreNode>()) {
          if (store->buffer.scope() == "local") {
            is_register_store = true;
          }
        }
      });

744
      auto loop_layout = result_.for_map[root];
745
      bool parallel_loop = !is_register_store && !skip_thread_partition_;
746

747
      if (parallel_loop) {
748
749
750
        for_node =
            PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
      }
751
      // If none thread bindings are provided, partition the loop
752
753
754
755
756
757
758
759
760
761
762
763
764
765
      bool has_non_local = false;
      PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
        if (const auto *load = obj.as<BufferLoadNode>()) {
          String scope = load->buffer.scope();
          if (scope != "local" && scope != "local.fragment") {
            has_non_local = true;
          }
        } else if (const auto *store = obj.as<BufferStoreNode>()) {
          String scope = store->buffer.scope();
          if (scope != "local" && scope != "local.fragment") {
            has_non_local = true;
          }
        }
      });
766
767
768
769
770
771
772
773
774
      // Workaround: if reducer is presented, don't vectorize loop
      // Best solution should be isolate reduction axis out of vectorization
      bool has_reducer = false;
      PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
        if (!has_reducer)
          if (const auto *store = obj.as<BufferStoreNode>()) {
            has_reducer = reducer_info.count(store->buffer->data) != 0;
          }
      });
775

776
      if (has_non_local && !has_reducer) {
777
778
        for_node = VectorizeLoop(for_node);
      }
779

780
781
      if (result_.predicate_map.count(root) && parallel_loop) {
        return IfThenElse(result_.predicate_map[root], for_node);
782
783
784
785
786
787
788
      } else {
        return for_node;
      }
    }
    return for_node;
  }

789
  Stmt VisitStmt_(const AttrStmtNode *op) final {
790
791
792
793
794
795
796
797
798
799
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      ICHECK_NE(iv->thread_tag.length(), 0U);
      if (iv->thread_tag == "threadIdx.x") {
        thread_var_ = iv;
      }
    }
    return IRMutatorWithAnalyzer::VisitStmt_(op);
  }

800
private:
801
  const LayoutInferenceResult result_;
802
803
  IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
                                IterVarType::kDataPar);
804
  bool skip_thread_partition_{false};
805
806
807
808
809
};

tvm::transform::Pass LayoutInference() {
  using namespace tir::transform;
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
810
    f.CopyOnWrite()->body = ParallelLoopTransformer::Substitute(f->body);
811
    ThreadBindingCollector collector;
812
    collector(f->body);
813
814
    bool has_thread_binding = collector.thread_binding_.size() > 0;
    bool skip_thread_partition = !has_thread_binding;
815
    return LayoutInferencer::Substitute(std::move(f), skip_thread_partition);
816
817
818
819
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {});
}

820
821
822
823
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.LayoutInference", LayoutInference);
});
824

825
826
} // namespace tl
} // namespace tvm