layout_inference.cc 31.8 KB
Newer Older
1
2
3
4
5
/*!
 * \file layout_inference.cc
 * \brief infer the fragment/shared memory layout
 */

6
#include <tvm/ffi/reflection/registry.h>
7
#include <tvm/tir/builtin.h>
8
#include <tvm/tir/index_map.h>
9
10
11
12
13
14
15
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>

#include <queue>

16
#include "../layout/utils.h"
17
#include "../op/copy.h"
18
#include "../op/parallel.h"
19
#include "../op/region.h"
20

21
#include "arith/ir_mutator_with_analyzer.h"
22
#include "arith/ir_visitor_with_analyzer.h"
23
#include "common/loop_fusion_utils.h"
24
#include "common/loop_parallel_transform_utils.h"
25
#include "common/union_find.h"
26
#include "layout_reducer.h"
27
28
#include "loop_partition.h"
#include "loop_vectorize.h"
29
30
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
31
32
33
34

namespace tvm {
namespace tl {

35
36
37
using namespace tir;

/*!
38
 * \brief collect the mapping from the buffer var to it allocated buffer
39
 */
40
class ThreadBindingCollector : public StmtExprVisitor {
41
42
public:
  void VisitStmt_(const AttrStmtNode *op) final {
43
44
45
46
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      thread_binding_[iv->var.get()] = iv;
    }
47
48
49
    StmtExprVisitor::VisitStmt_(op);
  }

50
51
  // The thread binding map
  std::unordered_map<const VarNode *, IterVar> thread_binding_;
52
53
};

54
55
using namespace tir;
using arith::IRMutatorWithAnalyzer;
56
using arith::IRVisitorWithAnalyzer;
57
58
59
60
61
62
63

struct LayoutInferenceResult {
  Map<Buffer, Layout> layout_map;
  Map<For, Fragment> for_map;
  Map<For, PrimExpr> predicate_map;
};

64
class BufferUseDefCollector : public IRVisitorWithAnalyzer {
65
public:
66
67
  BufferUseDefCollector(bool skip_thread_partition)
      : skip_thread_partition_(skip_thread_partition) {}
68

69
70
  using arith::IRVisitorWithAnalyzer::IRVisitorWithAnalyzer;

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
  void RunInferStep(int cur_infer_id, InferLevel level, bool update_queue,
                    LayoutMap &layout_map, const LayoutMap &strict_layout_map,
                    std::queue<int> &q, std::vector<bool> &in_queue) {
    auto num_infer = infer_list_.size();

    // Range check for cur_infer_id
    ICHECK_GE(cur_infer_id, 0) << "cur_infer_id is negative, which is invalid.";
    ICHECK_LT(cur_infer_id, num_infer)
        << "cur_infer_id " << cur_infer_id << " is out of range, must be < "
        << num_infer << ".";

    // Make sure we can safely access infer_list_[cur_infer_id] and
    // thread_var_vec_[cur_infer_id]
    auto &next = infer_list_[cur_infer_id];
    auto iter_var = thread_var_vec_[cur_infer_id];
    auto thread_bounds = thread_bounds_vec_[cur_infer_id];
87
    auto buffer_oob = buffer_oob_vec_[cur_infer_id];
88
    // Double-check that 'next' is valid
89
90
    ICHECK(next.defined()) << "infer_list_[" << cur_infer_id
                           << "] is null inside run_infer_step.";
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

    // Check iter_var->dom and dom->extent
    ICHECK(iter_var.defined())
        << "thread_var_vec_[" << cur_infer_id << "] is not defined.";
    ICHECK(iter_var->dom.defined())
        << "iter_var->dom is not defined for infer_list_[" << cur_infer_id
        << "].";
    ICHECK(iter_var->dom->extent.defined())
        << "iter_var->dom->extent is not defined for infer_list_["
        << cur_infer_id << "].";

    const int64_t *extent_ptr = as_const_int(iter_var->dom->extent);
    ICHECK(extent_ptr != nullptr)
        << "iter_var->dom->extent is not a constant integer, which is "
           "required for layout inference.";

    // Run InferLayout
108
    DLOG(INFO) << "[RunInferStep] working on " << cur_infer_id << '\n';
109
110
111
112
    auto updates =
        next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map,
                                          &analyzer_, buffer_oob},
                          level);
113
114
    // Process the returned updates
    for (const auto &[buffer, layout] : updates) {
115
116
117
      DLOG(INFO) << "    consider update " << buffer << " as "
                 << layout->DebugOutput() << '\n';

118
119
120
121
122
123
124
125
126
127
      // Basic validity checks
      ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
      ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";

      if (layout_map.count(buffer)) {
        // If new layout contains the old one, update map
        if (buffer.scope() == "local.fragment" &&
            level != InferLevel::kStrict && !strict_layout_map.count(buffer)) {
          // Actually this test has been done in ParallelOp::InferLayout
          // already. Just do it again to avoid missing implementations in other
128
          // `TileOperator`s.
129
130
131
132
133

          auto dst_layout_opt = layout.as<Fragment>();
          ICHECK(dst_layout_opt.has_value())
              << "Failed to cast layout to Fragment for buffer " << buffer
              << ", layout type is " << layout->GetTypeKey();
134
          const auto &dst_layout = dst_layout_opt.value();
135
136
137
138
139
          auto src_layout_opt = layout_map[buffer].as<Fragment>();
          ICHECK(src_layout_opt.has_value())
              << "Failed to cast layout_map[buffer] to Fragment for buffer "
              << buffer << ", layout type is "
              << layout_map[buffer]->GetTypeKey();
140
          const auto &src_layout = src_layout_opt.value();
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
          ICHECK(dst_layout->InputDim() == src_layout->InputDim());
          Array<PrimExpr> indices;
          indices.reserve(dst_layout->InputDim());
          arith::Analyzer inner_analyzer;
          for (int i = 0; i < dst_layout->InputDim(); ++i) {
            auto x = InputPlaceholder(i);
            indices.push_back(x);
            // should be literal - literal = 0, any analyzer will work
            ICHECK(is_zero(inner_analyzer.Simplify(
                dst_layout->InputShape()[i] - src_layout->InputShape()[i])));
            inner_analyzer.Bind(x, Range(0, dst_layout->InputShape()[i]));
          }
          if (ProveFragmentContains(src_layout, dst_layout, indices, indices,
                                    inner_analyzer)) {
            layout_map.Set(buffer, layout);
156
157
            DLOG(INFO) << "    layout broadcast from "
                       << src_layout->DebugOutput() << ", accepted" << '\n';
158
159
160
161
            continue;
          }
        }
        // If already in map, ensure they are structurally equal
162
        ICHECK(layout->IsEqual(layout_map[buffer].get()))
163
164
165
166
167
168
            << "Get different layout for " << buffer
            << "\n current layout: " << layout->DebugOutput()
            << "\n previous layout: " << layout_map[buffer]->DebugOutput();
      } else {
        // Otherwise, update map
        layout_map.Set(buffer, layout);
169
        DLOG(INFO) << "    new layout accepted" << '\n';
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        if (!update_queue)
          continue;

        // Check if buffer exists in use_list_
        if (!use_list_.count(buffer)) {
          LOG(WARNING) << "Layout inference failed for buffer " << buffer
                       << ". "
                       << "The buffer cannot be inferred with current layout "
                          "inference rules.";
          continue;
        }

        // Push back into BFS queue
        for (int idx : use_list_[buffer]) {
          ICHECK_GE(idx, 0)
              << "Index in use_list_ for buffer " << buffer << " is negative.";
          ICHECK_LT(idx, num_infer)
              << "Index in use_list_ for buffer " << buffer
              << " out of range: " << idx << " >= " << num_infer << ".";

          if (!in_queue[idx] && idx != cur_infer_id) {
            in_queue[idx] = true;
            q.push(idx);
          }
        }
      }
    }
  };

  void FinishInferQueue(InferLevel level, LayoutMap &layout_map,
                        const LayoutMap &strict_layout_map, std::queue<int> &q,
                        std::vector<bool> &in_queue) {
    auto num_infer = infer_list_.size();
    while (!q.empty()) {
      int cur_infer_id = q.front();
      q.pop();
      // Range check again, just to be safe
      ICHECK_GE(cur_infer_id, 0);
      ICHECK_LT(cur_infer_id, num_infer);

      in_queue[cur_infer_id] = false;
      RunInferStep(cur_infer_id, level, true, layout_map, strict_layout_map, q,
                   in_queue);
    }
  };

