layout_inference.cc 30.3 KB
Newer Older
1
2
3
4
5
/*!
 * \file layout_inference.cc
 * \brief infer the fragment/shared memory layout
 */

6
#include <tvm/ffi/reflection/registry.h>
7
#include <tvm/tir/builtin.h>
8
#include <tvm/tir/index_map.h>
9
10
11
12
13
14
15
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>

#include <queue>

16
#include "../layout/utils.h"
17
#include "../op/copy.h"
18
#include "../op/parallel.h"
19
#include "../op/region.h"
20

21
#include "arith/ir_mutator_with_analyzer.h"
22
#include "arith/ir_visitor_with_analyzer.h"
23
#include "common/loop_fusion_utils.h"
24
#include "common/loop_parallel_transform_utils.h"
25
#include "common/union_find.h"
26
#include "layout_reducer.h"
27
28
#include "loop_partition.h"
#include "loop_vectorize.h"
29
30
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
31
32
33
34

namespace tvm {
namespace tl {

35
36
37
using namespace tir;

/*!
38
 * \brief collect the mapping from the buffer var to it allocated buffer
39
 */
40
class ThreadBindingCollector : public StmtExprVisitor {
41
42
public:
  void VisitStmt_(const AttrStmtNode *op) final {
43
44
45
46
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      thread_binding_[iv->var.get()] = iv;
    }
47
48
49
    StmtExprVisitor::VisitStmt_(op);
  }

50
51
  // The thread binding map
  std::unordered_map<const VarNode *, IterVar> thread_binding_;
52
53
};

54
55
using namespace tir;
using arith::IRMutatorWithAnalyzer;
56
using arith::IRVisitorWithAnalyzer;
57
58
59
60
61
62
63

struct LayoutInferenceResult {
  Map<Buffer, Layout> layout_map;
  Map<For, Fragment> for_map;
  Map<For, PrimExpr> predicate_map;
};

64
class BufferUseDefCollector : public IRVisitorWithAnalyzer {
65
public:
66
67
  BufferUseDefCollector(bool skip_thread_partition)
      : skip_thread_partition_(skip_thread_partition) {}
68

69
70
  using arith::IRVisitorWithAnalyzer::IRVisitorWithAnalyzer;

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
  void RunInferStep(int cur_infer_id, InferLevel level, bool update_queue,
                    LayoutMap &layout_map, const LayoutMap &strict_layout_map,
                    std::queue<int> &q, std::vector<bool> &in_queue) {
    auto num_infer = infer_list_.size();

    // Range check for cur_infer_id
    ICHECK_GE(cur_infer_id, 0) << "cur_infer_id is negative, which is invalid.";
    ICHECK_LT(cur_infer_id, num_infer)
        << "cur_infer_id " << cur_infer_id << " is out of range, must be < "
        << num_infer << ".";

    // Make sure we can safely access infer_list_[cur_infer_id] and
    // thread_var_vec_[cur_infer_id]
    auto &next = infer_list_[cur_infer_id];
    auto iter_var = thread_var_vec_[cur_infer_id];
    auto thread_bounds = thread_bounds_vec_[cur_infer_id];
87
    auto buffer_oob = buffer_oob_vec_[cur_infer_id];
88
    // Double-check that 'next' is valid
89
90
    ICHECK(next.defined()) << "infer_list_[" << cur_infer_id
                           << "] is null inside run_infer_step.";
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

    // Check iter_var->dom and dom->extent
    ICHECK(iter_var.defined())
        << "thread_var_vec_[" << cur_infer_id << "] is not defined.";
    ICHECK(iter_var->dom.defined())
        << "iter_var->dom is not defined for infer_list_[" << cur_infer_id
        << "].";
    ICHECK(iter_var->dom->extent.defined())
        << "iter_var->dom->extent is not defined for infer_list_["
        << cur_infer_id << "].";

    const int64_t *extent_ptr = as_const_int(iter_var->dom->extent);
    ICHECK(extent_ptr != nullptr)
        << "iter_var->dom->extent is not a constant integer, which is "
           "required for layout inference.";

    // Run InferLayout
108
    DLOG(INFO) << "[RunInferStep] working on " << cur_infer_id << '\n';
109
110
111
112
    auto updates =
        next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map,
                                          &analyzer_, buffer_oob},
                          level);
113
114
    // Process the returned updates
    for (const auto &[buffer, layout] : updates) {
115
116
117
      DLOG(INFO) << "    consider update " << buffer << " as "
                 << layout->DebugOutput() << '\n';

118
119
120
121
122
123
124
125
126
127
      // Basic validity checks
      ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
      ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";

      if (layout_map.count(buffer)) {
        // If new layout contains the old one, update map
        if (buffer.scope() == "local.fragment" &&
            level != InferLevel::kStrict && !strict_layout_map.count(buffer)) {
          // Actually this test has been done in ParallelOp::InferLayout
          // already. Just do it again to avoid missing implementations in other
128
          // `TileOperator`s.
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
          auto dst_layout = layout.as<Fragment>().value();
          auto src_layout = layout_map[buffer].as<Fragment>().value();
          ICHECK(dst_layout->InputDim() == src_layout->InputDim());
          Array<PrimExpr> indices;
          indices.reserve(dst_layout->InputDim());
          arith::Analyzer inner_analyzer;
          for (int i = 0; i < dst_layout->InputDim(); ++i) {
            auto x = InputPlaceholder(i);
            indices.push_back(x);
            // should be literal - literal = 0, any analyzer will work
            ICHECK(is_zero(inner_analyzer.Simplify(
                dst_layout->InputShape()[i] - src_layout->InputShape()[i])));
            inner_analyzer.Bind(x, Range(0, dst_layout->InputShape()[i]));
          }
          if (ProveFragmentContains(src_layout, dst_layout, indices, indices,
                                    inner_analyzer)) {
            layout_map.Set(buffer, layout);
146
147
            DLOG(INFO) << "    layout broadcast from "
                       << src_layout->DebugOutput() << ", accepted" << '\n';
148
149
150
151
152
153
154
155
156
157
158
            continue;
          }
        }
        // If already in map, ensure they are structurally equal
        ICHECK(StructuralEqual()(layout, layout_map[buffer]))
            << "Get different layout for " << buffer
            << "\n current layout: " << layout->DebugOutput()
            << "\n previous layout: " << layout_map[buffer]->DebugOutput();
      } else {
        // Otherwise, update map
        layout_map.Set(buffer, layout);
159
        DLOG(INFO) << "    new layout accepted" << '\n';
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        if (!update_queue)
          continue;

        // Check if buffer exists in use_list_
        if (!use_list_.count(buffer)) {
          LOG(WARNING) << "Layout inference failed for buffer " << buffer
                       << ". "
                       << "The buffer cannot be inferred with current layout "
                          "inference rules.";
          continue;
        }

        // Push back into BFS queue
        for (int idx : use_list_[buffer]) {
          ICHECK_GE(idx, 0)
              << "Index in use_list_ for buffer " << buffer << " is negative.";
          ICHECK_LT(idx, num_infer)
              << "Index in use_list_ for buffer " << buffer
              << " out of range: " << idx << " >= " << num_infer << ".";

          if (!in_queue[idx] && idx != cur_infer_id) {
            in_queue[idx] = true;
            q.push(idx);
          }
        }
      }
    }
  };

  void FinishInferQueue(InferLevel level, LayoutMap &layout_map,
                        const LayoutMap &strict_layout_map, std::queue<int> &q,
                        std::vector<bool> &in_queue) {
    auto num_infer = infer_list_.size();
    while (!q.empty()) {
      int cur_infer_id = q.front();
      q.pop();
      // Range check again, just to be safe
      ICHECK_GE(cur_infer_id, 0);
      ICHECK_LT(cur_infer_id, num_infer);

      in_queue[cur_infer_id] = false;
      RunInferStep(cur_infer_id, level, true, layout_map, strict_layout_map, q,
                   in_queue);
    }
  };

