layout_inference.cc 22.7 KB
Newer Older
1
2
3
4
5
6
/*!
 * \file layout_inference.cc
 * \brief infer the fragment/shared memory layout
 */

#include <tvm/tir/builtin.h>
7
#include <tvm/tir/index_map.h>
8
9
10
11
12
13
14
15
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>

#include <queue>

#include "../op/parallel.h"
16
#include "arith/ir_mutator_with_analyzer.h"
17
#include "arith/ir_visitor_with_analyzer.h"
18
#include "common/loop_fusion_utils.h"
19
20
#include "loop_partition.h"
#include "loop_vectorize.h"
21
22
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
23
24
25
26

namespace tvm {
namespace tl {

27
28
29
using namespace tir;

/*!
30
 * \brief collect the mapping from the buffer var to it allocated buffer
31
 */
32
class ThreadBindingCollector : public StmtExprVisitor {
33
34
public:
  void VisitStmt_(const AttrStmtNode *op) final {
35
36
37
38
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      thread_binding_[iv->var.get()] = iv;
    }
39
40
41
    StmtExprVisitor::VisitStmt_(op);
  }

42
43
  // The thread binding map
  std::unordered_map<const VarNode *, IterVar> thread_binding_;
44
45
};

46
47
using namespace tir;
using arith::IRMutatorWithAnalyzer;
48
using arith::IRVisitorWithAnalyzer;
49

50
51
52
53
54
55
56
57
58
59
60
61
class ParallelLoopTransformer : public IRMutatorWithAnalyzer {
public:
  static Stmt Substitute(Stmt stmt, bool skip_thread_partition = false) {
    arith::Analyzer analyzer;
    ParallelLoopTransformer transformer(&analyzer);
    return transformer.VisitStmt(stmt);
  }

  ParallelLoopTransformer(arith::Analyzer *analyzer)
      : IRMutatorWithAnalyzer(analyzer) {}