216
  LayoutInferenceResult Run() {
217
218
219
220
221
    // Basic consistency check: infer_list_ and thread_var_vec_ should have the
    // same size
    ICHECK_EQ(infer_list_.size(), thread_var_vec_.size())
        << "Size mismatch: infer_list_ and thread_var_vec_ must match in "
           "length.";
222
223
224
    ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size())
        << "Size mismatch: thread_bounds_vec_ and infer_list_ must match in "
           "length.";
225
226
227
    ICHECK_EQ(buffer_oob_vec_.size(), infer_list_.size())
        << "Size mismatch: buffer_oob_vec_ and infer_list_ must match in "
           "length.";
228

229
230
231
232
233
    DLOG(INFO) << "[InferLayout] all participating operators:" << '\n';
    for (int i = 0; i < infer_list_stmt_.size(); ++i) {
      DLOG(INFO) << "    op " << i << ":" << infer_list_stmt_[i] << '\n';
    }

234
235
236
237
    // If needed, you can also check that annotated_layout_map_ is not empty, or
    // anything else relevant to your setup.

    // Copy the annotated layout map to local variable
238
    Map<Buffer, Layout> layout_map = annotated_layout_map_;
239
    Map<Buffer, Layout> strict_layout_map;
240
241
    int num_infer = infer_list_.size();

