layout_inference.cc 22.8 KB
Newer Older
1
2
3
4
5
6
/*!
 * \file layout_inference.cc
 * \brief infer the fragment/shared memory layout
 */

#include <tvm/tir/builtin.h>
7
#include <tvm/tir/index_map.h>
8
9
10
11
12
13
14
15
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>

#include <queue>

#include "../op/parallel.h"
16
#include "arith/ir_mutator_with_analyzer.h"
17
#include "arith/ir_visitor_with_analyzer.h"
18
#include "common/loop_fusion_utils.h"
19
20
#include "loop_partition.h"
#include "loop_vectorize.h"
21
22
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
23
24
25
26

namespace tvm {
namespace tl {

27
28
29
using namespace tir;

/*!
30
 * \brief collect the mapping from the buffer var to it allocated buffer
31
 */
32
class ThreadBindingCollector : public StmtExprVisitor {
33
34
public:
  void VisitStmt_(const AttrStmtNode *op) final {
35
36
37
38
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      thread_binding_[iv->var.get()] = iv;
    }
39
40
41
    StmtExprVisitor::VisitStmt_(op);
  }

42
43
  // The thread binding map
  std::unordered_map<const VarNode *, IterVar> thread_binding_;
44
45
};

46
47
using namespace tir;
using arith::IRMutatorWithAnalyzer;
48
using arith::IRVisitorWithAnalyzer;
49

50
51
52
53
54
55
56
57
58
59
60
61
class ParallelLoopTransformer : public IRMutatorWithAnalyzer {
public:
  static Stmt Substitute(Stmt stmt, bool skip_thread_partition = false) {
    arith::Analyzer analyzer;
    ParallelLoopTransformer transformer(&analyzer);
    return transformer.VisitStmt(stmt);
  }

  ParallelLoopTransformer(arith::Analyzer *analyzer)
      : IRMutatorWithAnalyzer(analyzer) {}