206
  LayoutInferenceResult Run() {
207
208
209
210
211
    // Basic consistency check: infer_list_ and thread_var_vec_ should have the
    // same size
    ICHECK_EQ(infer_list_.size(), thread_var_vec_.size())
        << "Size mismatch: infer_list_ and thread_var_vec_ must match in "
           "length.";
212
213
214
    ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size())
        << "Size mismatch: thread_bounds_vec_ and infer_list_ must match in "
           "length.";
215
216
217
    ICHECK_EQ(buffer_oob_vec_.size(), infer_list_.size())
        << "Size mismatch: buffer_oob_vec_ and infer_list_ must match in "
           "length.";
218

219
220
221
222
223
    DLOG(INFO) << "[InferLayout] all participating operators:" << '\n';
    for (int i = 0; i < infer_list_stmt_.size(); ++i) {
      DLOG(INFO) << "    op " << i << ":" << infer_list_stmt_[i] << '\n';
    }

224
225
226
227
    // If needed, you can also check that annotated_layout_map_ is not empty, or
    // anything else relevant to your setup.

    // Copy the annotated layout map to local variable
228
    Map<Buffer, Layout> layout_map = annotated_layout_map_;
229
    Map<Buffer, Layout> strict_layout_map;
230
231
    int num_infer = infer_list_.size();