242
    // Prepare BFS queue for iterative inference
243
244
    std::queue<int> q;
    std::vector<bool> in_queue(num_infer, true);
245
246
    for (int i = 0; i < num_infer; i++) {
      // Check that each infer_list_ entry is valid
247
      ICHECK(infer_list_[i].defined())
248
249
250
251
252
253
254
          << "infer_list_[" << i
          << "] is null. The inference object is not allocated properly.";

      // Check that each thread_var_vec_ entry is defined
      if (!thread_var_vec_[i].defined() && skip_thread_partition_) {
        thread_var_vec_[i] = thread_var_;
      }
255
      q.push(i);
256
    }
257

258
    // step 1: infer strict layout
259
    for (int i = 0; i < num_infer; i++) {
260
261
      RunInferStep(i, InferLevel::kStrict, false, layout_map, strict_layout_map,
                   q, in_queue);
262
263
    }

264
265
266
267
    for (const auto &[buffer, layout] : layout_map) {
      strict_layout_map.Set(buffer, layout);
    }

268
    // step 2: infer common layout with BFS
269
270
    FinishInferQueue(InferLevel::kCommon, layout_map, strict_layout_map, q,
                     in_queue);
271

272
    // step 3: relax constraints to free and re-run
273
274
    InferInFreeMode(layout_map, strict_layout_map);

275
    // Check that all local.fragment buffers have inferred layouts
276
    for (const auto &[buffer, _] : use_list_) {
277
278
279
280
281
      if (buffer.scope() == "local.fragment") {
        ICHECK_NE(layout_map.count(buffer), 0)
            << "The layout for fragment " << buffer
            << " can not be inferred correctly.";
      }
282
283
    }

284
    // Collect layout info for For nodes
285
286
    Map<For, Fragment> for_map;
    Map<For, PrimExpr> predicate_map;
287
288
289
    ICHECK(infer_list_.size() == thread_var_vec_.size())
        << "infer_list_ and thread_var_vec_ size mismatch";
    for (int i = 0; i < infer_list_.size(); i++) {
290
      TileOperator base_infer = std::move(infer_list_[i]);
291
292
      auto thread_var = thread_var_vec_[i];

293
      // Check if base_infer is valid
294
295
296
      ICHECK(base_infer.defined()) << "Null pointer encountered in "
                                      "infer_list_ while collecting for_map.";
      if (auto for_infer = base_infer.as<ParallelOpNode>()) {
297
        // Check that the loop layout is defined
298
        ICHECK(for_infer->GetLoopLayout().defined())
299
            << "The Layout for Parallel for cannot be inferred correctly:\n"
300
301
            << for_infer->GetRoot();
        for_map.Set(for_infer->GetRoot(), for_infer->GetLoopLayout());
302
        // thread_var_ should be defined if we rely on it
303
304
        ICHECK(thread_var.defined())
            << "thread_var is not defined. Cannot retrieve predicate.";
305

306
        if (auto predicate = for_infer->GetPredicate(thread_var->var)) {
307
          predicate_map.Set(for_infer->GetRoot(), predicate.value());
308
        }
309
310
311
312
313
314
      }
    }

    return {layout_map, for_map, predicate_map};
  }

