layout_inference.cc 22.8 KB
Newer Older
1
2
3
4
5
6
/*!
 * \file layout_inference.cc
 * \brief infer the fragment/shared memory layout
 */

#include <tvm/tir/builtin.h>
7
#include <tvm/tir/index_map.h>
8
9
10
11
12
13
14
15
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>

#include <queue>

#include "../op/parallel.h"
16
#include "arith/ir_mutator_with_analyzer.h"
17
#include "arith/ir_visitor_with_analyzer.h"
18
#include "common/loop_fusion_utils.h"
19
20
#include "loop_partition.h"
#include "loop_vectorize.h"
21
22
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
23
24
25
26

namespace tvm {
namespace tl {

27
28
29
using namespace tir;

/*!
30
 * \brief collect the mapping from the buffer var to it allocated buffer
31
 */
32
class ThreadBindingCollector : public StmtExprVisitor {
33
34
public:
  void VisitStmt_(const AttrStmtNode *op) final {
35
36
37
38
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      thread_binding_[iv->var.get()] = iv;
    }
39
40
41
    StmtExprVisitor::VisitStmt_(op);
  }

42
43
  // The thread binding map
  std::unordered_map<const VarNode *, IterVar> thread_binding_;
44
45
};

46
47
using namespace tir;
using arith::IRMutatorWithAnalyzer;
48
using arith::IRVisitorWithAnalyzer;
49

50
51
52
53
54
55
56
57
58
59
60
61
class ParallelLoopTransformer : public IRMutatorWithAnalyzer {
public:
  static Stmt Substitute(Stmt stmt, bool skip_thread_partition = false) {
    arith::Analyzer analyzer;
    ParallelLoopTransformer transformer(&analyzer);
    return transformer.VisitStmt(stmt);
  }

  ParallelLoopTransformer(arith::Analyzer *analyzer)
      : IRMutatorWithAnalyzer(analyzer) {}