232
    // Prepare BFS queue for iterative inference
233
234
    std::queue<int> q;
    std::vector<bool> in_queue(num_infer, true);
235
236
    for (int i = 0; i < num_infer; i++) {
      // Check that each infer_list_ entry is valid
237
      ICHECK(infer_list_[i].defined())
238
239
240
241
242
243
244
          << "infer_list_[" << i
          << "] is null. The inference object is not allocated properly.";

      // Check that each thread_var_vec_ entry is defined
      if (!thread_var_vec_[i].defined() && skip_thread_partition_) {
        thread_var_vec_[i] = thread_var_;
      }
245
      q.push(i);
246
    }
247

248
    // step 1: infer strict layout
249
    for (int i = 0; i < num_infer; i++) {
250
251
      RunInferStep(i, InferLevel::kStrict, false, layout_map, strict_layout_map,
                   q, in_queue);
252
253
    }

254
255
256
257
    for (const auto &[buffer, layout] : layout_map) {
      strict_layout_map.Set(buffer, layout);
    }

258
    // step 2: infer common layout with BFS
259
260
    FinishInferQueue(InferLevel::kCommon, layout_map, strict_layout_map, q,
                     in_queue);
261

262
    // step 3: relax constraints to free and re-run
263
264
    InferInFreeMode(layout_map, strict_layout_map);

265
    // Check that all local.fragment buffers have inferred layouts
266
    for (const auto &[buffer, _] : use_list_) {
267
268
269
270
271
      if (buffer.scope() == "local.fragment") {
        ICHECK_NE(layout_map.count(buffer), 0)
            << "The layout for fragment " << buffer
            << " can not be inferred correctly.";
      }
272
273
    }

274
    // Collect layout info for For nodes
275
276
    Map<For, Fragment> for_map;
    Map<For, PrimExpr> predicate_map;
277
278
279
    ICHECK(infer_list_.size() == thread_var_vec_.size())
        << "infer_list_ and thread_var_vec_ size mismatch";
    for (int i = 0; i < infer_list_.size(); i++) {
280
      TileOperator base_infer = std::move(infer_list_[i]);
281
282
      auto thread_var = thread_var_vec_[i];

283
      // Check if base_infer is valid
284
285
286
      ICHECK(base_infer.defined()) << "Null pointer encountered in "
                                      "infer_list_ while collecting for_map.";
      if (auto for_infer = base_infer.as<ParallelOpNode>()) {
287
        // Check that the loop layout is defined
288
        ICHECK(for_infer->GetLoopLayout().defined())
289
            << "The Layout for Parallel for cannot be inferred correctly:\n"
290
291
            << for_infer->GetRoot();
        for_map.Set(for_infer->GetRoot(), for_infer->GetLoopLayout());
292
        // thread_var_ should be defined if we rely on it
293
294
        ICHECK(thread_var.defined())
            << "thread_var is not defined. Cannot retrieve predicate.";
295

296
        if (auto predicate = for_infer->GetPredicate(thread_var->var)) {
297
          predicate_map.Set(for_infer->GetRoot(), predicate.value());
298
        }
299
300
301
302
303
304
      }
    }

    return {layout_map, for_map, predicate_map};
  }

