layout_inference.cc 23.6 KB
Newer Older
1
2
3
4
5
6
/*!
 * \file layout_inference.cc
 * \brief infer the fragment/shared memory layout
 */

#include <tvm/tir/builtin.h>
7
#include <tvm/tir/index_map.h>
8
9
10
11
12
13
14
15
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>

#include <queue>

#include "../op/parallel.h"
16
#include "arith/ir_mutator_with_analyzer.h"
17
#include "arith/ir_visitor_with_analyzer.h"
18
#include "common/loop_fusion_utils.h"
19
20
#include "loop_partition.h"
#include "loop_vectorize.h"
21
22
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
23
24
25
26

namespace tvm {
namespace tl {

27
28
29
using namespace tir;

/*!
30
 * \brief collect the mapping from the buffer var to it allocated buffer
31
 */
32
class ThreadBindingCollector : public StmtExprVisitor {
33
34
public:
  void VisitStmt_(const AttrStmtNode *op) final {
35
36
37
38
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      thread_binding_[iv->var.get()] = iv;
    }
39
40
41
    StmtExprVisitor::VisitStmt_(op);
  }

42
43
  // The thread binding map
  std::unordered_map<const VarNode *, IterVar> thread_binding_;
44
45
};

46
47
using namespace tir;
using arith::IRMutatorWithAnalyzer;
48
using arith::IRVisitorWithAnalyzer;
49

50
51
52
53
54
55
56
57
58
59
60
61
class ParallelLoopTransformer : public IRMutatorWithAnalyzer {
public:
  static Stmt Substitute(Stmt stmt, bool skip_thread_partition = false) {
    arith::Analyzer analyzer;
    ParallelLoopTransformer transformer(&analyzer);
    return transformer.VisitStmt(stmt);
  }

  ParallelLoopTransformer(arith::Analyzer *analyzer)
      : IRMutatorWithAnalyzer(analyzer) {}

  Stmt VisitStmt_(const ForNode *op) final {
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    if (op->kind != ForKind::kParallel)
      return StmtMutator::VisitStmt_(op);

    // Collect loop variables and ranges
    auto for_node = GetRef<For>(op);
    Array<Var> loop_vars;
    Array<PrimExpr> loop_extents;
    Stmt body = op->body;

    // Bind the range of outer loop variables
    analyzer_->Bind(op->loop_var, Range::FromMinExtent(0, op->extent));
    loop_vars.push_back(op->loop_var);
    loop_extents.push_back(op->extent);

    // If there are inner loops, bind their ranges as well
    while (const ForNode *inner = body.as<ForNode>()) {
      analyzer_->Bind(inner->loop_var, Range::FromMinExtent(0, inner->extent));
      loop_vars.push_back(inner->loop_var);
      loop_extents.push_back(inner->extent);
      body = inner->body;
    }
83

84
85
    ICHECK(loop_vars.size() == loop_extents.size())
        << "loop_vars and loop_extents size mismatch";
86

87
88
89
    // Collect buffer access information
    BufferAccessCollector collector;
    collector(op->body);
90

91
    PrimExpr condition;
92

93
94
95
    for (const auto &[buffer, indices] : collector.buffer_indices) {
      ICHECK(indices.size() == buffer->shape.size())
          << "indices size mismatch with buffer shape";
96

97
98
99
100
101
      for (size_t i = 0; i < indices.size(); ++i) {
        auto index = indices[i];
        auto bound = analyzer_->const_int_bound(index);
        int64_t upper_bound = bound->max_value + 1;
        int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;
102

103
104
105
106
107
108
109
110
111
112
113
        // Collect the variables that used in the index
        std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_vars;
        // post order visit the index
        PostOrderVisit(index, [&](const ObjectRef &obj) {
          if (const VarNode *v = obj.as<VarNode>()) {
            used_vars.insert(GetRef<Var>(v));
          }
        });
        if (used_vars.size() == 0) {
          continue;
        }
114

115
116
117
118
119
120
121
        // find related loop vars
        Array<Var> related_loop_vars;
        for (size_t j = 0; j < loop_vars.size(); ++j) {
          auto loop_var = loop_vars[j];
          // if find related, pop the loop_vars and loop_extents
          if (used_vars.count(loop_var)) {
            related_loop_vars.push_back(loop_var);
122
          }
123
124
125
126
127
128
          ICHECK(related_loop_vars.size() <= 1)
              << "Only one related loop var is supported currently, but got "
              << related_loop_vars
              << " implement multiple loop vars may not be "
              << "too hard, please send an issue if you need "
              << "came up with this message.";
129

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
          auto bound = analyzer_->const_int_bound(index);
          int64_t upper_bound = bound->max_value + 1;
          int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;
          if (upper_bound < shape) {
            PrimExpr predicate = LT(index, IntImm(index.dtype(), upper_bound));
            condition =
                condition.defined() ? And(condition, predicate) : predicate;

            // replace the buffer index from A[i, r * 2] with A[i, j]
            // where r is the original index, j is the loop_var
            auto index_map = tir::IndexMap({loop_var}, {index});
            auto inverse_index_map = index_map.Inverse(
                {Range::FromMinExtent(0, IntImm(index.dtype(), upper_bound))},
                analyzer_);

            loop_extents.Set(i, IntImm(index.dtype(), shape));
            body = tir::Substitute(body,
                                   {{loop_var, inverse_index_map->MapIndices(
                                                   {loop_var}, analyzer_)[0]}});
149
150
151
          }
        }
      }
152
153
154
155
156
157
158
    }
    if (condition.defined()) {
      body = IfThenElse(condition, body);
      for (int j = loop_vars.size() - 1; j >= 0; --j) {
        auto loop_var = loop_vars[j];
        auto loop_extent = loop_extents[j];
        body = For(loop_var, 0, loop_extent, ForKind::kParallel, body);
159
      }
160
      return Downcast<For>(body);
161
    }
162
163
    // Only traverse the outer loop
    return for_node;
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
  }

private:
  // Helper class for collecting buffer access information, only counts fragment
  // buffer access
  class BufferAccessCollector : public StmtExprVisitor {
  public:
    void VisitExpr_(const BufferLoadNode *op) final {
      if (op->buffer.scope() == "local.fragment") {
        if (buffer_indices.find(op->buffer) == buffer_indices.end()) {
          buffer_indices[op->buffer] = op->indices;
        } else {
          // check equal
          ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices))
              << "indices mismatch for buffer: " << op->buffer;
        }
      }
      StmtExprVisitor::VisitExpr_(op);
    }