315
316
  void Collect(const PrimFunc &f) {
    for (const auto &[_, buffer] : f->buffer_map) {
317
318
319
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
320
321
    ICHECK(target.defined())
        << "Layout_Inference: Require the target attribute";
322
323
324
325
    target_ = target.value();
    this->operator()(f->body);
  }

326
327
private:
  void VisitExpr_(const CallNode *op) final {
328
    IRVisitorWithAnalyzer::VisitExpr_(op);
329
    // Do not analysis the call node to the global function.
330
331
    if (op->op.as<GlobalVarNode>())
      return;
332
333

    auto p = ParseOperator(GetRef<Call>(op), buffer_data_to_buffer_);
334
    if (p.defined()) {
335
      for (const auto &arg : op->args) {
336
337
338
339
        if (auto buffer = getBufferFromAccessPtr(arg)) {
          addToUseList(buffer.value());
        }
      }
340
      // Compute thread_var_ and thread_bounds_
341
      thread_var_vec_.push_back(thread_var_);
342
343
344
345
      if (analyzer_.const_int_bound.IsBound(thread_var_->var)) {
        auto const_int_bound = analyzer_.const_int_bound(thread_var_);
        auto min_value = const_int_bound->min_value;
        auto max_value = const_int_bound->max_value;
346
        auto extent = max_value - min_value + 1;
347
348
        auto dtype = thread_var_->var.dtype();
        thread_bounds_vec_.push_back(Range::FromMinExtent(
349
            IntImm(dtype, min_value), IntImm(dtype, extent)));
350
351
352
      } else {
        thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
      }
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385

      // Compute buffer oob for each buffer in the op
      if (const auto *copy = p.as<CopyNode>()) {
        auto src_tensor = copy->src;
        auto dst_tensor = copy->dst;
        auto src_range = copy->src_range;
        auto dst_range = copy->dst_range;
        bool src_oob = false;
        bool dst_oob = false;
        for (size_t i = 0; i < src_range.size(); i++) {
          if (!analyzer_.CanProve(src_range[i]->min + src_range[i]->extent <=
                                      src_tensor->shape[i],
                                  arith::ProofStrength::kSymbolicBound)) {
            src_oob = true;
            break;
          }
        }
        for (size_t i = 0; i < dst_range.size(); i++) {
          if (!analyzer_.CanProve(dst_range[i]->min + dst_range[i]->extent <=
                                      dst_tensor->shape[i],
                                  arith::ProofStrength::kSymbolicBound)) {
            dst_oob = true;
            break;
          }
        }
        buffer_oob_vec_.push_back(src_oob || dst_oob);
      } else {
        buffer_oob_vec_.push_back(false);
      }

      // Add the tile operator to infer_list_
      infer_list_stmt_.push_back(GetRef<ObjectRef>(op));
      infer_list_.push_back(std::move(p));
386
387
388
    }
  }

389
  Optional<Buffer> getBufferFromAccessPtr(const PrimExpr &expr) {
390
    auto call = expr.as<CallNode>();
391
392
393
394
    if (!call) {
      return std::nullopt;
    }
    if (call->op.same_as(builtin::tvm_access_ptr())) {
395
396
397
398
399
400
      auto var_opt = call->args[1].as<Var>();
      if (!var_opt.has_value()) {
        DLOG(WARNING) << "[getBufferFromAccessPtr] args[1] is not a Var, type: "
                      << call->args[1]->GetTypeKey();
        return std::nullopt;
      }
401
      const auto &var = var_opt.value();
402
      return buffer_data_to_buffer_[var];
403
404
    } else if (call->op.same_as(RegionOp::Get())) {
      return call->args[0].as<BufferLoadNode>()->buffer;
405
    }
406
    return std::nullopt;
407
408
  }

409
  void addToUseList(const Buffer &buffer) {
410
411
412
413
414
415
416
    int infer_idx = infer_list_.size();
    if (use_list_.find(buffer) == use_list_.end()) {
      use_list_[buffer] = {};
    }
    use_list_[buffer].push_back(infer_idx);
  }

417
  void VisitStmt_(const ForNode *op) final {
418
    if (op->kind == ForKind::kParallel) {
419
      auto infer = ParallelOp(GetRef<For>(op));
420
      for (const auto &[buffer, _] : infer->GetIndiceMap()) {
421
422
        addToUseList(buffer);
      }
423
      infer_list_stmt_.push_back(GetRef<ObjectRef>(op));
424
425
      infer_list_.push_back(std::move(infer));
      thread_var_vec_.push_back(thread_var_);
426
427
428
429
      if (thread_var_.defined() &&
          analyzer_.const_int_bound.IsBound(thread_var_->var)) {
        auto const_int_bound = analyzer_.const_int_bound(thread_var_);
        auto dtype = thread_var_->var.dtype();
430
431
        auto extent =
            const_int_bound->max_value - const_int_bound->min_value + 1;
432
        thread_bounds_vec_.push_back(Range::FromMinExtent(
433
            IntImm(dtype, const_int_bound->min_value), IntImm(dtype, extent)));
434
435
436
      } else {
        thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
      }
437
      buffer_oob_vec_.push_back(false);
438
    } else {
439
      IRVisitorWithAnalyzer::VisitStmt(op->body);
440
441
442
    }
  }

443
  void VisitStmt_(const BlockNode *op) final {
444
445
446
447
    for (auto buffer : op->alloc_buffers) {
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    if (op->annotations.count(attr::kLayoutMap)) {
448
      // Check if the layout map is Map<Var, Layout>
449
450
451
      auto map =
          op->annotations.Get(attr::kLayoutMap)->as<Map<Var, Layout>>().value();
      for (const auto &[var, layout] : map) {
452
453
        ICHECK(buffer_data_to_buffer_.count(var))
            << "buffer " << var << " is not found in the block";
454
455
456
457
458
        auto buffer = buffer_data_to_buffer_[var];
        ICHECK(StructuralEqual()(layout->InputShape(), buffer->shape));
        annotated_layout_map_.Set(buffer, layout);
      }
    }
459
    IRVisitorWithAnalyzer::VisitStmt_(op);
460
461
  }

462
  void VisitStmt_(const AttrStmtNode *op) final {
463
464
465
466
467
468
469
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      if (iv->thread_tag == "threadIdx.x") {
        ICHECK(iv->dom->extent.as<IntImmNode>());
        thread_var_ = iv;
      }
    }
470
    IRVisitorWithAnalyzer::VisitStmt_(op);
471
472
473
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
474
  std::vector<ObjectRef> infer_list_stmt_;
475
  std::vector<TileOperator> infer_list_;
476
477
  std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
      use_list_;
478
479
480
481
  // This is a workaround for cpu backend,
  // we need to define a thread_var for the serial loop.
  IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
                                IterVarType::kDataPar);
482
  std::vector<IterVar> thread_var_vec_;
483
  std::vector<Range> thread_bounds_vec_;
484
  std::vector<bool> buffer_oob_vec_;
485
486
  Target target_;
  LayoutMap annotated_layout_map_;
487
  bool skip_thread_partition_{false};
488

489
490
  std::vector<TileOperator> BackupInferList() {
    std::vector<TileOperator> back_infer_list;
491
492
493
494
495
496
497
498
499
    back_infer_list.reserve(infer_list_.size());
    for (auto &&p : infer_list_) {
      back_infer_list.push_back(p->Clone());
    }
    return back_infer_list;
  }

  void InferInFreeMode(LayoutMap &layout_map,
                       const LayoutMap &strict_layout_map) {
500
501
502
503
504
505
506

    DLOG(INFO) << "Enforced layout maps:" << '\n';
    for (auto &&[k, v] : layout_map) {
      DLOG(INFO) << "    " << k << ": " << v->DebugOutput() << '\n';
    }
    DLOG(INFO) << '\n';

507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
    // Group operators into connected components
    UnionFind<int> uf;
    for (int i = 0; i < infer_list_.size(); i++) {
      uf.MakeSet(i);
    }
    for (const auto &[buffer, infer_indices] : use_list_) {
      if (infer_indices.empty())
        continue;

      // Union all infer_list_ indices that share the same buffer
      int first_idx = infer_indices[0];
      for (size_t i = 1; i < infer_indices.size(); i++) {
        uf.Union(first_idx, infer_indices[i]);
      }
    }
    std::unordered_map<int, std::vector<int>> components;
    for (int i = 0; i < infer_list_.size(); i++) {
      int root = uf.Find(i);
      components[root].push_back(i);
    }
527
    // Create a map from root to buffers
528
529
530
531
532
    std::unordered_map<int, std::vector<Buffer>> components_buffers;
    for (const auto &[buffer, infer_indices] : use_list_) {
      int root = uf.Find(infer_indices[0]);
      components_buffers[root].push_back(buffer);
    }
533
534
    // Keep components_buffers for debug purpose
    (void)components_buffers;
535
536
537
538
539

    // For each component, try each op as root, and determine the least
    // replicated one
    std::queue<int> q;
    std::vector<bool> in_queue(infer_list_.size(), false);
540

541
    for (auto &&[root, members] : components) {
542
543
      DLOG(INFO) << "======================= processing component " << root
                 << '\n';
544
545
546
      decltype(infer_list_) best_infer_list;
      LayoutMap best_layout_map;
      int64_t min_reg_num = INT64_MAX;
547
      int min_reg_num_infer_root = -1;
548

549
      // Try each member as the root of inference for this component
550
      for (int attempt_infer_root : members) {
551
552
553
        DLOG(INFO) << "----------------------- try root " << attempt_infer_root
                   << '\n';
        // Backup the current infer_list_ state
554
        auto back_infer_list = BackupInferList();
555
        // Copy the current layout_map for temporary use
556
557
558
        LayoutMap tmp_layout_map = layout_map;
        bool do_update = true;
        try {
559
          // Run inference starting from attempt_infer_root
560
561
562
563
          RunInferStep(attempt_infer_root, InferLevel::kFree, true,
                       tmp_layout_map, strict_layout_map, q, in_queue);
          FinishInferQueue(InferLevel::kFree, tmp_layout_map, strict_layout_map,
                           q, in_queue);
564
565
566

          // After the first search, run inference for all other members in
          // order
567
568
569
570
571
572
573
574
          for (int other_infer_root : members) {
            if (other_infer_root != attempt_infer_root) {
              RunInferStep(other_infer_root, InferLevel::kFree, true,
                           tmp_layout_map, strict_layout_map, q, in_queue);
              FinishInferQueue(InferLevel::kFree, tmp_layout_map,
                               strict_layout_map, q, in_queue);
            }
          }
575
        } catch (const LayoutConflictException &e) {
576
          do_update = false;
577
578
579
          DLOG(INFO) << "attempt failed due to LayoutConflictException "
                     << e.what() << '\n';
        } catch (const NormalizeIterException &e) {
580
          do_update = false;
581
582
          DLOG(INFO) << "attempt failed due to NormalizeIterException "
                     << e.what() << '\n';
583
584
585
        }

        if (do_update) {
586
          // Compute the total register number for this layout
587
          int64_t reg_num = 0;
588
          for (const auto &[buffer, layout] : tmp_layout_map) {
589
590
591
592
593
594
595
596
597
598
            if (auto frag = layout.as<Fragment>()) {
              int64_t frag_reg_num = 1;
              for (auto i : frag.value()->OutputShape()) {
                auto pci = as_const_int(i);
                ICHECK(pci != nullptr);
                frag_reg_num *= *pci;
              }
              reg_num += frag_reg_num;
            }
          }
599
          // Update the best plan if this one uses fewer registers
600
          if (reg_num < min_reg_num) {
601
602
            best_infer_list =
                BackupInferList(); // Use backup to avoid moving out infer_list_
603
604
            best_layout_map = tmp_layout_map;
            min_reg_num = reg_num;
605
            min_reg_num_infer_root = attempt_infer_root;
606
607
          }
        }
608
        // Restore infer_list_ state for the next attempt
609
610
        infer_list_ = std::move(back_infer_list);
      }
611
612
613
614
615
616
      ICHECK(min_reg_num < INT64_MAX) << "no available layout found" << '\n';
      // Apply the best plan for this component
      infer_list_ = std::move(best_infer_list);
      layout_map = best_layout_map;
      DLOG(INFO) << "[InferInFreeMode] Final selection is attempt_infer_root = "
                 << min_reg_num_infer_root << '\n';
617
618
    }
  }
619
620
621
};