  Stmt VisitStmt_(const ForNode *op) final {
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    if (op->kind != ForKind::kParallel)
      return StmtMutator::VisitStmt_(op);

    // Collect loop variables and ranges
    auto for_node = GetRef<For>(op);
    Array<Var> loop_vars;
    Array<PrimExpr> loop_extents;
    Stmt body = op->body;

    // Bind the range of outer loop variables
    analyzer_->Bind(op->loop_var, Range::FromMinExtent(0, op->extent));
    loop_vars.push_back(op->loop_var);
    loop_extents.push_back(op->extent);

    // If there are inner loops, bind their ranges as well
    while (const ForNode *inner = body.as<ForNode>()) {
      analyzer_->Bind(inner->loop_var, Range::FromMinExtent(0, inner->extent));
      loop_vars.push_back(inner->loop_var);
      loop_extents.push_back(inner->extent);
      body = inner->body;
    }
83

84
85
    ICHECK(loop_vars.size() == loop_extents.size())
        << "loop_vars and loop_extents size mismatch";
86

87
88
89
    // Collect buffer access information
    BufferAccessCollector collector;
    collector(op->body);
90

91
    PrimExpr condition;
92

93
94
95
    for (const auto &[buffer, indices] : collector.buffer_indices) {
      ICHECK(indices.size() == buffer->shape.size())
          << "indices size mismatch with buffer shape";
96

97
98
99
100
101
      for (size_t i = 0; i < indices.size(); ++i) {
        auto index = indices[i];
        auto bound = analyzer_->const_int_bound(index);
        int64_t upper_bound = bound->max_value + 1;
        int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;
102

103
104
105
106
107
108
109
110
111
112
113
        // Collect the variables that used in the index
        std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_vars;
        // post order visit the index
        PostOrderVisit(index, [&](const ObjectRef &obj) {
          if (const VarNode *v = obj.as<VarNode>()) {
            used_vars.insert(GetRef<Var>(v));
          }
        });
        if (used_vars.size() == 0) {
          continue;
        }
114

115
116
117
118
119
120
121
        // find related loop vars
        Array<Var> related_loop_vars;
        for (size_t j = 0; j < loop_vars.size(); ++j) {
          auto loop_var = loop_vars[j];
          // if find related, pop the loop_vars and loop_extents
          if (used_vars.count(loop_var)) {
            related_loop_vars.push_back(loop_var);
122
          }
123
124
125
126
127
128
          ICHECK(related_loop_vars.size() <= 1)
              << "Only one related loop var is supported currently, but got "
              << related_loop_vars
              << " implement multiple loop vars may not be "
              << "too hard, please send an issue if you need "
              << "came up with this message.";
129

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
          auto bound = analyzer_->const_int_bound(index);
          int64_t upper_bound = bound->max_value + 1;
          int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;
          if (upper_bound < shape) {
            PrimExpr predicate = LT(index, IntImm(index.dtype(), upper_bound));
            condition =
                condition.defined() ? And(condition, predicate) : predicate;

            // replace the buffer index from A[i, r * 2] with A[i, j]
            // where r is the original index, j is the loop_var
            auto index_map = tir::IndexMap({loop_var}, {index});
            auto inverse_index_map = index_map.Inverse(
                {Range::FromMinExtent(0, IntImm(index.dtype(), upper_bound))},
                analyzer_);

            loop_extents.Set(i, IntImm(index.dtype(), shape));
            body = tir::Substitute(body,
                                   {{loop_var, inverse_index_map->MapIndices(
                                                   {loop_var}, analyzer_)[0]}});
149
150
151
          }
        }
      }
152
153
154
155
156
157
158
    }
    if (condition.defined()) {
      body = IfThenElse(condition, body);
      for (int j = loop_vars.size() - 1; j >= 0; --j) {
        auto loop_var = loop_vars[j];
        auto loop_extent = loop_extents[j];
        body = For(loop_var, 0, loop_extent, ForKind::kParallel, body);
159
      }
160
      return Downcast<For>(body);
161
    }
162
163
    // Only traverse the outer loop
    return for_node;
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
  }

private:
  // Helper class for collecting buffer access information, only counts fragment
  // buffer access
  class BufferAccessCollector : public StmtExprVisitor {
  public:
    void VisitExpr_(const BufferLoadNode *op) final {
      if (op->buffer.scope() == "local.fragment") {
        if (buffer_indices.find(op->buffer) == buffer_indices.end()) {
          buffer_indices[op->buffer] = op->indices;
        } else {
          // check equal
          ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices))
              << "indices mismatch for buffer: " << op->buffer;
        }
      }
      StmtExprVisitor::VisitExpr_(op);
    }

    void VisitStmt_(const BufferStoreNode *op) final {
      if (op->buffer.scope() == "local.fragment") {
        if (buffer_indices.find(op->buffer) == buffer_indices.end()) {
          buffer_indices[op->buffer] = op->indices;
        } else {
          // check equal
          ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices))
              << "indices mismatch for buffer: " << op->buffer;
        }
      }
      StmtExprVisitor::VisitStmt_(op);
    }

    std::unordered_map<Buffer, Array<PrimExpr>, ObjectPtrHash, ObjectPtrEqual>
        buffer_indices;
  };
};

202
203
204
205
206
207
struct LayoutInferenceResult {
  Map<Buffer, Layout> layout_map;
  Map<For, Fragment> for_map;
  Map<For, PrimExpr> predicate_map;
};

208
class BufferUseDefCollector : public IRVisitorWithAnalyzer {
209
public:
210
211
  BufferUseDefCollector(bool skip_thread_partition)
      : skip_thread_partition_(skip_thread_partition) {}
212
213

  LayoutInferenceResult Run() {
214
215
216
217
218
    // Basic consistency check: infer_list_ and thread_var_vec_ should have the
    // same size
    ICHECK_EQ(infer_list_.size(), thread_var_vec_.size())
        << "Size mismatch: infer_list_ and thread_var_vec_ must match in "
           "length.";
219
220
221
    ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size())
        << "Size mismatch: thread_bounds_vec_ and infer_list_ must match in "
           "length.";
222
223
224
225
226

    // If needed, you can also check that annotated_layout_map_ is not empty, or
    // anything else relevant to your setup.