305
306
  void Collect(const PrimFunc &f) {
    for (const auto &[_, buffer] : f->buffer_map) {
307
308
309
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
310
311
    ICHECK(target.defined())
        << "Layout_Inference: Require the target attribute";
312
313
314
315
    target_ = target.value();
    this->operator()(f->body);
  }

316
317
private:
  void VisitExpr_(const CallNode *op) final {
318
    IRVisitorWithAnalyzer::VisitExpr_(op);
319
    // Do not analysis the call node to the global function.
320
321
    if (op->op.as<GlobalVarNode>())
      return;
322
323

    auto p = ParseOperator(GetRef<Call>(op), buffer_data_to_buffer_);
324
    if (p.defined()) {
325
      for (const auto &arg : op->args) {
326
327
328
329
        if (auto buffer = getBufferFromAccessPtr(arg)) {
          addToUseList(buffer.value());
        }
      }
330
      // Compute thread_var_ and thread_bounds_
331
      thread_var_vec_.push_back(thread_var_);
332
333
334
335
      if (analyzer_.const_int_bound.IsBound(thread_var_->var)) {
        auto const_int_bound = analyzer_.const_int_bound(thread_var_);
        auto min_value = const_int_bound->min_value;
        auto max_value = const_int_bound->max_value;
336
        auto extent = max_value - min_value + 1;
337
338
        auto dtype = thread_var_->var.dtype();
        thread_bounds_vec_.push_back(Range::FromMinExtent(
339
            IntImm(dtype, min_value), IntImm(dtype, extent)));
340
341
342
      } else {
        thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
      }
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375

      // Compute buffer oob for each buffer in the op
      if (const auto *copy = p.as<CopyNode>()) {
        auto src_tensor = copy->src;
        auto dst_tensor = copy->dst;
        auto src_range = copy->src_range;
        auto dst_range = copy->dst_range;
        bool src_oob = false;
        bool dst_oob = false;
        for (size_t i = 0; i < src_range.size(); i++) {
          if (!analyzer_.CanProve(src_range[i]->min + src_range[i]->extent <=
                                      src_tensor->shape[i],
                                  arith::ProofStrength::kSymbolicBound)) {
            src_oob = true;
            break;
          }
        }
        for (size_t i = 0; i < dst_range.size(); i++) {
          if (!analyzer_.CanProve(dst_range[i]->min + dst_range[i]->extent <=
                                      dst_tensor->shape[i],
                                  arith::ProofStrength::kSymbolicBound)) {
            dst_oob = true;
            break;
          }
        }
        buffer_oob_vec_.push_back(src_oob || dst_oob);
      } else {
        buffer_oob_vec_.push_back(false);
      }

      // Add the tile operator to infer_list_
      infer_list_stmt_.push_back(GetRef<ObjectRef>(op));
      infer_list_.push_back(std::move(p));
376
377
378
    }
  }

379
  Optional<Buffer> getBufferFromAccessPtr(const PrimExpr &expr) {
380
    auto call = expr.as<CallNode>();
381
382
383
384
    if (!call) {
      return std::nullopt;
    }
    if (call->op.same_as(builtin::tvm_access_ptr())) {
385
386
      auto var = call->args[1].as<Var>().value();
      return buffer_data_to_buffer_[var];
387
388
    } else if (call->op.same_as(RegionOp::Get())) {
      return call->args[0].as<BufferLoadNode>()->buffer;
389
    }
390
    return std::nullopt;
391
392
  }

393
  void addToUseList(const Buffer &buffer) {
394
395
396
397
398
399
400
    int infer_idx = infer_list_.size();
    if (use_list_.find(buffer) == use_list_.end()) {
      use_list_[buffer] = {};
    }
    use_list_[buffer].push_back(infer_idx);
  }

401
  void VisitStmt_(const ForNode *op) final {
402
    if (op->kind == ForKind::kParallel) {
403
      auto infer = ParallelOp(GetRef<For>(op));
404
      for (const auto &[buffer, _] : infer->GetIndiceMap()) {
405
406
        addToUseList(buffer);
      }
407
      infer_list_stmt_.push_back(GetRef<ObjectRef>(op));
408
409
      infer_list_.push_back(std::move(infer));
      thread_var_vec_.push_back(thread_var_);
410
411
412
413
      if (thread_var_.defined() &&
          analyzer_.const_int_bound.IsBound(thread_var_->var)) {
        auto const_int_bound = analyzer_.const_int_bound(thread_var_);
        auto dtype = thread_var_->var.dtype();
414
415
        auto extent =
            const_int_bound->max_value - const_int_bound->min_value + 1;
416
        thread_bounds_vec_.push_back(Range::FromMinExtent(
417
            IntImm(dtype, const_int_bound->min_value), IntImm(dtype, extent)));
418
419
420
      } else {
        thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
      }
421
      buffer_oob_vec_.push_back(false);
422
    } else {
423
      IRVisitorWithAnalyzer::VisitStmt(op->body);
424
425
426
    }
  }

427
  void VisitStmt_(const BlockNode *op) final {
428
429
430
431
    for (auto buffer : op->alloc_buffers) {
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    if (op->annotations.count(attr::kLayoutMap)) {
432
      // Check if the layout map is Map<Var, Layout>
433
434
435
      auto map =
          op->annotations.Get(attr::kLayoutMap)->as<Map<Var, Layout>>().value();
      for (const auto &[var, layout] : map) {
436
437
        ICHECK(buffer_data_to_buffer_.count(var))
            << "buffer " << var << " is not found in the block";
438
439
440
441
442
        auto buffer = buffer_data_to_buffer_[var];
        ICHECK(StructuralEqual()(layout->InputShape(), buffer->shape));
        annotated_layout_map_.Set(buffer, layout);
      }
    }
443
    IRVisitorWithAnalyzer::VisitStmt_(op);
444
445
  }

446
  void VisitStmt_(const AttrStmtNode *op) final {
447
448
449
450
451
452
453
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      if (iv->thread_tag == "threadIdx.x") {
        ICHECK(iv->dom->extent.as<IntImmNode>());
        thread_var_ = iv;
      }
    }
454
    IRVisitorWithAnalyzer::VisitStmt_(op);
455
456
457
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
458
  std::vector<ObjectRef> infer_list_stmt_;
459
  std::vector<TileOperator> infer_list_;
460
461
  std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
      use_list_;
462
463
464
465
  // This is a workaround for cpu backend,
  // we need to define a thread_var for the serial loop.
  IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
                                IterVarType::kDataPar);
466
  std::vector<IterVar> thread_var_vec_;
467
  std::vector<Range> thread_bounds_vec_;
468
  std::vector<bool> buffer_oob_vec_;
469
470
  Target target_;
  LayoutMap annotated_layout_map_;
471
  bool skip_thread_partition_{false};
472

473
474
  std::vector<TileOperator> BackupInferList() {
    std::vector<TileOperator> back_infer_list;
475
476
477
478
479
480
481
482
483
    back_infer_list.reserve(infer_list_.size());
    for (auto &&p : infer_list_) {
      back_infer_list.push_back(p->Clone());
    }
    return back_infer_list;
  }

  void InferInFreeMode(LayoutMap &layout_map,
                       const LayoutMap &strict_layout_map) {
484
485
486
487
488
489
490

    DLOG(INFO) << "Enforced layout maps:" << '\n';
    for (auto &&[k, v] : layout_map) {
      DLOG(INFO) << "    " << k << ": " << v->DebugOutput() << '\n';
    }
    DLOG(INFO) << '\n';

491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
    // Group operators into connected components
    UnionFind<int> uf;
    for (int i = 0; i < infer_list_.size(); i++) {
      uf.MakeSet(i);
    }
    for (const auto &[buffer, infer_indices] : use_list_) {
      if (infer_indices.empty())
        continue;

      // Union all infer_list_ indices that share the same buffer
      int first_idx = infer_indices[0];
      for (size_t i = 1; i < infer_indices.size(); i++) {
        uf.Union(first_idx, infer_indices[i]);
      }
    }
    std::unordered_map<int, std::vector<int>> components;
    for (int i = 0; i < infer_list_.size(); i++) {
      int root = uf.Find(i);
      components[root].push_back(i);
    }
511
    // Create a map from root to buffers
512
513
514
515
516
    std::unordered_map<int, std::vector<Buffer>> components_buffers;
    for (const auto &[buffer, infer_indices] : use_list_) {
      int root = uf.Find(infer_indices[0]);
      components_buffers[root].push_back(buffer);
    }
517
518
    // Keep components_buffers for debug purpose
    (void)components_buffers;
519
520
521
522
523

    // For each component, try each op as root, and determine the least
    // replicated one
    std::queue<int> q;
    std::vector<bool> in_queue(infer_list_.size(), false);
524

525
    for (auto &&[root, members] : components) {
526
527
      DLOG(INFO) << "======================= processing component " << root
                 << '\n';
528
529
530
      decltype(infer_list_) best_infer_list;
      LayoutMap best_layout_map;
      int64_t min_reg_num = INT64_MAX;
531
      int min_reg_num_infer_root = -1;
532

533
      // Try each member as the root of inference for this component
534
      for (int attempt_infer_root : members) {
535
536
537
        DLOG(INFO) << "----------------------- try root " << attempt_infer_root
                   << '\n';
        // Backup the current infer_list_ state
538
        auto back_infer_list = BackupInferList();
539
        // Copy the current layout_map for temporary use
540
541
542
        LayoutMap tmp_layout_map = layout_map;
        bool do_update = true;
        try {
543
          // Run inference starting from attempt_infer_root
544
545
546
547
          RunInferStep(attempt_infer_root, InferLevel::kFree, true,
                       tmp_layout_map, strict_layout_map, q, in_queue);
          FinishInferQueue(InferLevel::kFree, tmp_layout_map, strict_layout_map,
                           q, in_queue);
548
549
550

          // After the first search, run inference for all other members in
          // order
551
552
553
554
555
556
557
558
          for (int other_infer_root : members) {
            if (other_infer_root != attempt_infer_root) {
              RunInferStep(other_infer_root, InferLevel::kFree, true,
                           tmp_layout_map, strict_layout_map, q, in_queue);
              FinishInferQueue(InferLevel::kFree, tmp_layout_map,
                               strict_layout_map, q, in_queue);
            }
          }
559
        } catch (const LayoutConflictException &e) {
560
          do_update = false;
561
562
563
          DLOG(INFO) << "attempt failed due to LayoutConflictException "
                     << e.what() << '\n';
        } catch (const NormalizeIterException &e) {
564
          do_update = false;
565
566
          DLOG(INFO) << "attempt failed due to NormalizeIterException "
                     << e.what() << '\n';
567
568
569
        }

        if (do_update) {
570
          // Compute the total register number for this layout
571
          int64_t reg_num = 0;
572
          for (const auto &[buffer, layout] : tmp_layout_map) {
573
574
575
576
577
578
579
580
581
582
            if (auto frag = layout.as<Fragment>()) {
              int64_t frag_reg_num = 1;
              for (auto i : frag.value()->OutputShape()) {
                auto pci = as_const_int(i);
                ICHECK(pci != nullptr);
                frag_reg_num *= *pci;
              }
              reg_num += frag_reg_num;
            }
          }
583
          // Update the best plan if this one uses fewer registers
584
          if (reg_num < min_reg_num) {
585
586
            best_infer_list =
                BackupInferList(); // Use backup to avoid moving out infer_list_
587
588
            best_layout_map = tmp_layout_map;
            min_reg_num = reg_num;
589
            min_reg_num_infer_root = attempt_infer_root;
590
591
          }
        }
592
        // Restore infer_list_ state for the next attempt
593
594
        infer_list_ = std::move(back_infer_list);
      }
595
596
597
598
599
600
      ICHECK(min_reg_num < INT64_MAX) << "no available layout found" << '\n';
      // Apply the best plan for this component
      infer_list_ = std::move(best_infer_list);
      layout_map = best_layout_map;
      DLOG(INFO) << "[InferInFreeMode] Final selection is attempt_infer_root = "
                 << min_reg_num_infer_root << '\n';
601
602
    }
  }
603
604
605
};