  Stmt VisitStmt_(const ForNode *op) final {
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    if (op->kind != ForKind::kParallel)
      return StmtMutator::VisitStmt_(op);

    // Collect loop variables and ranges
    auto for_node = GetRef<For>(op);
    Array<Var> loop_vars;
    Array<PrimExpr> loop_extents;
    Stmt body = op->body;

    // Bind the range of outer loop variables
    analyzer_->Bind(op->loop_var, Range::FromMinExtent(0, op->extent));
    loop_vars.push_back(op->loop_var);
    loop_extents.push_back(op->extent);

    // If there are inner loops, bind their ranges as well
    while (const ForNode *inner = body.as<ForNode>()) {
      analyzer_->Bind(inner->loop_var, Range::FromMinExtent(0, inner->extent));
      loop_vars.push_back(inner->loop_var);
      loop_extents.push_back(inner->extent);
      body = inner->body;
    }
83

84
85
    ICHECK(loop_vars.size() == loop_extents.size())
        << "loop_vars and loop_extents size mismatch";
86

87
88
89
    // Collect buffer access information
    BufferAccessCollector collector;
    collector(op->body);
90

91
    PrimExpr condition;
92

93
94
95
    for (const auto &[buffer, indices] : collector.buffer_indices) {
      ICHECK(indices.size() == buffer->shape.size())
          << "indices size mismatch with buffer shape";
96

97
98
99
100
101
      for (size_t i = 0; i < indices.size(); ++i) {
        auto index = indices[i];
        auto bound = analyzer_->const_int_bound(index);
        int64_t upper_bound = bound->max_value + 1;
        int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;
102

103
104
105
106
107
108
109
110
111
112
113
        // Collect the variables that used in the index
        std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_vars;
        // post order visit the index
        PostOrderVisit(index, [&](const ObjectRef &obj) {
          if (const VarNode *v = obj.as<VarNode>()) {
            used_vars.insert(GetRef<Var>(v));
          }
        });
        if (used_vars.size() == 0) {
          continue;
        }
114

115
116
117
118
119
120
121
        // find related loop vars
        Array<Var> related_loop_vars;
        for (size_t j = 0; j < loop_vars.size(); ++j) {
          auto loop_var = loop_vars[j];
          // if find related, pop the loop_vars and loop_extents
          if (used_vars.count(loop_var)) {
            related_loop_vars.push_back(loop_var);
122
          }
123
124
125
126
127
128
          ICHECK(related_loop_vars.size() <= 1)
              << "Only one related loop var is supported currently, but got "
              << related_loop_vars
              << " implement multiple loop vars may not be "
              << "too hard, please send an issue if you need "
              << "came up with this message.";
129

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
          auto bound = analyzer_->const_int_bound(index);
          int64_t upper_bound = bound->max_value + 1;
          int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;
          if (upper_bound < shape) {
            PrimExpr predicate = LT(index, IntImm(index.dtype(), upper_bound));
            condition =
                condition.defined() ? And(condition, predicate) : predicate;

            // replace the buffer index from A[i, r * 2] with A[i, j]
            // where r is the original index, j is the loop_var
            auto index_map = tir::IndexMap({loop_var}, {index});
            auto inverse_index_map = index_map.Inverse(
                {Range::FromMinExtent(0, IntImm(index.dtype(), upper_bound))},
                analyzer_);

            loop_extents.Set(i, IntImm(index.dtype(), shape));
            body = tir::Substitute(body,
                                   {{loop_var, inverse_index_map->MapIndices(
                                                   {loop_var}, analyzer_)[0]}});
149
150
151
          }
        }
      }
152
153
154
155
156
157
158
    }
    if (condition.defined()) {
      body = IfThenElse(condition, body);
      for (int j = loop_vars.size() - 1; j >= 0; --j) {
        auto loop_var = loop_vars[j];
        auto loop_extent = loop_extents[j];
        body = For(loop_var, 0, loop_extent, ForKind::kParallel, body);
159
      }
160
      return Downcast<For>(body);
161
    }
162
163
    // Only traverse the outer loop
    return for_node;
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
  }

private:
  // Helper class for collecting buffer access information, only counts fragment
  // buffer access
  class BufferAccessCollector : public StmtExprVisitor {
  public:
    void VisitExpr_(const BufferLoadNode *op) final {
      if (op->buffer.scope() == "local.fragment") {
        if (buffer_indices.find(op->buffer) == buffer_indices.end()) {
          buffer_indices[op->buffer] = op->indices;
        } else {
          // check equal
          ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices))
              << "indices mismatch for buffer: " << op->buffer;
        }
      }
      StmtExprVisitor::VisitExpr_(op);
    }

    void VisitStmt_(const BufferStoreNode *op) final {
      if (op->buffer.scope() == "local.fragment") {
        if (buffer_indices.find(op->buffer) == buffer_indices.end()) {
          buffer_indices[op->buffer] = op->indices;
        } else {
          // check equal
          ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices))
              << "indices mismatch for buffer: " << op->buffer;
        }
      }
      StmtExprVisitor::VisitStmt_(op);
    }

    std::unordered_map<Buffer, Array<PrimExpr>, ObjectPtrHash, ObjectPtrEqual>
        buffer_indices;
  };
};

202
203
204
205
206
207
struct LayoutInferenceResult {
  Map<Buffer, Layout> layout_map;
  Map<For, Fragment> for_map;
  Map<For, PrimExpr> predicate_map;
};

208
class BufferUseDefCollector : public IRVisitorWithAnalyzer {
209
public:
210
211
  BufferUseDefCollector(bool skip_thread_partition)
      : skip_thread_partition_(skip_thread_partition) {}
212
213

  LayoutInferenceResult Run() {
214
215
216
217
218
    // Basic consistency check: infer_list_ and thread_var_vec_ should have the
    // same size
    ICHECK_EQ(infer_list_.size(), thread_var_vec_.size())
        << "Size mismatch: infer_list_ and thread_var_vec_ must match in "
           "length.";
219
220
221
    ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size())
        << "Size mismatch: thread_bounds_vec_ and infer_list_ must match in "
           "length.";
222
223
224
225
226

    // If needed, you can also check that annotated_layout_map_ is not empty, or
    // anything else relevant to your setup.