class LayoutInferencer : public IRMutatorWithAnalyzer {
622
public:
623
  static PrimFunc Substitute(PrimFunc f, bool skip_thread_partition = false) {
624
    arith::Analyzer analyzer;
625
    PrimFuncNode *fptr = f.CopyOnWrite();
626
    fptr->body = ParallelLoopFuser::Fuse(f->body);
627
    BufferUseDefCollector collector(skip_thread_partition);
628
629
    collector.Collect(f);
    auto result = collector.Run();
630
    LayoutInferencer substituter(result, skip_thread_partition, &analyzer);
631
632
633
634
    fptr->body = substituter.VisitStmt(f->body);
    return f;
  }

635
private:
636
  LayoutInferencer(const LayoutInferenceResult &result,
637
638
                   bool skip_thread_partition, arith::Analyzer *analyzer)
      : arith::IRMutatorWithAnalyzer(analyzer), result_(result),
639
        skip_thread_partition_(skip_thread_partition) {};
640

641
642
  using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer;

643
644
645
646
647
648
649
650
651
652
653
654
655
656
  /**
   * @brief Visit and mutate a Block node to attach inferred layout information.
   *
   * Converts the visited Block via the base visitor, asserts that every buffer
   * allocated with scope "local.framgent" has an inferred layout in
   * result_.layout_map, and attaches result_.layout_map to the Block's
   * annotations under attr::kLayoutMap.
   *
   * If any "local.framgent" buffer lacks an entry in result_.layout_map an
   * ICHECK will fail with the offending buffer printed.
   *
   * @return Stmt The (possibly modified) Block statement with the layout-map
   * annotation set.
   */
657
  Stmt VisitStmt_(const BlockNode *op) final {
658
659
660
661
662
663
664
665
666
667
668
669
670
    Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));

    for (auto buffer : block->alloc_buffers) {
      if (buffer.scope() == "local.framgent") {
        ICHECK(result_.layout_map.count(buffer))
            << "Cannot inference fragment layout for " << buffer;
      }
    }
    auto block_ptr = block.CopyOnWrite();
    block_ptr->annotations.Set(attr::kLayoutMap, result_.layout_map);
    return block;
  }