    void VisitStmt_(const BufferStoreNode *op) final {
      if (op->buffer.scope() == "local.fragment") {
        if (buffer_indices.find(op->buffer) == buffer_indices.end()) {
          buffer_indices[op->buffer] = op->indices;
        } else {
          // check equal
          ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices))
              << "indices mismatch for buffer: " << op->buffer;
        }
      }
      StmtExprVisitor::VisitStmt_(op);
    }

    std::unordered_map<Buffer, Array<PrimExpr>, ObjectPtrHash, ObjectPtrEqual>
        buffer_indices;
  };
};

202
203
204
205
206
207
struct LayoutInferenceResult {
  Map<Buffer, Layout> layout_map;
  Map<For, Fragment> for_map;
  Map<For, PrimExpr> predicate_map;
};

208
class BufferUseDefCollector : public IRVisitorWithAnalyzer {
209
public:
210
211
  BufferUseDefCollector(bool skip_thread_partition)
      : skip_thread_partition_(skip_thread_partition) {}
212
213

  LayoutInferenceResult Run() {
214
215
216
217
218
    // Basic consistency check: infer_list_ and thread_var_vec_ should have the
    // same size
    ICHECK_EQ(infer_list_.size(), thread_var_vec_.size())
        << "Size mismatch: infer_list_ and thread_var_vec_ must match in "
           "length.";
219
220
221
    ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size())
        << "Size mismatch: thread_bounds_vec_ and infer_list_ must match in "
           "length.";
222
223
224
225
226

    // If needed, you can also check that annotated_layout_map_ is not empty, or
    // anything else relevant to your setup.