    // Copy the annotated layout map to local variable
227
228
229
    Map<Buffer, Layout> layout_map = annotated_layout_map_;
    int num_infer = infer_list_.size();

230
    // Prepare BFS queue for iterative inference
231
232
    std::queue<int> q;
    std::vector<bool> in_queue(num_infer, true);
233
234
235
236
237
238
239
240
241
242
    for (int i = 0; i < num_infer; i++) {
      // Check that each infer_list_ entry is valid
      ICHECK(infer_list_[i] != nullptr)
          << "infer_list_[" << i
          << "] is null. The inference object is not allocated properly.";

      // Check that each thread_var_vec_ entry is defined
      if (!thread_var_vec_[i].defined() && skip_thread_partition_) {
        thread_var_vec_[i] = thread_var_;
      }
243
      q.push(i);
244
    }
245
246
    auto run_infer_step = [&](int cur_infer_id, InferLevel level,
                              bool update_queue) {
247
248
249
250
251
252
253
254
255
      // Range check for cur_infer_id
      ICHECK_GE(cur_infer_id, 0)
          << "cur_infer_id is negative, which is invalid.";
      ICHECK_LT(cur_infer_id, num_infer)
          << "cur_infer_id " << cur_infer_id << " is out of range, must be < "
          << num_infer << ".";

      // Make sure we can safely access infer_list_[cur_infer_id] and
      // thread_var_vec_[cur_infer_id]
256
      auto &next = infer_list_[cur_infer_id];
257
      auto iter_var = thread_var_vec_[cur_infer_id];
258
      auto thread_bounds = thread_bounds_vec_[cur_infer_id];
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279

      // Double-check that 'next' is valid
      ICHECK(next != nullptr) << "infer_list_[" << cur_infer_id
                              << "] is null inside run_infer_step.";

      // Check iter_var->dom and dom->extent
      ICHECK(iter_var.defined())
          << "thread_var_vec_[" << cur_infer_id << "] is not defined.";
      ICHECK(iter_var->dom.defined())
          << "iter_var->dom is not defined for infer_list_[" << cur_infer_id
          << "].";
      ICHECK(iter_var->dom->extent.defined())
          << "iter_var->dom->extent is not defined for infer_list_["
          << cur_infer_id << "].";

      const int64_t *extent_ptr = as_const_int(iter_var->dom->extent);
      ICHECK(extent_ptr != nullptr)
          << "iter_var->dom->extent is not a constant integer, which is "
             "required for layout inference.";

      // Run InferLayout
280
      auto updates = next->InferLayout(
281
          LayoutInferArgs{target_, thread_bounds, layout_map}, level);
282
      // Process the returned updates
283
      for (const auto &[buffer, layout] : updates) {
284
285
286
287
        // Basic validity checks
        ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
        ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";

288
        if (layout_map.count(buffer)) {
289
          // If already in map, ensure they are structurally equal
290
          ICHECK(StructuralEqual()(layout, layout_map[buffer]))
291
              << "Get different layout for " << buffer
292
293
              << " current layout: " << layout->DebugOutput()
              << " previous layout: " << layout_map[buffer]->DebugOutput();
294
        } else {
295
          // Otherwise, update map
296
          layout_map.Set(buffer, layout);
297
298
          if (!update_queue)
            continue;
299
300

          // Check if buffer exists in use_list_
301
302
303
304
305
306
          if (!use_list_.count(buffer)) {
            LOG(WARNING) << "Buffer " << buffer << " not found in use_list_. "
                         << "Potential mismatch between inference updates and "
                         << "use_list_.";
            continue;
          }
307
308

          // Push back into BFS queue
309
          for (int idx : use_list_[buffer]) {
310
311
312
313
314
315
            ICHECK_GE(idx, 0) << "Index in use_list_ for buffer " << buffer
                              << " is negative.";
            ICHECK_LT(idx, num_infer)
                << "Index in use_list_ for buffer " << buffer
                << " out of range: " << idx << " >= " << num_infer << ".";

316
317
318
319
320
321
322
323
            if (!in_queue[idx] && idx != cur_infer_id) {
              in_queue[idx] = true;
              q.push(idx);
            }
          }
        }
      }
    };
324

325
326
327
328
    auto finish_infer_queue = [&]() {
      while (!q.empty()) {
        int cur_infer_id = q.front();
        q.pop();
329
330
331
332
        // Range check again, just to be safe
        ICHECK_GE(cur_infer_id, 0);
        ICHECK_LT(cur_infer_id, num_infer);

333
334
335
336
337
        in_queue[cur_infer_id] = false;
        run_infer_step(cur_infer_id, InferLevel::kCommon, true);
      }
    };

338
    // step 1: infer strict layout
339
340
341
342
    for (int i = 0; i < num_infer; i++) {
      run_infer_step(i, InferLevel::kStrict, false);
    }

343
    // step 2: infer common layout with BFS
344
    finish_infer_queue();
345

346
    // step 3: relax constraints to free and re-run
347
348
349
350
351
    for (int i = 0; i < num_infer; i++) {
      run_infer_step(i, InferLevel::kFree, true);
      finish_infer_queue();
    }

352
    // Check that all local.fragment buffers have inferred layouts
353
    for (const auto &[buffer, _] : use_list_) {
354
355
356
357
358
      if (buffer.scope() == "local.fragment") {
        ICHECK_NE(layout_map.count(buffer), 0)
            << "The layout for fragment " << buffer
            << " can not be inferred correctly.";
      }
359
360
    }

361
    // Collect layout info for For nodes
362
363
    Map<For, Fragment> for_map;
    Map<For, PrimExpr> predicate_map;
364
365
366
367
368
369
    ICHECK(infer_list_.size() == thread_var_vec_.size())
        << "infer_list_ and thread_var_vec_ size mismatch";
    for (int i = 0; i < infer_list_.size(); i++) {
      std::unique_ptr<Operator> base_infer = std::move(infer_list_[i]);
      auto thread_var = thread_var_vec_[i];

370
371
372
      // Check if base_infer is valid
      ICHECK(base_infer != nullptr) << "Null pointer encountered in "
                                       "infer_list_ while collecting for_map.";
373
      if (auto for_infer = dynamic_cast<ParallelOp *>(base_infer.get())) {
374
        // Check that the loop layout is defined
375
        ICHECK(for_infer->GetLoopLayout().defined())
376
            << "The Layout for Parallel for cannot be inferred correctly:\n"
377
378
            << for_infer->GetRoot();
        for_map.Set(for_infer->GetRoot(), for_infer->GetLoopLayout());
379
        // thread_var_ should be defined if we rely on it
380
381
        ICHECK(thread_var.defined())
            << "thread_var is not defined. Cannot retrieve predicate.";
382

383
        if (auto predicate = for_infer->GetPredicate(thread_var->var)) {
384
          predicate_map.Set(for_infer->GetRoot(), predicate.value());
385
        }
386
387
388
389
390
391
      }
    }