class LayoutInferencer : public IRMutatorWithAnalyzer {
606
public:
607
  static PrimFunc Substitute(PrimFunc f, bool skip_thread_partition = false) {
608
    arith::Analyzer analyzer;
609
    PrimFuncNode *fptr = f.CopyOnWrite();
610
    fptr->body = ParallelLoopFuser::Fuse(f->body);
611
    BufferUseDefCollector collector(skip_thread_partition);
612
613
    collector.Collect(f);
    auto result = collector.Run();
614
    LayoutInferencer substituter(result, skip_thread_partition, &analyzer);
615
616
617
618
    fptr->body = substituter.VisitStmt(f->body);
    return f;
  }

619
private:
620
  LayoutInferencer(const LayoutInferenceResult &result,
621
622
623
                   bool skip_thread_partition, arith::Analyzer *analyzer)
      : arith::IRMutatorWithAnalyzer(analyzer), result_(result),
        skip_thread_partition_(skip_thread_partition){};
624

625
626
  using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer;

627
628
629
630
631
632
633
634
635
636
637
638
639
640
  /**
   * @brief Visit and mutate a Block node to attach inferred layout information.
   *
   * Converts the visited Block via the base visitor, asserts that every buffer
   * allocated with scope "local.framgent" has an inferred layout in
   * result_.layout_map, and attaches result_.layout_map to the Block's
   * annotations under attr::kLayoutMap.
   *
   * If any "local.framgent" buffer lacks an entry in result_.layout_map an
   * ICHECK will fail with the offending buffer printed.
   *
   * @return Stmt The (possibly modified) Block statement with the layout-map
   * annotation set.
   */
641
  Stmt VisitStmt_(const BlockNode *op) final {
642
643
644
645
646
647
648
649
650
651
652
653
654
    Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));

    for (auto buffer : block->alloc_buffers) {
      if (buffer.scope() == "local.framgent") {
        ICHECK(result_.layout_map.count(buffer))
            << "Cannot inference fragment layout for " << buffer;
      }
    }
    auto block_ptr = block.CopyOnWrite();
    block_ptr->annotations.Set(attr::kLayoutMap, result_.layout_map);
    return block;
  }