    // Copy the annotated layout map to local variable
227
228
229
    Map<Buffer, Layout> layout_map = annotated_layout_map_;
    int num_infer = infer_list_.size();

230
    // Prepare BFS queue for iterative inference
231
232
    std::queue<int> q;
    std::vector<bool> in_queue(num_infer, true);
233
234
235
236
237
238
239
240
241
242
    for (int i = 0; i < num_infer; i++) {
      // Check that each infer_list_ entry is valid
      ICHECK(infer_list_[i] != nullptr)
          << "infer_list_[" << i
          << "] is null. The inference object is not allocated properly.";

      // Check that each thread_var_vec_ entry is defined
      if (!thread_var_vec_[i].defined() && skip_thread_partition_) {
        thread_var_vec_[i] = thread_var_;
      }
243
      q.push(i);
244
    }
245
246
    auto run_infer_step = [&](int cur_infer_id, InferLevel level,
                              bool update_queue) {
247
248
249
250
251
252
253
254
255
      // Range check for cur_infer_id
      ICHECK_GE(cur_infer_id, 0)
          << "cur_infer_id is negative, which is invalid.";
      ICHECK_LT(cur_infer_id, num_infer)
          << "cur_infer_id " << cur_infer_id << " is out of range, must be < "
          << num_infer << ".";

      // Make sure we can safely access infer_list_[cur_infer_id] and
      // thread_var_vec_[cur_infer_id]
256
      auto &next = infer_list_[cur_infer_id];
257
      auto iter_var = thread_var_vec_[cur_infer_id];
258
      auto thread_bounds = thread_bounds_vec_[cur_infer_id];
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
      // Double-check that 'next' is valid
      ICHECK(next != nullptr) << "infer_list_[" << cur_infer_id
                              << "] is null inside run_infer_step.";

      // Check iter_var->dom and dom->extent
      ICHECK(iter_var.defined())
          << "thread_var_vec_[" << cur_infer_id << "] is not defined.";
      ICHECK(iter_var->dom.defined())
          << "iter_var->dom is not defined for infer_list_[" << cur_infer_id
          << "].";
      ICHECK(iter_var->dom->extent.defined())
          << "iter_var->dom->extent is not defined for infer_list_["
          << cur_infer_id << "].";

      const int64_t *extent_ptr = as_const_int(iter_var->dom->extent);
      ICHECK(extent_ptr != nullptr)
          << "iter_var->dom->extent is not a constant integer, which is "
             "required for layout inference.";

      // Run InferLayout
279
      auto updates = next->InferLayout(
280
          LayoutInferArgs{target_, thread_bounds, layout_map}, level);
281
      // Process the returned updates
282
      for (const auto &[buffer, layout] : updates) {
283
284
285
286
        // Basic validity checks
        ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
        ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";

287
        if (layout_map.count(buffer)) {
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
          // If replicate size of this buffer is greater than the old one
          if (buffer.scope() == "local.fragment" &&
              level != InferLevel::kStrict) {
            const FragmentNode *dst_layout = layout.as<Fragment>().get();
            const FragmentNode *src_layout =
                layout_map[buffer].as<Fragment>().get();
            if (as_const_int(dst_layout->ReplicateExtent()) &&
                as_const_int(src_layout->ReplicateExtent()) &&
                (*as_const_int(dst_layout->ReplicateExtent()) >
                 *as_const_int(src_layout->ReplicateExtent()))) {
              // update map
              layout_map.Set(buffer, layout);
              continue;
            }
          }
303
          // If already in map, ensure they are structurally equal
304
          ICHECK(StructuralEqual()(layout, layout_map[buffer]))
305
              << "Get different layout for " << buffer
306
307
              << "\n current layout: " << layout->DebugOutput()
              << "\n previous layout: " << layout_map[buffer]->DebugOutput();
308
        } else {
309
          // Otherwise, update map
310
          layout_map.Set(buffer, layout);
311
312
          if (!update_queue)
            continue;
313
314

          // Check if buffer exists in use_list_
315
          if (!use_list_.count(buffer)) {
316
317
318
319
            LOG(WARNING) << "Layout inference failed for buffer " << buffer
                         << ". "
                         << "The buffer cannot be inferred with current layout "
                            "inference rules.";
320
321
            continue;
          }
322
323

          // Push back into BFS queue
324
          for (int idx : use_list_[buffer]) {
325
326
327
328
329
330
            ICHECK_GE(idx, 0) << "Index in use_list_ for buffer " << buffer
                              << " is negative.";
            ICHECK_LT(idx, num_infer)
                << "Index in use_list_ for buffer " << buffer
                << " out of range: " << idx << " >= " << num_infer << ".";

331
332
333
334
335
336
337
338
            if (!in_queue[idx] && idx != cur_infer_id) {
              in_queue[idx] = true;
              q.push(idx);
            }
          }
        }
      }
    };
339

340
341
342
343
    auto finish_infer_queue = [&]() {
      while (!q.empty()) {
        int cur_infer_id = q.front();
        q.pop();
344
345
346
347
        // Range check again, just to be safe
        ICHECK_GE(cur_infer_id, 0);
        ICHECK_LT(cur_infer_id, num_infer);

348
349
350
351
352
        in_queue[cur_infer_id] = false;
        run_infer_step(cur_infer_id, InferLevel::kCommon, true);
      }
    };

353
    // step 1: infer strict layout
354
355
356
357
    for (int i = 0; i < num_infer; i++) {
      run_infer_step(i, InferLevel::kStrict, false);
    }

358
    // step 2: infer common layout with BFS
359
    finish_infer_queue();
360

361
    // step 3: relax constraints to free and re-run
362
363
364
365
366
    for (int i = 0; i < num_infer; i++) {
      run_infer_step(i, InferLevel::kFree, true);
      finish_infer_queue();
    }

367
    // Check that all local.fragment buffers have inferred layouts
368
    for (const auto &[buffer, _] : use_list_) {
369
370
371
372
373
      if (buffer.scope() == "local.fragment") {
        ICHECK_NE(layout_map.count(buffer), 0)
            << "The layout for fragment " << buffer
            << " can not be inferred correctly.";
      }
374
375
    }

376
    // Collect layout info for For nodes
377
378
    Map<For, Fragment> for_map;
    Map<For, PrimExpr> predicate_map;
379
380
381
382
383
384
    ICHECK(infer_list_.size() == thread_var_vec_.size())
        << "infer_list_ and thread_var_vec_ size mismatch";
    for (int i = 0; i < infer_list_.size(); i++) {
      std::unique_ptr<Operator> base_infer = std::move(infer_list_[i]);
      auto thread_var = thread_var_vec_[i];

385
386
387
      // Check if base_infer is valid
      ICHECK(base_infer != nullptr) << "Null pointer encountered in "
                                       "infer_list_ while collecting for_map.";
388
      if (auto for_infer = dynamic_cast<ParallelOp *>(base_infer.get())) {
389
        // Check that the loop layout is defined
390
        ICHECK(for_infer->GetLoopLayout().defined())
391
            << "The Layout for Parallel for cannot be inferred correctly:\n"
392
393
            << for_infer->GetRoot();
        for_map.Set(for_infer->GetRoot(), for_infer->GetLoopLayout());
394
        // thread_var_ should be defined if we rely on it
395
396
        ICHECK(thread_var.defined())
            << "thread_var is not defined. Cannot retrieve predicate.";
397

398
        if (auto predicate = for_infer->GetPredicate(thread_var->var)) {
399
          predicate_map.Set(for_infer->GetRoot(), predicate.value());
400
        }
401
402
403
404
405
406
      }
    }