671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
  /**
   * @brief Visit and transform For nodes according to inferred layout
   * information.
   *
   * If the For node is present in result_.for_map, this method applies
   * loop-level layout-driven transformations: it optionally partitions the loop
   * across the thread index, vectorizes the loop body, and wraps the loop with
   * a predicate if one was inferred for the loop root.
   *
   * Detailed behavior:
   * - Reads reducer information from the For node's attr::kReducerInfo
   * annotation (if present) to detect reduction targets.
   * - Detects register-local buffer stores (buffers with scope "local") in the
   *   original loop body; if only register-local stores are present the loop is
   *   treated as a register-local scenario and is not partitioned across
   * threads.
   * - Obtains the loop layout from result_.for_map[root] and, unless the loop
   * is register-local or skip_thread_partition_ is set, partitions the loop via
   *   PartitionLoop using thread_var_ and analyzer_.
   * - Scans the transformed loop body to determine whether it accesses any
   *   non-local buffers (scopes other than "local" or "local.fragment").
   * - Scans the transformed loop body to detect reducers (based on
   * reducer_info). If a reducer is present the loop is NOT vectorized
   * (reduction axes are excluded from vectorization as a conservative
   * workaround).
   * - If the loop has non-local accesses and no reducer, the loop is vectorized
   *   via VectorizeLoop.
   * - If a predicate exists in result_.predicate_map for the loop root and the
   *   loop was partitioned, the method returns an IfThenElse surrounding the
   *   (possibly partitioned/vectorized) loop with that predicate; otherwise it
   *   returns the transformed For.
   *
   * @return The possibly transformed For statement (or an IfThenElse wrapping
   * it)
   */