655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
  /**
   * @brief Visit and transform For nodes according to inferred layout
   * information.
   *
   * If the For node is present in result_.for_map, this method applies
   * loop-level layout-driven transformations: it optionally partitions the loop
   * across the thread index, vectorizes the loop body, and wraps the loop with
   * a predicate if one was inferred for the loop root.
   *
   * Detailed behavior:
   * - Reads reducer information from the For node's attr::kReducerInfo
   * annotation (if present) to detect reduction targets.
   * - Detects register-local buffer stores (buffers with scope "local") in the
   *   original loop body; if only register-local stores are present the loop is
   *   treated as a register-local scenario and is not partitioned across
   * threads.
   * - Obtains the loop layout from result_.for_map[root] and, unless the loop
   * is register-local or skip_thread_partition_ is set, partitions the loop via
   *   PartitionLoop using thread_var_ and analyzer_.
   * - Scans the transformed loop body to determine whether it accesses any
   *   non-local buffers (scopes other than "local" or "local.fragment").
   * - Scans the transformed loop body to detect reducers (based on
   * reducer_info). If a reducer is present the loop is NOT vectorized
   * (reduction axes are excluded from vectorization as a conservative
   * workaround).
   * - If the loop has non-local accesses and no reducer, the loop is vectorized
   *   via VectorizeLoop.
   * - If a predicate exists in result_.predicate_map for the loop root and the
   *   loop was partitioned, the method returns an IfThenElse surrounding the
   *   (possibly partitioned/vectorized) loop with that predicate; otherwise it
   *   returns the transformed For.
   *
   * @return The possibly transformed For statement (or an IfThenElse wrapping
   * it)
   */