    // Copy the annotated layout map to local variable
227
228
229
    Map<Buffer, Layout> layout_map = annotated_layout_map_;
    int num_infer = infer_list_.size();

230
    // Prepare BFS queue for iterative inference
231
232
    std::queue<int> q;
    std::vector<bool> in_queue(num_infer, true);
233
234
235
236
237
238
239
240
241
242
    for (int i = 0; i < num_infer; i++) {
      // Check that each infer_list_ entry is valid
      ICHECK(infer_list_[i] != nullptr)
          << "infer_list_[" << i
          << "] is null. The inference object is not allocated properly.";

      // Check that each thread_var_vec_ entry is defined
      if (!thread_var_vec_[i].defined() && skip_thread_partition_) {
        thread_var_vec_[i] = thread_var_;
      }
243
      q.push(i);
244
    }
245
246
    auto run_infer_step = [&](int cur_infer_id, InferLevel level,
                              bool update_queue) {
247
248
249
250
251
252
253
254
255
      // Range check for cur_infer_id
      ICHECK_GE(cur_infer_id, 0)
          << "cur_infer_id is negative, which is invalid.";
      ICHECK_LT(cur_infer_id, num_infer)
          << "cur_infer_id " << cur_infer_id << " is out of range, must be < "
          << num_infer << ".";

      // Make sure we can safely access infer_list_[cur_infer_id] and
      // thread_var_vec_[cur_infer_id]
256
      auto &next = infer_list_[cur_infer_id];
257
      auto iter_var = thread_var_vec_[cur_infer_id];
258
      auto thread_bounds = thread_bounds_vec_[cur_infer_id];
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
      // Double-check that 'next' is valid
      ICHECK(next != nullptr) << "infer_list_[" << cur_infer_id
                              << "] is null inside run_infer_step.";

      // Check iter_var->dom and dom->extent
      ICHECK(iter_var.defined())
          << "thread_var_vec_[" << cur_infer_id << "] is not defined.";
      ICHECK(iter_var->dom.defined())
          << "iter_var->dom is not defined for infer_list_[" << cur_infer_id
          << "].";
      ICHECK(iter_var->dom->extent.defined())
          << "iter_var->dom->extent is not defined for infer_list_["
          << cur_infer_id << "].";

      const int64_t *extent_ptr = as_const_int(iter_var->dom->extent);
      ICHECK(extent_ptr != nullptr)
          << "iter_var->dom->extent is not a constant integer, which is "
             "required for layout inference.";

      // Run InferLayout
279
      auto updates = next->InferLayout(
280
          LayoutInferArgs{target_, thread_bounds, layout_map}, level);
281
      // Process the returned updates
282
      for (const auto &[buffer, layout] : updates) {
283
284
285
286
        // Basic validity checks
        ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
        ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";

287
        if (layout_map.count(buffer)) {
288
          // If already in map, ensure they are structurally equal
289
          ICHECK(StructuralEqual()(layout, layout_map[buffer]))
290
              << "Get different layout for " << buffer
291
292
              << "\n current layout: " << layout->DebugOutput()
              << "\n previous layout: " << layout_map[buffer]->DebugOutput();
293
        } else {
294
          // Otherwise, update map
295
          layout_map.Set(buffer, layout);
296
297
          if (!update_queue)
            continue;
298
299

          // Check if buffer exists in use_list_
300
          if (!use_list_.count(buffer)) {
301
302
303
304
            LOG(WARNING) << "Layout inference failed for buffer " << buffer
                         << ". "
                         << "The buffer cannot be inferred with current layout "
                            "inference rules.";
305
306
            continue;
          }
307
308

          // Push back into BFS queue
309
          for (int idx : use_list_[buffer]) {
310
311
312
313
314
315
            ICHECK_GE(idx, 0) << "Index in use_list_ for buffer " << buffer
                              << " is negative.";
            ICHECK_LT(idx, num_infer)
                << "Index in use_list_ for buffer " << buffer
                << " out of range: " << idx << " >= " << num_infer << ".";

316
317
318
319
320
321
322
323
            if (!in_queue[idx] && idx != cur_infer_id) {
              in_queue[idx] = true;
              q.push(idx);
            }
          }
        }
      }
    };
324

325
326
327
328
    auto finish_infer_queue = [&]() {
      while (!q.empty()) {
        int cur_infer_id = q.front();
        q.pop();
329
330
331
332
        // Range check again, just to be safe
        ICHECK_GE(cur_infer_id, 0);
        ICHECK_LT(cur_infer_id, num_infer);

333
334
335
336
337
        in_queue[cur_infer_id] = false;
        run_infer_step(cur_infer_id, InferLevel::kCommon, true);
      }
    };

338
    // step 1: infer strict layout
339
340
341
342
    for (int i = 0; i < num_infer; i++) {
      run_infer_step(i, InferLevel::kStrict, false);
    }

343
    // step 2: infer common layout with BFS
344
    finish_infer_queue();
345

346
    // step 3: relax constraints to free and re-run
347
348
349
350
351
    for (int i = 0; i < num_infer; i++) {
      run_infer_step(i, InferLevel::kFree, true);
      finish_infer_queue();
    }

352
    // Check that all local.fragment buffers have inferred layouts
353
    for (const auto &[buffer, _] : use_list_) {
354
355
356
357
358
      if (buffer.scope() == "local.fragment") {
        ICHECK_NE(layout_map.count(buffer), 0)
            << "The layout for fragment " << buffer
            << " can not be inferred correctly.";
      }
359
360
    }

361
    // Collect layout info for For nodes
362
363
    Map<For, Fragment> for_map;
    Map<For, PrimExpr> predicate_map;
364
365
366
367
368
369
    ICHECK(infer_list_.size() == thread_var_vec_.size())
        << "infer_list_ and thread_var_vec_ size mismatch";
    for (int i = 0; i < infer_list_.size(); i++) {
      std::unique_ptr<Operator> base_infer = std::move(infer_list_[i]);
      auto thread_var = thread_var_vec_[i];

370
371
372
      // Check if base_infer is valid
      ICHECK(base_infer != nullptr) << "Null pointer encountered in "
                                       "infer_list_ while collecting for_map.";
373
      if (auto for_infer = dynamic_cast<ParallelOp *>(base_infer.get())) {
374
        // Check that the loop layout is defined
375
        ICHECK(for_infer->GetLoopLayout().defined())
376
            << "The Layout for Parallel for cannot be inferred correctly:\n"
377
378
            << for_infer->GetRoot();
        for_map.Set(for_infer->GetRoot(), for_infer->GetLoopLayout());
379
        // thread_var_ should be defined if we rely on it
380
381
        ICHECK(thread_var.defined())
            << "thread_var is not defined. Cannot retrieve predicate.";
382

383
        if (auto predicate = for_infer->GetPredicate(thread_var->var)) {
384
          predicate_map.Set(for_infer->GetRoot(), predicate.value());
385
        }
386
387
388
389
390
391
      }
    }