  Stmt VisitStmt_(const ForNode *op) final {
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    if (op->kind != ForKind::kParallel)
      return StmtMutator::VisitStmt_(op);

    // Collect loop variables and ranges
    auto for_node = GetRef<For>(op);
    Array<Var> loop_vars;
    Array<PrimExpr> loop_extents;
    Stmt body = op->body;

    // Bind the range of outer loop variables
    analyzer_->Bind(op->loop_var, Range::FromMinExtent(0, op->extent));
    loop_vars.push_back(op->loop_var);
    loop_extents.push_back(op->extent);

    // If there are inner loops, bind their ranges as well
    while (const ForNode *inner = body.as<ForNode>()) {
      analyzer_->Bind(inner->loop_var, Range::FromMinExtent(0, inner->extent));
      loop_vars.push_back(inner->loop_var);
      loop_extents.push_back(inner->extent);
      body = inner->body;
    }
83

84
85
    ICHECK(loop_vars.size() == loop_extents.size())
        << "loop_vars and loop_extents size mismatch";
86

87
88
89
    // Collect buffer access information
    BufferAccessCollector collector;
    collector(op->body);
90

91
    PrimExpr condition;
92

93
94
95
    for (const auto &[buffer, indices] : collector.buffer_indices) {
      ICHECK(indices.size() == buffer->shape.size())
          << "indices size mismatch with buffer shape";
96

97
98
99
100
101
      for (size_t i = 0; i < indices.size(); ++i) {
        auto index = indices[i];
        auto bound = analyzer_->const_int_bound(index);
        int64_t upper_bound = bound->max_value + 1;
        int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;
102

103
104
105
106
107
108
109
110
111
112
113
        // Collect the variables that used in the index
        std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_vars;
        // post order visit the index
        PostOrderVisit(index, [&](const ObjectRef &obj) {
          if (const VarNode *v = obj.as<VarNode>()) {
            used_vars.insert(GetRef<Var>(v));
          }
        });
        if (used_vars.size() == 0) {
          continue;
        }
114

115
116
117
118
119
120
121
        // find related loop vars
        Array<Var> related_loop_vars;
        for (size_t j = 0; j < loop_vars.size(); ++j) {
          auto loop_var = loop_vars[j];
          // if find related, pop the loop_vars and loop_extents
          if (used_vars.count(loop_var)) {
            related_loop_vars.push_back(loop_var);
122
          }
123
124
125
126
127
128
          ICHECK(related_loop_vars.size() <= 1)
              << "Only one related loop var is supported currently, but got "
              << related_loop_vars
              << " implement multiple loop vars may not be "
              << "too hard, please send an issue if you need "
              << "came up with this message.";
129

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
          auto bound = analyzer_->const_int_bound(index);
          int64_t upper_bound = bound->max_value + 1;
          int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;
          if (upper_bound < shape) {
            PrimExpr predicate = LT(index, IntImm(index.dtype(), upper_bound));
            condition =
                condition.defined() ? And(condition, predicate) : predicate;

            // replace the buffer index from A[i, r * 2] with A[i, j]
            // where r is the original index, j is the loop_var
            auto index_map = tir::IndexMap({loop_var}, {index});
            auto inverse_index_map = index_map.Inverse(
                {Range::FromMinExtent(0, IntImm(index.dtype(), upper_bound))},
                analyzer_);

            loop_extents.Set(i, IntImm(index.dtype(), shape));
            body = tir::Substitute(body,
                                   {{loop_var, inverse_index_map->MapIndices(
                                                   {loop_var}, analyzer_)[0]}});
149
150
151
          }
        }
      }
152
153
154
155
156
157
158
    }
    if (condition.defined()) {
      body = IfThenElse(condition, body);
      for (int j = loop_vars.size() - 1; j >= 0; --j) {
        auto loop_var = loop_vars[j];
        auto loop_extent = loop_extents[j];
        body = For(loop_var, 0, loop_extent, ForKind::kParallel, body);
159
      }
160
      return Downcast<For>(body);
161
    }
162
163
    // Only traverse the outer loop
    return for_node;
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
  }

private:
  // Helper class for collecting buffer access information, only counts fragment
  // buffer access
  class BufferAccessCollector : public StmtExprVisitor {
  public:
    void VisitExpr_(const BufferLoadNode *op) final {
      if (op->buffer.scope() == "local.fragment") {
        if (buffer_indices.find(op->buffer) == buffer_indices.end()) {
          buffer_indices[op->buffer] = op->indices;
        } else {
          // check equal
          ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices))
              << "indices mismatch for buffer: " << op->buffer;
        }
      }
      StmtExprVisitor::VisitExpr_(op);
    }

    void VisitStmt_(const BufferStoreNode *op) final {
      if (op->buffer.scope() == "local.fragment") {
        if (buffer_indices.find(op->buffer) == buffer_indices.end()) {
          buffer_indices[op->buffer] = op->indices;
        } else {
          // check equal
          ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices))
              << "indices mismatch for buffer: " << op->buffer;
        }
      }
      StmtExprVisitor::VisitStmt_(op);
    }

    std::unordered_map<Buffer, Array<PrimExpr>, ObjectPtrHash, ObjectPtrEqual>
        buffer_indices;
  };
};

202
203
204
205
206
207
struct LayoutInferenceResult {
  Map<Buffer, Layout> layout_map;
  Map<For, Fragment> for_map;
  Map<For, PrimExpr> predicate_map;
};

208
class BufferUseDefCollector : public IRVisitorWithAnalyzer {
209
public:
210
211
  BufferUseDefCollector(bool skip_thread_partition)
      : skip_thread_partition_(skip_thread_partition) {}
212
213

  LayoutInferenceResult Run() {
214
215
216
217
218
    // Basic consistency check: infer_list_ and thread_var_vec_ should have the
    // same size
    ICHECK_EQ(infer_list_.size(), thread_var_vec_.size())
        << "Size mismatch: infer_list_ and thread_var_vec_ must match in "
           "length.";
219
220
221
    ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size())
        << "Size mismatch: thread_bounds_vec_ and infer_list_ must match in "
           "length.";
222
223
224
225
226

    // If needed, you can also check that annotated_layout_map_ is not empty, or
    // anything else relevant to your setup.

