layout_inference.cc 24 KB
Newer Older
1
2
3
4
5
6
/*!
 * \file layout_inference.cc
 * \brief infer the fragment/shared memory layout
 */

#include <tvm/tir/builtin.h>
7
#include <tvm/tir/index_map.h>
8
9
10
11
12
13
14
15
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>

#include <queue>

#include "../op/parallel.h"
16
#include "arith/ir_mutator_with_analyzer.h"
17
#include "arith/ir_visitor_with_analyzer.h"
18
#include "common/loop_fusion_utils.h"
19
20
#include "loop_partition.h"
#include "loop_vectorize.h"
21
22
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
23
24
25
26

namespace tvm {
namespace tl {

27
28
29
using namespace tir;

/*!
30
 * \brief collect the mapping from the buffer var to it allocated buffer
31
 */
32
class ThreadBindingCollector : public StmtExprVisitor {
33
34
public:
  void VisitStmt_(const AttrStmtNode *op) final {
35
36
37
38
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      thread_binding_[iv->var.get()] = iv;
    }
39
40
41
    StmtExprVisitor::VisitStmt_(op);
  }

42
43
  // The thread binding map
  std::unordered_map<const VarNode *, IterVar> thread_binding_;
44
45
};

46
47
using namespace tir;
using arith::IRMutatorWithAnalyzer;
48
using arith::IRVisitorWithAnalyzer;
49

50
51
52
53
54
55
56
57
58
59
60
61
class ParallelLoopTransformer : public IRMutatorWithAnalyzer {
public:
  static Stmt Substitute(Stmt stmt, bool skip_thread_partition = false) {
    arith::Analyzer analyzer;
    ParallelLoopTransformer transformer(&analyzer);
    return transformer.VisitStmt(stmt);
  }

  ParallelLoopTransformer(arith::Analyzer *analyzer)
      : IRMutatorWithAnalyzer(analyzer) {}

  Stmt VisitStmt_(const ForNode *op) final {
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    if (op->kind != ForKind::kParallel)
      return StmtMutator::VisitStmt_(op);

    // Collect loop variables and ranges
    auto for_node = GetRef<For>(op);
    Array<Var> loop_vars;
    Array<PrimExpr> loop_extents;
    Stmt body = op->body;

    // Bind the range of outer loop variables
    analyzer_->Bind(op->loop_var, Range::FromMinExtent(0, op->extent));
    loop_vars.push_back(op->loop_var);
    loop_extents.push_back(op->extent);

    // If there are inner loops, bind their ranges as well
    while (const ForNode *inner = body.as<ForNode>()) {
      analyzer_->Bind(inner->loop_var, Range::FromMinExtent(0, inner->extent));
      loop_vars.push_back(inner->loop_var);
      loop_extents.push_back(inner->extent);
      body = inner->body;
    }
83

84
85
    ICHECK(loop_vars.size() == loop_extents.size())
        << "loop_vars and loop_extents size mismatch";
86

87
88
89
    // Collect buffer access information
    BufferAccessCollector collector;
    collector(op->body);
90

91
    PrimExpr condition;
92

93
94
95
    for (const auto &[buffer, indices] : collector.buffer_indices) {
      ICHECK(indices.size() == buffer->shape.size())
          << "indices size mismatch with buffer shape";
96

97
98
99
100
101
      for (size_t i = 0; i < indices.size(); ++i) {
        auto index = indices[i];
        auto bound = analyzer_->const_int_bound(index);
        int64_t upper_bound = bound->max_value + 1;
        int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;
102

103
104
105
106
107
108
109
110
111
112
113
        // Collect the variables that used in the index
        std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_vars;
        // post order visit the index
        PostOrderVisit(index, [&](const ObjectRef &obj) {
          if (const VarNode *v = obj.as<VarNode>()) {
            used_vars.insert(GetRef<Var>(v));
          }
        });
        if (used_vars.size() == 0) {
          continue;
        }
114

115
116
117
118
119
120
121
        // find related loop vars
        Array<Var> related_loop_vars;
        for (size_t j = 0; j < loop_vars.size(); ++j) {
          auto loop_var = loop_vars[j];
          // if find related, pop the loop_vars and loop_extents
          if (used_vars.count(loop_var)) {
            related_loop_vars.push_back(loop_var);
122
          }
123
124
125
126
127
128
          ICHECK(related_loop_vars.size() <= 1)
              << "Only one related loop var is supported currently, but got "
              << related_loop_vars
              << " implement multiple loop vars may not be "
              << "too hard, please send an issue if you need "
              << "came up with this message.";
129

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
          auto bound = analyzer_->const_int_bound(index);
          int64_t upper_bound = bound->max_value + 1;
          int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;
          if (upper_bound < shape) {
            PrimExpr predicate = LT(index, IntImm(index.dtype(), upper_bound));
            condition =
                condition.defined() ? And(condition, predicate) : predicate;

            // replace the buffer index from A[i, r * 2] with A[i, j]
            // where r is the original index, j is the loop_var
            auto index_map = tir::IndexMap({loop_var}, {index});
            auto inverse_index_map = index_map.Inverse(
                {Range::FromMinExtent(0, IntImm(index.dtype(), upper_bound))},
                analyzer_);

            loop_extents.Set(i, IntImm(index.dtype(), shape));
            body = tir::Substitute(body,
                                   {{loop_var, inverse_index_map->MapIndices(
                                                   {loop_var}, analyzer_)[0]}});
149
150
151
          }
        }
      }
152
153
154
155
156
157
158
    }
    if (condition.defined()) {
      body = IfThenElse(condition, body);
      for (int j = loop_vars.size() - 1; j >= 0; --j) {
        auto loop_var = loop_vars[j];
        auto loop_extent = loop_extents[j];
        body = For(loop_var, 0, loop_extent, ForKind::kParallel, body);
159
      }
160
      return Downcast<For>(body);
161
    }
162
163
    // Only traverse the outer loop
    return for_node;
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
  }

private:
  // Helper class for collecting buffer access information, only counts fragment
  // buffer access
  class BufferAccessCollector : public StmtExprVisitor {
  public:
    void VisitExpr_(const BufferLoadNode *op) final {
      if (op->buffer.scope() == "local.fragment") {
        if (buffer_indices.find(op->buffer) == buffer_indices.end()) {
          buffer_indices[op->buffer] = op->indices;
        } else {
          // check equal
          ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices))
              << "indices mismatch for buffer: " << op->buffer;
        }
      }
      StmtExprVisitor::VisitExpr_(op);
    }