706
  Stmt VisitStmt_(const ForNode *op) final {
707
708
709
710
711
712
    Map<Var, ReducerInfo> reducer_info;
    if (op->annotations.count(attr::kReducerInfo))
      reducer_info = op->annotations.Get(attr::kReducerInfo)
                         ->as<Map<Var, ReducerInfo>>()
                         .value();

713
714
    For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
    if (result_.for_map.count(GetRef<For>(op))) {
715
716
717
718
719
720
721
      auto root = GetRef<For>(op);
      // This check is a workaround to support T.Parallel for local buffers.
      // For example:
      //   for i in T.Parallel(1024):
      //     A_local[i] = A_global[i]
      // Here, A_local is a register-local buffer held independently by each
      // thread, so explicit thread binding is not required.
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
      bool store_into_local = false;
      PostOrderVisit(root, [&](const ObjectRef &obj) {
        if (const auto *store = obj.as<BufferStoreNode>()) {
          if (store->buffer.scope() == "local") {
            store_into_local = true;
          }
          // if the case is like:
          // for i in T.Parallel(1024):
          //     A_local[i] = B_global[i]
          //     A_frag[i] = A_global[i]
          // exception will be raise in Parallel::LayoutInference
        }
      });
      // This check if for the loop that only manuplates "local" buffers,
      // for i in T.Parallel(1024):
      //     A_local[i] = B_local[i]
      // Though this might be illegal
739
740
      // We use PostOrderVisit to detect whether the loop only manuplates
      // "local" buffers, which indicates register usage and justifies skipping
741
      // thread binding.
742
      bool local_register_only = true;
743
744
      PostOrderVisit(root, [&](const ObjectRef &obj) {
        if (const auto *store = obj.as<BufferStoreNode>()) {
745
746
747
748
749
750
          if (store->buffer.scope() != "local") {
            local_register_only = false;
          }
        } else if (const auto *load = obj.as<BufferLoadNode>()) {
          if (load->buffer.scope() != "local") {
            local_register_only = false;
751
752
753
754
          }
        }
      });

755
      auto loop_layout = result_.for_map[root];
756
      // FIXME: tell in-Parallel and out-of-Parallel `local`s apart
757
758
759
      // NOTE(lei): a bit ugly, we should rethink about this part in future.
      bool parallel_loop =
          !skip_thread_partition_ && !local_register_only && !store_into_local;
760

761
      if (parallel_loop) {
762
763
764
        for_node =
            PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
      }
765
      // If none thread bindings are provided, partition the loop
766
767
768
769
770
771
772
773
774
775
776
777
778
779
      bool has_non_local = false;
      PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
        if (const auto *load = obj.as<BufferLoadNode>()) {
          String scope = load->buffer.scope();
          if (scope != "local" && scope != "local.fragment") {
            has_non_local = true;
          }
        } else if (const auto *store = obj.as<BufferStoreNode>()) {
          String scope = store->buffer.scope();
          if (scope != "local" && scope != "local.fragment") {
            has_non_local = true;
          }
        }
      });