    return {layout_map, for_map, predicate_map};
  }

407
408
  void Collect(const PrimFunc &f) {
    for (const auto &[_, buffer] : f->buffer_map) {
409
410
411
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
412
413
    ICHECK(target.defined())
        << "Layout_Inference: Require the target attribute";
414
415
416
417
    target_ = target.value();
    this->operator()(f->body);
  }

418
419
private:
  void VisitExpr_(const CallNode *op) final {
420
    IRVisitorWithAnalyzer::VisitExpr_(op);
421
    // Do not analysis the call node to the global function.
422
423
    if (op->op.as<GlobalVarNode>())
      return;
424
425
426

    auto p = ParseOperator(GetRef<Call>(op), buffer_data_to_buffer_);
    if (p != nullptr) {
427
      for (const auto &arg : op->args) {
428
429
430
431
432
433
        if (auto buffer = getBufferFromAccessPtr(arg)) {
          addToUseList(buffer.value());
        }
      }
      infer_list_.push_back(std::move(p));
      thread_var_vec_.push_back(thread_var_);
434
435
436
437
      if (analyzer_.const_int_bound.IsBound(thread_var_->var)) {
        auto const_int_bound = analyzer_.const_int_bound(thread_var_);
        auto min_value = const_int_bound->min_value;
        auto max_value = const_int_bound->max_value;
438
        auto extent = max_value - min_value + 1;
439
440
        auto dtype = thread_var_->var.dtype();
        thread_bounds_vec_.push_back(Range::FromMinExtent(
441
            IntImm(dtype, min_value), IntImm(dtype, extent)));
442
443
444
      } else {
        thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
      }
445
446
447
    }
  }

448
  Optional<Buffer> getBufferFromAccessPtr(const PrimExpr &expr) {
449
450
451
452
453
454
455
456
    auto call = expr.as<CallNode>();
    if (call && call->op.same_as(builtin::tvm_access_ptr())) {
      auto var = call->args[1].as<Var>().value();
      return buffer_data_to_buffer_[var];
    }
    return NullOpt;
  }

457
  void addToUseList(const Buffer &buffer) {
458
459
460
461
462
463
464
    int infer_idx = infer_list_.size();
    if (use_list_.find(buffer) == use_list_.end()) {
      use_list_[buffer] = {};
    }
    use_list_[buffer].push_back(infer_idx);
  }

465
  void VisitStmt_(const ForNode *op) final {
466
467
    if (op->kind == ForKind::kParallel) {
      auto infer = std::make_unique<ParallelOp>(GetRef<For>(op));
468
      for (const auto &[buffer, _] : infer->GetIndiceMap()) {
469
470
471
472
        addToUseList(buffer);
      }
      infer_list_.push_back(std::move(infer));
      thread_var_vec_.push_back(thread_var_);
473
474
475
476
      if (thread_var_.defined() &&
          analyzer_.const_int_bound.IsBound(thread_var_->var)) {
        auto const_int_bound = analyzer_.const_int_bound(thread_var_);
        auto dtype = thread_var_->var.dtype();
477
478
        auto extent =
            const_int_bound->max_value - const_int_bound->min_value + 1;
479
        thread_bounds_vec_.push_back(Range::FromMinExtent(
480
            IntImm(dtype, const_int_bound->min_value), IntImm(dtype, extent)));
481
482
483
      } else {
        thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
      }
484
    } else {
485
      IRVisitorWithAnalyzer::VisitStmt(op->body);
486
487
488
    }
  }

489
  void VisitStmt_(const BlockNode *op) final {
490
491
492
493
    for (auto buffer : op->alloc_buffers) {
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    if (op->annotations.count(attr::kLayoutMap)) {
494
495
496
      auto map =
          op->annotations.Get(attr::kLayoutMap).as<Map<Var, Layout>>().value();
      for (const auto &[var, layout] : map) {
497
498
        ICHECK(buffer_data_to_buffer_.count(var))
            << "buffer " << var << " is not found in the block";
499
500
501
502
503
        auto buffer = buffer_data_to_buffer_[var];
        ICHECK(StructuralEqual()(layout->InputShape(), buffer->shape));
        annotated_layout_map_.Set(buffer, layout);
      }
    }
504
    IRVisitorWithAnalyzer::VisitStmt_(op);
505
506
  }

507
  void VisitStmt_(const AttrStmtNode *op) final {
508
509
510
511
512
513
514
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      if (iv->thread_tag == "threadIdx.x") {
        ICHECK(iv->dom->extent.as<IntImmNode>());
        thread_var_ = iv;
      }
    }
515
    IRVisitorWithAnalyzer::VisitStmt_(op);
516
517
518
519
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  std::vector<std::unique_ptr<Operator>> infer_list_;
520
521
  std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
      use_list_;
522
523
524
525
  // This is a workaround for cpu backend,
  // we need to define a thread_var for the serial loop.
  IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
                                IterVarType::kDataPar);
526
  std::vector<IterVar> thread_var_vec_;
527
  std::vector<Range> thread_bounds_vec_;
528
529
  Target target_;
  LayoutMap annotated_layout_map_;
530
  bool skip_thread_partition_{false};
531
532
533
};