    void VisitStmt_(const BufferStoreNode *op) final {
      if (op->buffer.scope() == "local.fragment") {
        if (buffer_indices.find(op->buffer) == buffer_indices.end()) {
          buffer_indices[op->buffer] = op->indices;
        } else {
          // check equal
          ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices))
              << "indices mismatch for buffer: " << op->buffer;
        }
      }
      StmtExprVisitor::VisitStmt_(op);
    }

    std::unordered_map<Buffer, Array<PrimExpr>, ObjectPtrHash, ObjectPtrEqual>
        buffer_indices;
  };
};

202
203
204
205
206
207
struct LayoutInferenceResult {
  Map<Buffer, Layout> layout_map;
  Map<For, Fragment> for_map;
  Map<For, PrimExpr> predicate_map;
};

208
class BufferUseDefCollector : public IRVisitorWithAnalyzer {
209
public:
210
211
  BufferUseDefCollector(bool skip_thread_partition)
      : skip_thread_partition_(skip_thread_partition) {}
212
213

  LayoutInferenceResult Run() {
214
215
216
217
218
    // Basic consistency check: infer_list_ and thread_var_vec_ should have the
    // same size
    ICHECK_EQ(infer_list_.size(), thread_var_vec_.size())
        << "Size mismatch: infer_list_ and thread_var_vec_ must match in "
           "length.";
219
220
221
    ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size())
        << "Size mismatch: thread_bounds_vec_ and infer_list_ must match in "
           "length.";
222
223
224
225
226

    // If needed, you can also check that annotated_layout_map_ is not empty, or
    // anything else relevant to your setup.