    // Copy the annotated layout map to local variable
227
228
229
    Map<Buffer, Layout> layout_map = annotated_layout_map_;
    int num_infer = infer_list_.size();

230
    // Prepare BFS queue for iterative inference
231
232
    std::queue<int> q;
    std::vector<bool> in_queue(num_infer, true);
233
234
235
236
237
238
239
240
241
242
    for (int i = 0; i < num_infer; i++) {
      // Check that each infer_list_ entry is valid
      ICHECK(infer_list_[i] != nullptr)
          << "infer_list_[" << i
          << "] is null. The inference object is not allocated properly.";

      // Check that each thread_var_vec_ entry is defined
      if (!thread_var_vec_[i].defined() && skip_thread_partition_) {
        thread_var_vec_[i] = thread_var_;
      }
243
      q.push(i);
244
    }
245
246
    auto run_infer_step = [&](int cur_infer_id, InferLevel level,
                              bool update_queue) {
247
248
249
250
251
252
253
254
255
      // Range check for cur_infer_id
      ICHECK_GE(cur_infer_id, 0)
          << "cur_infer_id is negative, which is invalid.";
      ICHECK_LT(cur_infer_id, num_infer)
          << "cur_infer_id " << cur_infer_id << " is out of range, must be < "
          << num_infer << ".";

      // Make sure we can safely access infer_list_[cur_infer_id] and
      // thread_var_vec_[cur_infer_id]
256
      auto &next = infer_list_[cur_infer_id];
257
      auto iter_var = thread_var_vec_[cur_infer_id];
258
      auto thread_bounds = thread_bounds_vec_[cur_infer_id];
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
      // Double-check that 'next' is valid
      ICHECK(next != nullptr) << "infer_list_[" << cur_infer_id
                              << "] is null inside run_infer_step.";

      // Check iter_var->dom and dom->extent
      ICHECK(iter_var.defined())
          << "thread_var_vec_[" << cur_infer_id << "] is not defined.";
      ICHECK(iter_var->dom.defined())
          << "iter_var->dom is not defined for infer_list_[" << cur_infer_id
          << "].";
      ICHECK(iter_var->dom->extent.defined())
          << "iter_var->dom->extent is not defined for infer_list_["
          << cur_infer_id << "].";

      const int64_t *extent_ptr = as_const_int(iter_var->dom->extent);
      ICHECK(extent_ptr != nullptr)
          << "iter_var->dom->extent is not a constant integer, which is "
             "required for layout inference.";

      // Run InferLayout
279
      auto updates = next->InferLayout(
280
          LayoutInferArgs{target_, thread_bounds, layout_map}, level);
281
      // Process the returned updates
282
      for (const auto &[buffer, layout] : updates) {
283
284
285
286
        // Basic validity checks
        ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
        ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";

287
        if (layout_map.count(buffer)) {
288
          // If already in map, ensure they are structurally equal
289
          ICHECK(StructuralEqual()(layout, layout_map[buffer]))
290
              << "Get different layout for " << buffer
291
292
              << " current layout: " << layout->DebugOutput()
              << " previous layout: " << layout_map[buffer]->DebugOutput();
293
        } else {
294
          // Otherwise, update map
295
          layout_map.Set(buffer, layout);
296
297
          if (!update_queue)
            continue;
298
299

          // Check if buffer exists in use_list_
300
301
302
303
304
305
          if (!use_list_.count(buffer)) {
            LOG(WARNING) << "Buffer " << buffer << " not found in use_list_. "
                         << "Potential mismatch between inference updates and "
                         << "use_list_.";
            continue;
          }
306
307

          // Push back into BFS queue
308
          for (int idx : use_list_[buffer]) {
309
310
311
312
313
314
            ICHECK_GE(idx, 0) << "Index in use_list_ for buffer " << buffer
                              << " is negative.";
            ICHECK_LT(idx, num_infer)
                << "Index in use_list_ for buffer " << buffer
                << " out of range: " << idx << " >= " << num_infer << ".";

315
316
317
318
319
320
321
322
            if (!in_queue[idx] && idx != cur_infer_id) {
              in_queue[idx] = true;
              q.push(idx);
            }
          }
        }
      }
    };
323

324
325
326
327
    auto finish_infer_queue = [&]() {
      while (!q.empty()) {
        int cur_infer_id = q.front();
        q.pop();
328
329
330
331
        // Range check again, just to be safe
        ICHECK_GE(cur_infer_id, 0);
        ICHECK_LT(cur_infer_id, num_infer);

332
333
334
335
336
        in_queue[cur_infer_id] = false;
        run_infer_step(cur_infer_id, InferLevel::kCommon, true);
      }
    };

337
    // step 1: infer strict layout
338
339
340
341
    for (int i = 0; i < num_infer; i++) {
      run_infer_step(i, InferLevel::kStrict, false);
    }

342
    // step 2: infer common layout with BFS
343
    finish_infer_queue();
344

345
    // step 3: relax constraints to free and re-run
346
347
348
349
350
    for (int i = 0; i < num_infer; i++) {
      run_infer_step(i, InferLevel::kFree, true);
      finish_infer_queue();
    }

351
    // Check that all local.fragment buffers have inferred layouts
352
    for (const auto &[buffer, _] : use_list_) {
353
354
355
356
357
      if (buffer.scope() == "local.fragment") {
        ICHECK_NE(layout_map.count(buffer), 0)
            << "The layout for fragment " << buffer
            << " can not be inferred correctly.";
      }
358
359
    }

360
    // Collect layout info for For nodes
361
362
    Map<For, Fragment> for_map;
    Map<For, PrimExpr> predicate_map;
363
364
365
366
367
368
    ICHECK(infer_list_.size() == thread_var_vec_.size())
        << "infer_list_ and thread_var_vec_ size mismatch";
    for (int i = 0; i < infer_list_.size(); i++) {
      std::unique_ptr<Operator> base_infer = std::move(infer_list_[i]);
      auto thread_var = thread_var_vec_[i];

369
370
371
      // Check if base_infer is valid
      ICHECK(base_infer != nullptr) << "Null pointer encountered in "
                                       "infer_list_ while collecting for_map.";
372
      if (auto for_infer = dynamic_cast<ParallelOp *>(base_infer.get())) {
373
        // Check that the loop layout is defined
374
        ICHECK(for_infer->GetLoopLayout().defined())
375
            << "The Layout for Parallel for cannot be inferred correctly:\n"
376
377
            << for_infer->GetRoot();
        for_map.Set(for_infer->GetRoot(), for_infer->GetLoopLayout());
378
        // thread_var_ should be defined if we rely on it
379
380
        ICHECK(thread_var.defined())
            << "thread_var is not defined. Cannot retrieve predicate.";
381

382
        if (auto predicate = for_infer->GetPredicate(thread_var->var)) {
383
          predicate_map.Set(for_infer->GetRoot(), predicate.value());
384
        }
385
386
387
388
389
390
      }
    }