class LayoutInferencer : public IRMutatorWithAnalyzer {
534
public:
535
  static PrimFunc Substitute(PrimFunc f, bool skip_thread_partition = false) {
536
    arith::Analyzer analyzer;
537
    PrimFuncNode *fptr = f.CopyOnWrite();
538
    fptr->body = ParallelLoopFuser::Fuse(f->body);
539
    BufferUseDefCollector collector(skip_thread_partition);
540
541
    collector.Collect(f);
    auto result = collector.Run();
542
    LayoutInferencer substituter(result, skip_thread_partition, &analyzer);
543
544
545
546
    fptr->body = substituter.VisitStmt(f->body);
    return f;
  }

547
548
private:
  LayoutInferencer(const LayoutInferenceResult result,
549
550
551
                   bool skip_thread_partition, arith::Analyzer *analyzer)
      : arith::IRMutatorWithAnalyzer(analyzer), result_(result),
        skip_thread_partition_(skip_thread_partition){};
552

553
  Stmt VisitStmt_(const BlockNode *op) final {
554
555
556
557
558
559
560
561
562
563
564
565
566
    Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));

    for (auto buffer : block->alloc_buffers) {
      if (buffer.scope() == "local.framgent") {
        ICHECK(result_.layout_map.count(buffer))
            << "Cannot inference fragment layout for " << buffer;
      }
    }
    auto block_ptr = block.CopyOnWrite();
    block_ptr->annotations.Set(attr::kLayoutMap, result_.layout_map);
    return block;
  }