    // Copy the annotated layout map to local variable
227
    Map<Buffer, Layout> layout_map = annotated_layout_map_;
228
    Map<Buffer, Layout> strict_layout_map;
229
230
    int num_infer = infer_list_.size();

231
    // Prepare BFS queue for iterative inference
232
233
    std::queue<int> q;
    std::vector<bool> in_queue(num_infer, true);
234
235
236
237
238
239
240
241
242
243
    for (int i = 0; i < num_infer; i++) {
      // Check that each infer_list_ entry is valid
      ICHECK(infer_list_[i] != nullptr)
          << "infer_list_[" << i
          << "] is null. The inference object is not allocated properly.";

      // Check that each thread_var_vec_ entry is defined
      if (!thread_var_vec_[i].defined() && skip_thread_partition_) {
        thread_var_vec_[i] = thread_var_;
      }
244
      q.push(i);
245
    }
246

247
248
    auto run_infer_step = [&](int cur_infer_id, InferLevel level,
                              bool update_queue) {
249
250
251
252
253
254
255
256
257
      // Range check for cur_infer_id
      ICHECK_GE(cur_infer_id, 0)
          << "cur_infer_id is negative, which is invalid.";
      ICHECK_LT(cur_infer_id, num_infer)
          << "cur_infer_id " << cur_infer_id << " is out of range, must be < "
          << num_infer << ".";

      // Make sure we can safely access infer_list_[cur_infer_id] and
      // thread_var_vec_[cur_infer_id]
258
      auto &next = infer_list_[cur_infer_id];
259
      auto iter_var = thread_var_vec_[cur_infer_id];
260
      auto thread_bounds = thread_bounds_vec_[cur_infer_id];
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
      // Double-check that 'next' is valid
      ICHECK(next != nullptr) << "infer_list_[" << cur_infer_id
                              << "] is null inside run_infer_step.";

      // Check iter_var->dom and dom->extent
      ICHECK(iter_var.defined())
          << "thread_var_vec_[" << cur_infer_id << "] is not defined.";
      ICHECK(iter_var->dom.defined())
          << "iter_var->dom is not defined for infer_list_[" << cur_infer_id
          << "].";
      ICHECK(iter_var->dom->extent.defined())
          << "iter_var->dom->extent is not defined for infer_list_["
          << cur_infer_id << "].";

      const int64_t *extent_ptr = as_const_int(iter_var->dom->extent);
      ICHECK(extent_ptr != nullptr)
          << "iter_var->dom->extent is not a constant integer, which is "
             "required for layout inference.";

      // Run InferLayout
281
      auto updates = next->InferLayout(
282
          LayoutInferArgs{target_, thread_bounds, layout_map}, level);
283
      // Process the returned updates
284
      for (const auto &[buffer, layout] : updates) {
285
286
287
288
        // Basic validity checks
        ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
        ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";

289
        if (layout_map.count(buffer)) {
290
291
          // If replicate size of this buffer is greater than the old one
          if (buffer.scope() == "local.fragment" &&
292
293
              level != InferLevel::kStrict &&
              !strict_layout_map.count(buffer)) {
294
295
296
297
298
299
300
301
302
303
304
305
            const FragmentNode *dst_layout = layout.as<Fragment>().get();
            const FragmentNode *src_layout =
                layout_map[buffer].as<Fragment>().get();
            if (as_const_int(dst_layout->ReplicateExtent()) &&
                as_const_int(src_layout->ReplicateExtent()) &&
                (*as_const_int(dst_layout->ReplicateExtent()) >
                 *as_const_int(src_layout->ReplicateExtent()))) {
              // update map
              layout_map.Set(buffer, layout);
              continue;
            }
          }
306
          // If already in map, ensure they are structurally equal
307
          ICHECK(StructuralEqual()(layout, layout_map[buffer]))
308
              << "Get different layout for " << buffer
309
310
              << "\n current layout: " << layout->DebugOutput()
              << "\n previous layout: " << layout_map[buffer]->DebugOutput();
311
        } else {
312
          // Otherwise, update map
313
          layout_map.Set(buffer, layout);
314
315
          if (!update_queue)
            continue;
316
317

          // Check if buffer exists in use_list_
318
          if (!use_list_.count(buffer)) {
319
320
321
322
            LOG(WARNING) << "Layout inference failed for buffer " << buffer
                         << ". "
                         << "The buffer cannot be inferred with current layout "
                            "inference rules.";
323
324
            continue;
          }
325
326

          // Push back into BFS queue
327
          for (int idx : use_list_[buffer]) {
328
329
330
331
332
333
            ICHECK_GE(idx, 0) << "Index in use_list_ for buffer " << buffer
                              << " is negative.";
            ICHECK_LT(idx, num_infer)
                << "Index in use_list_ for buffer " << buffer
                << " out of range: " << idx << " >= " << num_infer << ".";

334
335
336
337
338
339
340
341
            if (!in_queue[idx] && idx != cur_infer_id) {
              in_queue[idx] = true;
              q.push(idx);
            }
          }
        }
      }
    };
342

343
344
345
346
    auto finish_infer_queue = [&]() {
      while (!q.empty()) {
        int cur_infer_id = q.front();
        q.pop();
347
348
349
350
        // Range check again, just to be safe
        ICHECK_GE(cur_infer_id, 0);
        ICHECK_LT(cur_infer_id, num_infer);

351
352
353
354
355
        in_queue[cur_infer_id] = false;
        run_infer_step(cur_infer_id, InferLevel::kCommon, true);
      }
    };

356
    // step 1: infer strict layout
357
358
359
360
    for (int i = 0; i < num_infer; i++) {
      run_infer_step(i, InferLevel::kStrict, false);
    }

361
362
363
364
    for (const auto &[buffer, layout] : layout_map) {
      strict_layout_map.Set(buffer, layout);
    }

365
    // step 2: infer common layout with BFS
366
    finish_infer_queue();
367

368
    // step 3: relax constraints to free and re-run
369
370
371
372
    for (int i = 0; i < num_infer; i++) {
      run_infer_step(i, InferLevel::kFree, true);
      finish_infer_queue();
    }
373
    // Check that all local.fragment buffers have inferred layouts
374
    for (const auto &[buffer, _] : use_list_) {
375
376
377
378
379
      if (buffer.scope() == "local.fragment") {
        ICHECK_NE(layout_map.count(buffer), 0)
            << "The layout for fragment " << buffer
            << " can not be inferred correctly.";
      }
380
381
    }

382
    // Collect layout info for For nodes
383
384
    Map<For, Fragment> for_map;
    Map<For, PrimExpr> predicate_map;
385
386
387
388
389
390
    ICHECK(infer_list_.size() == thread_var_vec_.size())
        << "infer_list_ and thread_var_vec_ size mismatch";
    for (int i = 0; i < infer_list_.size(); i++) {
      std::unique_ptr<Operator> base_infer = std::move(infer_list_[i]);
      auto thread_var = thread_var_vec_[i];

391
392
393
      // Check if base_infer is valid
      ICHECK(base_infer != nullptr) << "Null pointer encountered in "
                                       "infer_list_ while collecting for_map.";
394
      if (auto for_infer = dynamic_cast<ParallelOp *>(base_infer.get())) {
395
        // Check that the loop layout is defined
396
        ICHECK(for_infer->GetLoopLayout().defined())
397
            << "The Layout for Parallel for cannot be inferred correctly:\n"
398
399
            << for_infer->GetRoot();
        for_map.Set(for_infer->GetRoot(), for_infer->GetLoopLayout());
400
        // thread_var_ should be defined if we rely on it
401
402
        ICHECK(thread_var.defined())
            << "thread_var is not defined. Cannot retrieve predicate.";
403

404
        if (auto predicate = for_infer->GetPredicate(thread_var->var)) {
405
          predicate_map.Set(for_infer->GetRoot(), predicate.value());
406
        }
407
408
409
410
411
412
      }
    }