    return {layout_map, for_map, predicate_map};
  }

392
393
  void Collect(const PrimFunc &f) {
    for (const auto &[_, buffer] : f->buffer_map) {
394
395
396
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
397
398
    ICHECK(target.defined())
        << "Layout_Inference: Require the target attribute";
399
400
401
402
    target_ = target.value();
    this->operator()(f->body);
  }

403
404
private:
  void VisitExpr_(const CallNode *op) final {
405
    IRVisitorWithAnalyzer::VisitExpr_(op);
406
    // Do not analysis the call node to the global function.
407
408
    if (op->op.as<GlobalVarNode>())
      return;
409
410
411

    auto p = ParseOperator(GetRef<Call>(op), buffer_data_to_buffer_);
    if (p != nullptr) {
412
      for (const auto &arg : op->args) {
413
414
415
416
417
418
        if (auto buffer = getBufferFromAccessPtr(arg)) {
          addToUseList(buffer.value());
        }
      }
      infer_list_.push_back(std::move(p));
      thread_var_vec_.push_back(thread_var_);
419
420
421
422
      if (analyzer_.const_int_bound.IsBound(thread_var_->var)) {
        auto const_int_bound = analyzer_.const_int_bound(thread_var_);
        auto min_value = const_int_bound->min_value;
        auto max_value = const_int_bound->max_value;
423
        auto extent = max_value - min_value + 1;
424
425
        auto dtype = thread_var_->var.dtype();
        thread_bounds_vec_.push_back(Range::FromMinExtent(
426
            IntImm(dtype, min_value), IntImm(dtype, extent)));
427
428
429
      } else {
        thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
      }
430
431
432
    }
  }

433
  Optional<Buffer> getBufferFromAccessPtr(const PrimExpr &expr) {
434
435
436
437
438
439
440
441
    auto call = expr.as<CallNode>();
    if (call && call->op.same_as(builtin::tvm_access_ptr())) {
      auto var = call->args[1].as<Var>().value();
      return buffer_data_to_buffer_[var];
    }
    return NullOpt;
  }

442
  void addToUseList(const Buffer &buffer) {
443
444
445
446
447
448
449
    int infer_idx = infer_list_.size();
    if (use_list_.find(buffer) == use_list_.end()) {
      use_list_[buffer] = {};
    }
    use_list_[buffer].push_back(infer_idx);
  }

450
  void VisitStmt_(const ForNode *op) final {
451
452
    if (op->kind == ForKind::kParallel) {
      auto infer = std::make_unique<ParallelOp>(GetRef<For>(op));
453
      for (const auto &[buffer, _] : infer->GetIndiceMap()) {
454
455
456
457
        addToUseList(buffer);
      }
      infer_list_.push_back(std::move(infer));
      thread_var_vec_.push_back(thread_var_);
458
459
460
461
      if (thread_var_.defined() &&
          analyzer_.const_int_bound.IsBound(thread_var_->var)) {
        auto const_int_bound = analyzer_.const_int_bound(thread_var_);
        auto dtype = thread_var_->var.dtype();
462
463
        auto extent =
            const_int_bound->max_value - const_int_bound->min_value + 1;
464
        thread_bounds_vec_.push_back(Range::FromMinExtent(
465
            IntImm(dtype, const_int_bound->min_value), IntImm(dtype, extent)));
466
467
468
      } else {
        thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
      }
469
    } else {
470
      IRVisitorWithAnalyzer::VisitStmt(op->body);
471
472
473
    }
  }

474
  void VisitStmt_(const BlockNode *op) final {
475
476
477
478
    for (auto buffer : op->alloc_buffers) {
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    if (op->annotations.count(attr::kLayoutMap)) {
479
480
481
      auto map =
          op->annotations.Get(attr::kLayoutMap).as<Map<Var, Layout>>().value();
      for (const auto &[var, layout] : map) {
482
483
484
485
486
        auto buffer = buffer_data_to_buffer_[var];
        ICHECK(StructuralEqual()(layout->InputShape(), buffer->shape));
        annotated_layout_map_.Set(buffer, layout);
      }
    }
487
    IRVisitorWithAnalyzer::VisitStmt_(op);
488
489
  }

490
  void VisitStmt_(const AttrStmtNode *op) final {
491
492
493
494
495
496
497
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      if (iv->thread_tag == "threadIdx.x") {
        ICHECK(iv->dom->extent.as<IntImmNode>());
        thread_var_ = iv;
      }
    }
498
    IRVisitorWithAnalyzer::VisitStmt_(op);
499
500
501
502
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  std::vector<std::unique_ptr<Operator>> infer_list_;
503
504
  std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
      use_list_;
505
506
507
508
  // This is a workaround for cpu backend,
  // we need to define a thread_var for the serial loop.
  IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
                                IterVarType::kDataPar);
509
  std::vector<IterVar> thread_var_vec_;
510
  std::vector<Range> thread_bounds_vec_;
511
512
  Target target_;
  LayoutMap annotated_layout_map_;
513
  bool skip_thread_partition_{false};
514
515
516
};