690
  Stmt VisitStmt_(const ForNode *op) final {
691
692
693
694
695
696
    Map<Var, ReducerInfo> reducer_info;
    if (op->annotations.count(attr::kReducerInfo))
      reducer_info = op->annotations.Get(attr::kReducerInfo)
                         ->as<Map<Var, ReducerInfo>>()
                         .value();

697
698
    For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
    if (result_.for_map.count(GetRef<For>(op))) {
699
700
701
702
703
704
705
706
      auto root = GetRef<For>(op);
      // This check is a workaround to support T.Parallel for local buffers.
      // For example:
      //   for i in T.Parallel(1024):
      //     A_local[i] = A_global[i]
      // Here, A_local is a register-local buffer held independently by each
      // thread, so explicit thread binding is not required.
      //
707
708
      // We use PostOrderVisit to detect whether the loop only manuplates
      // "local" buffers, which indicates register usage and justifies skipping
709
      // thread binding.
710
      bool local_register_only = true;
711
712
      PostOrderVisit(root, [&](const ObjectRef &obj) {
        if (const auto *store = obj.as<BufferStoreNode>()) {
713
714
715
716
717
718
          if (store->buffer.scope() != "local") {
            local_register_only = false;
          }
        } else if (const auto *load = obj.as<BufferLoadNode>()) {
          if (load->buffer.scope() != "local") {
            local_register_only = false;
719
720
721
722
          }
        }
      });

723
      auto loop_layout = result_.for_map[root];
724
725
      // FIXME: tell in-Parallel and out-of-Parallel `local`s apart
      bool parallel_loop = !skip_thread_partition_ && !local_register_only;
726

727
      if (parallel_loop) {
728
729
730
        for_node =
            PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
      }
731
      // If none thread bindings are provided, partition the loop
732
733
734
735
736
737
738
739
740
741
742
743
744
745
      bool has_non_local = false;
      PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
        if (const auto *load = obj.as<BufferLoadNode>()) {
          String scope = load->buffer.scope();
          if (scope != "local" && scope != "local.fragment") {
            has_non_local = true;
          }
        } else if (const auto *store = obj.as<BufferStoreNode>()) {
          String scope = store->buffer.scope();
          if (scope != "local" && scope != "local.fragment") {
            has_non_local = true;
          }
        }
      });