    return {layout_map, for_map, predicate_map};
  }

391
392
  void Collect(const PrimFunc &f) {
    for (const auto &[_, buffer] : f->buffer_map) {
393
394
395
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
396
397
    ICHECK(target.defined())
        << "Layout_Inference: Require the target attribute";
398
399
400
401
    target_ = target.value();
    this->operator()(f->body);
  }

402
403
private:
  void VisitExpr_(const CallNode *op) final {
404
    IRVisitorWithAnalyzer::VisitExpr_(op);
405
    // Do not analysis the call node to the global function.
406
407
    if (op->op.as<GlobalVarNode>())
      return;
408
409
410

    auto p = ParseOperator(GetRef<Call>(op), buffer_data_to_buffer_);
    if (p != nullptr) {
411
      for (const auto &arg : op->args) {
412
413
414
415
416
417
        if (auto buffer = getBufferFromAccessPtr(arg)) {
          addToUseList(buffer.value());
        }
      }
      infer_list_.push_back(std::move(p));
      thread_var_vec_.push_back(thread_var_);
418
419
420
421
      if (analyzer_.const_int_bound.IsBound(thread_var_->var)) {
        auto const_int_bound = analyzer_.const_int_bound(thread_var_);
        auto min_value = const_int_bound->min_value;
        auto max_value = const_int_bound->max_value;
422
        auto extent = max_value - min_value + 1;
423
424
        auto dtype = thread_var_->var.dtype();
        thread_bounds_vec_.push_back(Range::FromMinExtent(
425
            IntImm(dtype, min_value), IntImm(dtype, extent)));
426
427
428
      } else {
        thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
      }
429
430
431
    }
  }

432
  Optional<Buffer> getBufferFromAccessPtr(const PrimExpr &expr) {
433
434
435
436
437
438
439
440
    auto call = expr.as<CallNode>();
    if (call && call->op.same_as(builtin::tvm_access_ptr())) {
      auto var = call->args[1].as<Var>().value();
      return buffer_data_to_buffer_[var];
    }
    return NullOpt;
  }

441
  void addToUseList(const Buffer &buffer) {
442
443
444
445
446
447
448
    int infer_idx = infer_list_.size();
    if (use_list_.find(buffer) == use_list_.end()) {
      use_list_[buffer] = {};
    }
    use_list_[buffer].push_back(infer_idx);
  }

449
  void VisitStmt_(const ForNode *op) final {
450
451
    if (op->kind == ForKind::kParallel) {
      auto infer = std::make_unique<ParallelOp>(GetRef<For>(op));
452
      for (const auto &[buffer, _] : infer->GetIndiceMap()) {
453
454
455
456
        addToUseList(buffer);
      }
      infer_list_.push_back(std::move(infer));
      thread_var_vec_.push_back(thread_var_);
457
458
459
460
      if (thread_var_.defined() &&
          analyzer_.const_int_bound.IsBound(thread_var_->var)) {
        auto const_int_bound = analyzer_.const_int_bound(thread_var_);
        auto dtype = thread_var_->var.dtype();
461
462
        auto extent =
            const_int_bound->max_value - const_int_bound->min_value + 1;
463
        thread_bounds_vec_.push_back(Range::FromMinExtent(
464
            IntImm(dtype, const_int_bound->min_value), IntImm(dtype, extent)));
465
466
467
      } else {
        thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
      }
468
    } else {
469
      IRVisitorWithAnalyzer::VisitStmt(op->body);
470
471
472
    }
  }

473
  void VisitStmt_(const BlockNode *op) final {
474
475
476
477
    for (auto buffer : op->alloc_buffers) {
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    if (op->annotations.count(attr::kLayoutMap)) {
478
479
480
      auto map =
          op->annotations.Get(attr::kLayoutMap).as<Map<Var, Layout>>().value();
      for (const auto &[var, layout] : map) {
481
482
483
484
485
        auto buffer = buffer_data_to_buffer_[var];
        ICHECK(StructuralEqual()(layout->InputShape(), buffer->shape));
        annotated_layout_map_.Set(buffer, layout);
      }
    }
486
    IRVisitorWithAnalyzer::VisitStmt_(op);
487
488
  }

489
  void VisitStmt_(const AttrStmtNode *op) final {
490
491
492
493
494
495
496
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      if (iv->thread_tag == "threadIdx.x") {
        ICHECK(iv->dom->extent.as<IntImmNode>());
        thread_var_ = iv;
      }
    }
497
    IRVisitorWithAnalyzer::VisitStmt_(op);
498
499
500
501
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  std::vector<std::unique_ptr<Operator>> infer_list_;
502
503
  std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
      use_list_;
504
505
506
507
  // This is a workaround for cpu backend,
  // we need to define a thread_var for the serial loop.
  IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
                                IterVarType::kDataPar);
508
  std::vector<IterVar> thread_var_vec_;
509
  std::vector<Range> thread_bounds_vec_;
510
511
  Target target_;
  LayoutMap annotated_layout_map_;
512
  bool skip_thread_partition_{false};
513
514
515
};