567
  Stmt VisitStmt_(const ForNode *op) final {
568
569
    For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
    if (result_.for_map.count(GetRef<For>(op))) {
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
      auto root = GetRef<For>(op);
      // This check is a workaround to support T.Parallel for local buffers.
      // For example:
      //   for i in T.Parallel(1024):
      //     A_local[i] = A_global[i]
      // Here, A_local is a register-local buffer held independently by each
      // thread, so explicit thread binding is not required.
      //
      // We use PostOrderVisit to detect whether the buffer store targets a
      // "local" buffer, which indicates register usage and justifies skipping
      // thread binding.
      bool is_register_store = false;
      PostOrderVisit(root, [&](const ObjectRef &obj) {
        if (const auto *store = obj.as<BufferStoreNode>()) {
          if (store->buffer.scope() == "local") {
            is_register_store = true;
          }
        }
      });

590
      auto loop_layout = result_.for_map[root];
591
592
      bool parallel_loop = !is_register_store && !skip_thread_partition_;
      if (parallel_loop) {
593
594
595
        for_node =
            PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
      }
596
      // If none thread bindings are provided, partition the loop
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
      bool has_non_local = false;
      PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
        if (const auto *load = obj.as<BufferLoadNode>()) {
          String scope = load->buffer.scope();
          if (scope != "local" && scope != "local.fragment") {
            has_non_local = true;
          }
        } else if (const auto *store = obj.as<BufferStoreNode>()) {
          String scope = store->buffer.scope();
          if (scope != "local" && scope != "local.fragment") {
            has_non_local = true;
          }
        }
      });

      if (has_non_local) {
        for_node = VectorizeLoop(for_node);
      }
615

616
617
      if (result_.predicate_map.count(root) && parallel_loop) {
        return IfThenElse(result_.predicate_map[root], for_node);
618
619
620
621
622
623
624
      } else {
        return for_node;
      }
    }
    return for_node;
  }

625
  Stmt VisitStmt_(const AttrStmtNode *op) final {
626
627
628
629
630
631
632
633
634
635
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      ICHECK_NE(iv->thread_tag.length(), 0U);
      if (iv->thread_tag == "threadIdx.x") {
        thread_var_ = iv;
      }
    }
    return IRMutatorWithAnalyzer::VisitStmt_(op);
  }

636
private:
637
  const LayoutInferenceResult result_;
638
639
  IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
                                IterVarType::kDataPar);
640
  bool skip_thread_partition_{false};
641
642
643
644
645
};

tvm::transform::Pass LayoutInference() {
  using namespace tir::transform;
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
646
    f.CopyOnWrite()->body = ParallelLoopTransformer::Substitute(f->body);
647
    ThreadBindingCollector collector;
648
    collector(f->body);
649
650
    bool has_thread_binding = collector.thread_binding_.size() > 0;
    bool skip_thread_partition = !has_thread_binding;
651
    return LayoutInferencer::Substitute(std::move(f), skip_thread_partition);
652
653
654
655
656
657
658
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {});
}

TVM_REGISTER_GLOBAL("tl.transform.LayoutInference")
    .set_body_typed(LayoutInference);

659
660
} // namespace tl
} // namespace tvm