    return {layout_map, for_map, predicate_map};
  }

392
393
  void Collect(const PrimFunc &f) {
    for (const auto &[_, buffer] : f->buffer_map) {
394
395
396
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
397
398
    ICHECK(target.defined())
        << "Layout_Inference: Require the target attribute";
399
400
401
402
    target_ = target.value();
    this->operator()(f->body);
  }

403
404
private:
  void VisitExpr_(const CallNode *op) final {
405
    IRVisitorWithAnalyzer::VisitExpr_(op);
406
    // Do not analysis the call node to the global function.
407
408
    if (op->op.as<GlobalVarNode>())
      return;
409
410
411

    auto p = ParseOperator(GetRef<Call>(op), buffer_data_to_buffer_);
    if (p != nullptr) {
412
      for (const auto &arg : op->args) {
413
414
415
416
417
418
        if (auto buffer = getBufferFromAccessPtr(arg)) {
          addToUseList(buffer.value());
        }
      }
      infer_list_.push_back(std::move(p));
      thread_var_vec_.push_back(thread_var_);
419
420
421
422
423
424
425
426
427
428
      if (analyzer_.const_int_bound.IsBound(thread_var_->var)) {
        auto const_int_bound = analyzer_.const_int_bound(thread_var_);
        auto min_value = const_int_bound->min_value;
        auto max_value = const_int_bound->max_value;
        auto dtype = thread_var_->var.dtype();
        thread_bounds_vec_.push_back(Range::FromMinExtent(
            IntImm(dtype, min_value), IntImm(dtype, max_value + 1)));
      } else {
        thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
      }
429
430
431
    }
  }

432
  Optional<Buffer> getBufferFromAccessPtr(const PrimExpr &expr) {
433
434
435
436
437
438
439
440
    auto call = expr.as<CallNode>();
    if (call && call->op.same_as(builtin::tvm_access_ptr())) {
      auto var = call->args[1].as<Var>().value();
      return buffer_data_to_buffer_[var];
    }
    return NullOpt;
  }

441
  void addToUseList(const Buffer &buffer) {
442
443
444
445
446
447
448
    int infer_idx = infer_list_.size();
    if (use_list_.find(buffer) == use_list_.end()) {
      use_list_[buffer] = {};
    }
    use_list_[buffer].push_back(infer_idx);
  }

449
  void VisitStmt_(const ForNode *op) final {
450
451
    if (op->kind == ForKind::kParallel) {
      auto infer = std::make_unique<ParallelOp>(GetRef<For>(op));
452
      for (const auto &[buffer, _] : infer->GetIndiceMap()) {
453
454
455
456
        addToUseList(buffer);
      }
      infer_list_.push_back(std::move(infer));
      thread_var_vec_.push_back(thread_var_);
457
458
459
460
461
462
463
464
465
466
      if (thread_var_.defined() &&
          analyzer_.const_int_bound.IsBound(thread_var_->var)) {
        auto const_int_bound = analyzer_.const_int_bound(thread_var_);
        auto dtype = thread_var_->var.dtype();
        thread_bounds_vec_.push_back(Range::FromMinExtent(
            IntImm(dtype, const_int_bound->min_value),
            IntImm(dtype, const_int_bound->max_value + 1)));
      } else {
        thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
      }
467
    } else {
468
      IRVisitorWithAnalyzer::VisitStmt(op->body);
469
470
471
    }
  }

472
  void VisitStmt_(const BlockNode *op) final {
473
474
475
476
    for (auto buffer : op->alloc_buffers) {
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    if (op->annotations.count(attr::kLayoutMap)) {
477
478
479
      auto map =
          op->annotations.Get(attr::kLayoutMap).as<Map<Var, Layout>>().value();
      for (const auto &[var, layout] : map) {
480
481
482
483
484
        auto buffer = buffer_data_to_buffer_[var];
        ICHECK(StructuralEqual()(layout->InputShape(), buffer->shape));
        annotated_layout_map_.Set(buffer, layout);
      }
    }
485
    IRVisitorWithAnalyzer::VisitStmt_(op);
486
487
  }

488
  void VisitStmt_(const AttrStmtNode *op) final {
489
490
491
492
493
494
495
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      if (iv->thread_tag == "threadIdx.x") {
        ICHECK(iv->dom->extent.as<IntImmNode>());
        thread_var_ = iv;
      }
    }
496
    IRVisitorWithAnalyzer::VisitStmt_(op);
497
498
499
500
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  std::vector<std::unique_ptr<Operator>> infer_list_;
501
502
  std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
      use_list_;
503
504
505
506
  // This is a workaround for cpu backend,
  // we need to define a thread_var for the serial loop.
  IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
                                IterVarType::kDataPar);
507
  std::vector<IterVar> thread_var_vec_;
508
  std::vector<Range> thread_bounds_vec_;
509
510
  Target target_;
  LayoutMap annotated_layout_map_;
511
  bool skip_thread_partition_{false};
512
513
514
};