    return {layout_map, for_map, predicate_map};
  }

413
414
  void Collect(const PrimFunc &f) {
    for (const auto &[_, buffer] : f->buffer_map) {
415
416
417
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
418
419
    ICHECK(target.defined())
        << "Layout_Inference: Require the target attribute";
420
421
422
423
    target_ = target.value();
    this->operator()(f->body);
  }

424
425
private:
  void VisitExpr_(const CallNode *op) final {
426
    IRVisitorWithAnalyzer::VisitExpr_(op);
427
    // Do not analysis the call node to the global function.
428
429
    if (op->op.as<GlobalVarNode>())
      return;
430
431
432

    auto p = ParseOperator(GetRef<Call>(op), buffer_data_to_buffer_);
    if (p != nullptr) {
433
      for (const auto &arg : op->args) {
434
435
436
437
438
439
        if (auto buffer = getBufferFromAccessPtr(arg)) {
          addToUseList(buffer.value());
        }
      }
      infer_list_.push_back(std::move(p));
      thread_var_vec_.push_back(thread_var_);
440
441
442
443
      if (analyzer_.const_int_bound.IsBound(thread_var_->var)) {
        auto const_int_bound = analyzer_.const_int_bound(thread_var_);
        auto min_value = const_int_bound->min_value;
        auto max_value = const_int_bound->max_value;
444
        auto extent = max_value - min_value + 1;
445
446
        auto dtype = thread_var_->var.dtype();
        thread_bounds_vec_.push_back(Range::FromMinExtent(
447
            IntImm(dtype, min_value), IntImm(dtype, extent)));
448
449
450
      } else {
        thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
      }
451
452
453
    }
  }

454
  Optional<Buffer> getBufferFromAccessPtr(const PrimExpr &expr) {
455
456
457
458
459
460
461
462
    auto call = expr.as<CallNode>();
    if (call && call->op.same_as(builtin::tvm_access_ptr())) {
      auto var = call->args[1].as<Var>().value();
      return buffer_data_to_buffer_[var];
    }
    return NullOpt;
  }

463
  void addToUseList(const Buffer &buffer) {
464
465
466
467
468
469
470
    int infer_idx = infer_list_.size();
    if (use_list_.find(buffer) == use_list_.end()) {
      use_list_[buffer] = {};
    }
    use_list_[buffer].push_back(infer_idx);
  }

471
  void VisitStmt_(const ForNode *op) final {
472
473
    if (op->kind == ForKind::kParallel) {
      auto infer = std::make_unique<ParallelOp>(GetRef<For>(op));
474
      for (const auto &[buffer, _] : infer->GetIndiceMap()) {
475
476
477
478
        addToUseList(buffer);
      }
      infer_list_.push_back(std::move(infer));
      thread_var_vec_.push_back(thread_var_);
479
480
481
482
      if (thread_var_.defined() &&
          analyzer_.const_int_bound.IsBound(thread_var_->var)) {
        auto const_int_bound = analyzer_.const_int_bound(thread_var_);
        auto dtype = thread_var_->var.dtype();
483
484
        auto extent =
            const_int_bound->max_value - const_int_bound->min_value + 1;
485
        thread_bounds_vec_.push_back(Range::FromMinExtent(
486
            IntImm(dtype, const_int_bound->min_value), IntImm(dtype, extent)));
487
488
489
      } else {
        thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
      }
490
    } else {
491
      IRVisitorWithAnalyzer::VisitStmt(op->body);
492
493
494
    }
  }

495
  void VisitStmt_(const BlockNode *op) final {
496
497
498
499
    for (auto buffer : op->alloc_buffers) {
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    if (op->annotations.count(attr::kLayoutMap)) {
500
501
502
503
504
505
      // Check if the layout map is Map<Var, Layout>
      auto map = op->annotations.Get(attr::kLayoutMap).as<Map<Var, Layout>>();
      ICHECK(map.defined()) << "layout map is not defined";
      ICHECK(map.value().defined()) << "layout map is not defined";

      for (const auto &[var, layout] : map.value()) {
506
507
        ICHECK(buffer_data_to_buffer_.count(var))
            << "buffer " << var << " is not found in the block";
508
509
510
511
512
        auto buffer = buffer_data_to_buffer_[var];
        ICHECK(StructuralEqual()(layout->InputShape(), buffer->shape));
        annotated_layout_map_.Set(buffer, layout);
      }
    }
513
    IRVisitorWithAnalyzer::VisitStmt_(op);
514
515
  }

516
  void VisitStmt_(const AttrStmtNode *op) final {
517
518
519
520
521
522
523
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      if (iv->thread_tag == "threadIdx.x") {
        ICHECK(iv->dom->extent.as<IntImmNode>());
        thread_var_ = iv;
      }
    }
524
    IRVisitorWithAnalyzer::VisitStmt_(op);
525
526
527
528
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  std::vector<std::unique_ptr<Operator>> infer_list_;
529
530
  std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
      use_list_;
531
532
533
534
  // This is a workaround for cpu backend,
  // we need to define a thread_var for the serial loop.
  IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
                                IterVarType::kDataPar);
535
  std::vector<IterVar> thread_var_vec_;
536
  std::vector<Range> thread_bounds_vec_;
537
538
  Target target_;
  LayoutMap annotated_layout_map_;
539
  bool skip_thread_partition_{false};
540
541
542
};