class LayoutInferencer : public IRMutatorWithAnalyzer {
516
public:
517
  static PrimFunc Substitute(PrimFunc f, bool skip_thread_partition = false) {
518
    arith::Analyzer analyzer;
519
    PrimFuncNode *fptr = f.CopyOnWrite();
520
    fptr->body = ParallelLoopFuser::Fuse(f->body);
521
    BufferUseDefCollector collector(skip_thread_partition);
522
523
    collector.Collect(f);
    auto result = collector.Run();
524
    LayoutInferencer substituter(result, skip_thread_partition, &analyzer);
525
526
527
528
    fptr->body = substituter.VisitStmt(f->body);
    return f;
  }

529
530
private:
  LayoutInferencer(const LayoutInferenceResult result,
531
532
533
                   bool skip_thread_partition, arith::Analyzer *analyzer)
      : arith::IRMutatorWithAnalyzer(analyzer), result_(result),
        skip_thread_partition_(skip_thread_partition){};
534

535
  Stmt VisitStmt_(const BlockNode *op) final {
536
537
538
539
540
541
542
543
544
545
546
547
548
    Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));

    for (auto buffer : block->alloc_buffers) {
      if (buffer.scope() == "local.framgent") {
        ICHECK(result_.layout_map.count(buffer))
            << "Cannot inference fragment layout for " << buffer;
      }
    }
    auto block_ptr = block.CopyOnWrite();
    block_ptr->annotations.Set(attr::kLayoutMap, result_.layout_map);
    return block;
  }