class LayoutInferencer : public IRMutatorWithAnalyzer {
515
public:
516
  static PrimFunc Substitute(PrimFunc f, bool skip_thread_partition = false) {
517
    arith::Analyzer analyzer;
518
    PrimFuncNode *fptr = f.CopyOnWrite();
519
    fptr->body = ParallelLoopFuser::Fuse(f->body);
520
    BufferUseDefCollector collector(skip_thread_partition);
521
522
    collector.Collect(f);
    auto result = collector.Run();
523
    LayoutInferencer substituter(result, skip_thread_partition, &analyzer);
524
525
526
527
    fptr->body = substituter.VisitStmt(f->body);
    return f;
  }

528
529
private:
  LayoutInferencer(const LayoutInferenceResult result,
530
531
532
                   bool skip_thread_partition, arith::Analyzer *analyzer)
      : arith::IRMutatorWithAnalyzer(analyzer), result_(result),
        skip_thread_partition_(skip_thread_partition){};
533

534
  Stmt VisitStmt_(const BlockNode *op) final {
535
536
537
538
539
540
541
542
543
544
545
546
547
    Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));

    for (auto buffer : block->alloc_buffers) {
      if (buffer.scope() == "local.framgent") {
        ICHECK(result_.layout_map.count(buffer))
            << "Cannot inference fragment layout for " << buffer;
      }
    }
    auto block_ptr = block.CopyOnWrite();
    block_ptr->annotations.Set(attr::kLayoutMap, result_.layout_map);
    return block;
  }