class LayoutInferencer : public IRMutatorWithAnalyzer {
543
public:
544
  static PrimFunc Substitute(PrimFunc f, bool skip_thread_partition = false) {
545
    arith::Analyzer analyzer;
546
    PrimFuncNode *fptr = f.CopyOnWrite();
547
    fptr->body = ParallelLoopFuser::Fuse(f->body);
548
    BufferUseDefCollector collector(skip_thread_partition);
549
550
    collector.Collect(f);
    auto result = collector.Run();
551
    LayoutInferencer substituter(result, skip_thread_partition, &analyzer);
552
553
554
555
    fptr->body = substituter.VisitStmt(f->body);
    return f;
  }

556
557
private:
  LayoutInferencer(const LayoutInferenceResult result,
558
559
560
                   bool skip_thread_partition, arith::Analyzer *analyzer)
      : arith::IRMutatorWithAnalyzer(analyzer), result_(result),
        skip_thread_partition_(skip_thread_partition){};
561

562
  Stmt VisitStmt_(const BlockNode *op) final {
563
564
565
566
567
568
569
570
571
572
573
574
575
    Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));

    for (auto buffer : block->alloc_buffers) {
      if (buffer.scope() == "local.framgent") {
        ICHECK(result_.layout_map.count(buffer))
            << "Cannot inference fragment layout for " << buffer;
      }
    }
    auto block_ptr = block.CopyOnWrite();
    block_ptr->annotations.Set(attr::kLayoutMap, result_.layout_map);
    return block;
  }