class LayoutInferencer : public IRMutatorWithAnalyzer {
517
public:
518
  static PrimFunc Substitute(PrimFunc f, bool skip_thread_partition = false) {
519
    arith::Analyzer analyzer;
520
    PrimFuncNode *fptr = f.CopyOnWrite();
521
    fptr->body = ParallelLoopFuser::Fuse(f->body);
522
    BufferUseDefCollector collector(skip_thread_partition);
523
524
    collector.Collect(f);
    auto result = collector.Run();
525
    LayoutInferencer substituter(result, skip_thread_partition, &analyzer);
526
527
528
529
    fptr->body = substituter.VisitStmt(f->body);
    return f;
  }

530
531
private:
  LayoutInferencer(const LayoutInferenceResult result,
532
533
534
                   bool skip_thread_partition, arith::Analyzer *analyzer)
      : arith::IRMutatorWithAnalyzer(analyzer), result_(result),
        skip_thread_partition_(skip_thread_partition){};
535

536
  Stmt VisitStmt_(const BlockNode *op) final {
537
538
539
540
541
542
543
544
545
546
547
548
549
    Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));

    for (auto buffer : block->alloc_buffers) {
      if (buffer.scope() == "local.framgent") {
        ICHECK(result_.layout_map.count(buffer))
            << "Cannot inference fragment layout for " << buffer;
      }
    }
    auto block_ptr = block.CopyOnWrite();
    block_ptr->annotations.Set(attr::kLayoutMap, result_.layout_map);
    return block;
  }