548
  Stmt VisitStmt_(const ForNode *op) final {
549
550
    For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
    if (result_.for_map.count(GetRef<For>(op))) {
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
      auto root = GetRef<For>(op);
      // This check is a workaround to support T.Parallel for local buffers.
      // For example:
      //   for i in T.Parallel(1024):
      //     A_local[i] = A_global[i]
      // Here, A_local is a register-local buffer held independently by each
      // thread, so explicit thread binding is not required.
      //
      // We use PostOrderVisit to detect whether the buffer store targets a
      // "local" buffer, which indicates register usage and justifies skipping
      // thread binding.
      bool is_register_store = false;
      PostOrderVisit(root, [&](const ObjectRef &obj) {
        if (const auto *store = obj.as<BufferStoreNode>()) {
          if (store->buffer.scope() == "local") {
            is_register_store = true;
          }
        }
      });

      bool parallel_loop = !is_register_store && !skip_thread_partition_;
      if (parallel_loop) {
        auto loop_layout = result_.for_map[root];
574
575
576
        for_node =
            PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
      }
577
      // If none thread bindings are provided, partition the loop
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
      bool has_non_local = false;
      PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
        if (const auto *load = obj.as<BufferLoadNode>()) {
          String scope = load->buffer.scope();
          if (scope != "local" && scope != "local.fragment") {
            has_non_local = true;
          }
        } else if (const auto *store = obj.as<BufferStoreNode>()) {
          String scope = store->buffer.scope();
          if (scope != "local" && scope != "local.fragment") {
            has_non_local = true;
          }
        }
      });

      if (has_non_local) {
        for_node = VectorizeLoop(for_node);
      }
596

597
598
      if (result_.predicate_map.count(root) && parallel_loop) {
        return IfThenElse(result_.predicate_map[root], for_node);
599
600
601
602
603
604
605
      } else {
        return for_node;
      }
    }
    return for_node;
  }

606
  Stmt VisitStmt_(const AttrStmtNode *op) final {
607
608
609
610
611
612
613
614
615
616
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      ICHECK_NE(iv->thread_tag.length(), 0U);
      if (iv->thread_tag == "threadIdx.x") {
        thread_var_ = iv;
      }
    }
    return IRMutatorWithAnalyzer::VisitStmt_(op);
  }

617
private:
618
  const LayoutInferenceResult result_;
619
620
  IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
                                IterVarType::kDataPar);
621
  bool skip_thread_partition_{false};
622
623
624
625
626
};

tvm::transform::Pass LayoutInference() {
  using namespace tir::transform;
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
627
    f.CopyOnWrite()->body = ParallelLoopTransformer::Substitute(f->body);
628
    ThreadBindingCollector collector;
629
    collector(f->body);
630
631
    bool has_thread_binding = collector.thread_binding_.size() > 0;
    bool skip_thread_partition = !has_thread_binding;
632
    return LayoutInferencer::Substitute(std::move(f), skip_thread_partition);
633
634
635
636
637
638
639
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {});
}

TVM_REGISTER_GLOBAL("tl.transform.LayoutInference")
    .set_body_typed(LayoutInference);

640
641
} // namespace tl
} // namespace tvm