576
  Stmt VisitStmt_(const ForNode *op) final {
577
578
    For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
    if (result_.for_map.count(GetRef<For>(op))) {
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
      auto root = GetRef<For>(op);
      // This check is a workaround to support T.Parallel for local buffers.
      // For example:
      //   for i in T.Parallel(1024):
      //     A_local[i] = A_global[i]
      // Here, A_local is a register-local buffer held independently by each
      // thread, so explicit thread binding is not required.
      //
      // We use PostOrderVisit to detect whether the buffer store targets a
      // "local" buffer, which indicates register usage and justifies skipping
      // thread binding.
      bool is_register_store = false;
      PostOrderVisit(root, [&](const ObjectRef &obj) {
        if (const auto *store = obj.as<BufferStoreNode>()) {
          if (store->buffer.scope() == "local") {
            is_register_store = true;
          }
        }
      });

599
      auto loop_layout = result_.for_map[root];
600
601
      bool parallel_loop = !is_register_store && !skip_thread_partition_;
      if (parallel_loop) {
602
603
604
        for_node =
            PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
      }
605
      // If none thread bindings are provided, partition the loop
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
      bool has_non_local = false;
      PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
        if (const auto *load = obj.as<BufferLoadNode>()) {
          String scope = load->buffer.scope();
          if (scope != "local" && scope != "local.fragment") {
            has_non_local = true;
          }
        } else if (const auto *store = obj.as<BufferStoreNode>()) {
          String scope = store->buffer.scope();
          if (scope != "local" && scope != "local.fragment") {
            has_non_local = true;
          }
        }
      });

      if (has_non_local) {
        for_node = VectorizeLoop(for_node);
      }
624

625
626
      if (result_.predicate_map.count(root) && parallel_loop) {
        return IfThenElse(result_.predicate_map[root], for_node);
627
628
629
630
631
632
633
      } else {
        return for_node;
      }
    }
    return for_node;
  }

634
  Stmt VisitStmt_(const AttrStmtNode *op) final {
635
636
637
638
639
640
641
642
643
644
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      ICHECK_NE(iv->thread_tag.length(), 0U);
      if (iv->thread_tag == "threadIdx.x") {
        thread_var_ = iv;
      }
    }
    return IRMutatorWithAnalyzer::VisitStmt_(op);
  }

645
private:
646
  const LayoutInferenceResult result_;
647
648
  IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
                                IterVarType::kDataPar);
649
  bool skip_thread_partition_{false};
650
651
652
653
654
};

tvm::transform::Pass LayoutInference() {
  using namespace tir::transform;
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
655
    f.CopyOnWrite()->body = ParallelLoopTransformer::Substitute(f->body);
656
    ThreadBindingCollector collector;
657
    collector(f->body);
658
659
    bool has_thread_binding = collector.thread_binding_.size() > 0;
    bool skip_thread_partition = !has_thread_binding;
660
    return LayoutInferencer::Substitute(std::move(f), skip_thread_partition);
661
662
663
664
665
666
667
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {});
}

TVM_REGISTER_GLOBAL("tl.transform.LayoutInference")
    .set_body_typed(LayoutInference);

668
669
} // namespace tl
} // namespace tvm