746
747
748
749
750
751
752
753
754
      // Workaround: if reducer is presented, don't vectorize loop
      // Best solution should be isolate reduction axis out of vectorization
      bool has_reducer = false;
      PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
        if (!has_reducer)
          if (const auto *store = obj.as<BufferStoreNode>()) {
            has_reducer = reducer_info.count(store->buffer->data) != 0;
          }
      });
755

756
      if (has_non_local && !has_reducer) {
757
758
        for_node = VectorizeLoop(for_node);
      }
759

760
761
      if (result_.predicate_map.count(root) && parallel_loop) {
        return IfThenElse(result_.predicate_map[root], for_node);
762
763
764
765
766
767
768
      } else {
        return for_node;
      }
    }
    return for_node;
  }

769
  Stmt VisitStmt_(const AttrStmtNode *op) final {
770
771
772
773
774
775
776
777
778
779
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      ICHECK_NE(iv->thread_tag.length(), 0U);
      if (iv->thread_tag == "threadIdx.x") {
        thread_var_ = iv;
      }
    }
    return IRMutatorWithAnalyzer::VisitStmt_(op);
  }

780
private:
781
  const LayoutInferenceResult result_;
782
783
  IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
                                IterVarType::kDataPar);
784
  bool skip_thread_partition_{false};
785
786
787
788
};

tvm::transform::Pass LayoutInference() {
  using namespace tir::transform;
789
  auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
790
    f.CopyOnWrite()->body = ParallelLoopTransformer::Substitute(f->body);
791
    ThreadBindingCollector collector;
792
    collector(f->body);
793
    bool has_thread_binding = !collector.thread_binding_.empty();
794
    bool skip_thread_partition = !has_thread_binding;
795
    return LayoutInferencer::Substitute(std::move(f), skip_thread_partition);
796
797
798
799
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {});
}

800
801
802
803
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.LayoutInference", LayoutInference);
});
804

805
806
} // namespace tl
} // namespace tvm