549
  Stmt VisitStmt_(const ForNode *op) final {
550
551
    For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
    if (result_.for_map.count(GetRef<For>(op))) {
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
      auto root = GetRef<For>(op);
      // This check is a workaround to support T.Parallel for local buffers.
      // For example:
      //   for i in T.Parallel(1024):
      //     A_local[i] = A_global[i]
      // Here, A_local is a register-local buffer held independently by each
      // thread, so explicit thread binding is not required.
      //
      // We use PostOrderVisit to detect whether the buffer store targets a
      // "local" buffer, which indicates register usage and justifies skipping
      // thread binding.
      bool is_register_store = false;
      PostOrderVisit(root, [&](const ObjectRef &obj) {
        if (const auto *store = obj.as<BufferStoreNode>()) {
          if (store->buffer.scope() == "local") {
            is_register_store = true;
          }
        }
      });

572
      auto loop_layout = result_.for_map[root];
573
574
      bool parallel_loop = !is_register_store && !skip_thread_partition_;
      if (parallel_loop) {
575
576
577
        for_node =
            PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
      }
578
      // If none thread bindings are provided, partition the loop
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
      bool has_non_local = false;
      PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
        if (const auto *load = obj.as<BufferLoadNode>()) {
          String scope = load->buffer.scope();
          if (scope != "local" && scope != "local.fragment") {
            has_non_local = true;
          }
        } else if (const auto *store = obj.as<BufferStoreNode>()) {
          String scope = store->buffer.scope();
          if (scope != "local" && scope != "local.fragment") {
            has_non_local = true;
          }
        }
      });

      if (has_non_local) {
        for_node = VectorizeLoop(for_node);
      }
597

598
599
      if (result_.predicate_map.count(root) && parallel_loop) {
        return IfThenElse(result_.predicate_map[root], for_node);
600
601
602
603
604
605
606
      } else {
        return for_node;
      }
    }
    return for_node;
  }

607
  Stmt VisitStmt_(const AttrStmtNode *op) final {
608
609
610
611
612
613
614
615
616
617
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      ICHECK_NE(iv->thread_tag.length(), 0U);
      if (iv->thread_tag == "threadIdx.x") {
        thread_var_ = iv;
      }
    }
    return IRMutatorWithAnalyzer::VisitStmt_(op);
  }

618
private:
619
  const LayoutInferenceResult result_;
620
621
  IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
                                IterVarType::kDataPar);
622
  bool skip_thread_partition_{false};
623
624
625
626
627
};

tvm::transform::Pass LayoutInference() {
  using namespace tir::transform;
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
628
    f.CopyOnWrite()->body = ParallelLoopTransformer::Substitute(f->body);
629
    ThreadBindingCollector collector;
630
    collector(f->body);
631
632
    bool has_thread_binding = collector.thread_binding_.size() > 0;
    bool skip_thread_partition = !has_thread_binding;
633
    return LayoutInferencer::Substitute(std::move(f), skip_thread_partition);
634
635
636
637
638
639
640
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {});
}

TVM_REGISTER_GLOBAL("tl.transform.LayoutInference")
    .set_body_typed(LayoutInference);

641
642
} // namespace tl
} // namespace tvm