550
  Stmt VisitStmt_(const ForNode *op) final {
551
552
    For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
    if (result_.for_map.count(GetRef<For>(op))) {
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
      auto root = GetRef<For>(op);
      // This check is a workaround to support T.Parallel for local buffers.
      // For example:
      //   for i in T.Parallel(1024):
      //     A_local[i] = A_global[i]
      // Here, A_local is a register-local buffer held independently by each
      // thread, so explicit thread binding is not required.
      //
      // We use PostOrderVisit to detect whether the buffer store targets a
      // "local" buffer, which indicates register usage and justifies skipping
      // thread binding.
      bool is_register_store = false;
      PostOrderVisit(root, [&](const ObjectRef &obj) {
        if (const auto *store = obj.as<BufferStoreNode>()) {
          if (store->buffer.scope() == "local") {
            is_register_store = true;
          }
        }
      });

573
      auto loop_layout = result_.for_map[root];
574
575
      bool parallel_loop = !is_register_store && !skip_thread_partition_;
      if (parallel_loop) {
576
577
578
        for_node =
            PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
      }
579
      // If none thread bindings are provided, partition the loop
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
      bool has_non_local = false;
      PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
        if (const auto *load = obj.as<BufferLoadNode>()) {
          String scope = load->buffer.scope();
          if (scope != "local" && scope != "local.fragment") {
            has_non_local = true;
          }
        } else if (const auto *store = obj.as<BufferStoreNode>()) {
          String scope = store->buffer.scope();
          if (scope != "local" && scope != "local.fragment") {
            has_non_local = true;
          }
        }
      });

      if (has_non_local) {
        for_node = VectorizeLoop(for_node);
      }
598

599
600
      if (result_.predicate_map.count(root) && parallel_loop) {
        return IfThenElse(result_.predicate_map[root], for_node);
601
602
603
604
605
606
607
      } else {
        return for_node;
      }
    }
    return for_node;
  }

608
  Stmt VisitStmt_(const AttrStmtNode *op) final {
609
610
611
612
613
614
615
616
617
618
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      ICHECK_NE(iv->thread_tag.length(), 0U);
      if (iv->thread_tag == "threadIdx.x") {
        thread_var_ = iv;
      }
    }
    return IRMutatorWithAnalyzer::VisitStmt_(op);
  }

619
private:
620
  const LayoutInferenceResult result_;
621
622
  IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
                                IterVarType::kDataPar);
623
  bool skip_thread_partition_{false};
624
625
626
627
628
};

tvm::transform::Pass LayoutInference() {
  using namespace tir::transform;
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
629
    f.CopyOnWrite()->body = ParallelLoopTransformer::Substitute(f->body);
630
    ThreadBindingCollector collector;
631
    collector(f->body);
632
633
    bool has_thread_binding = collector.thread_binding_.size() > 0;
    bool skip_thread_partition = !has_thread_binding;
634
    return LayoutInferencer::Substitute(std::move(f), skip_thread_partition);
635
636
637
638
639
640
641
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {});
}

TVM_REGISTER_GLOBAL("tl.transform.LayoutInference")
    .set_body_typed(LayoutInference);

642
643
} // namespace tl
} // namespace tvm