780
781
782
783
784
785
786
787
788
      // Workaround: if reducer is presented, don't vectorize loop
      // Best solution should be isolate reduction axis out of vectorization
      bool has_reducer = false;
      PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
        if (!has_reducer)
          if (const auto *store = obj.as<BufferStoreNode>()) {
            has_reducer = reducer_info.count(store->buffer->data) != 0;
          }
      });
789

790
      if (has_non_local && !has_reducer) {
791
792
        for_node = VectorizeLoop(for_node);
      }
793

794
795
      if (result_.predicate_map.count(root) && parallel_loop) {
        return IfThenElse(result_.predicate_map[root], for_node);
796
797
798
799
800
801
802
      } else {
        return for_node;
      }
    }
    return for_node;
  }

803
  Stmt VisitStmt_(const AttrStmtNode *op) final {
804
805
806
807
808
809
810
811
812
813
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      ICHECK_NE(iv->thread_tag.length(), 0U);
      if (iv->thread_tag == "threadIdx.x") {
        thread_var_ = iv;
      }
    }
    return IRMutatorWithAnalyzer::VisitStmt_(op);
  }

814
private:
815
  const LayoutInferenceResult result_;
816
817
  IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
                                IterVarType::kDataPar);
818
  bool skip_thread_partition_{false};
819
820
821
822
};

tvm::transform::Pass LayoutInference() {
  using namespace tir::transform;
823
  auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
824
    f.CopyOnWrite()->body = ParallelLoopTransformer::Substitute(f->body);
825
    ThreadBindingCollector collector;
826
    collector(f->body);
827
    bool has_thread_binding = !collector.thread_binding_.empty();
828
    bool skip_thread_partition = !has_thread_binding;
829
    return LayoutInferencer::Substitute(std::move(f), skip_thread_partition);
830
831
832
833
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {});
}

834
835
836
837
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.LayoutInference", LayoutInference);
});
838

839
840
} // namespace tl
} // namespace tvm