"vscode:/vscode.git/clone" did not exist on "cc122226f5ee01933c16655863cef861313d7d5f"
layout_inference.cc 20.2 KB
Newer Older
1
2
3
4
5
6
/*!
 * \file layout_inference.cc
 * \brief infer the fragment/shared memory layout
 */

#include <tvm/tir/builtin.h>
7
#include <tvm/tir/index_map.h>
8
9
10
11
12
13
14
15
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>

#include <queue>

#include "../op/parallel.h"
16
17
#include "arith/ir_mutator_with_analyzer.h"
#include "common/loop_fusion_utils.h"
18
19
#include "loop_partition.h"
#include "loop_vectorize.h"
20
21
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
22
23
24
25

namespace tvm {
namespace tl {

26
27
28
using namespace tir;

/*!
29
 * \brief collect the mapping from the buffer var to it allocated buffer
30
 */
31
class ThreadBindingCollector : public StmtExprVisitor {
32
33
public:
  void VisitStmt_(const AttrStmtNode *op) final {
34
35
36
37
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      thread_binding_[iv->var.get()] = iv;
    }
38
39
40
    StmtExprVisitor::VisitStmt_(op);
  }

41
42
  // The thread binding map
  std::unordered_map<const VarNode *, IterVar> thread_binding_;
43
44
};

45
46
47
using namespace tir;
using arith::IRMutatorWithAnalyzer;

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
class ParallelLoopTransformer : public IRMutatorWithAnalyzer {
public:
  static Stmt Substitute(Stmt stmt, bool skip_thread_partition = false) {
    arith::Analyzer analyzer;
    ParallelLoopTransformer transformer(&analyzer);
    return transformer.VisitStmt(stmt);
  }

  ParallelLoopTransformer(arith::Analyzer *analyzer)
      : IRMutatorWithAnalyzer(analyzer) {}

  Stmt VisitStmt_(const ForNode *op) final {
    if (op->kind == ForKind::kParallel) {

      // Collect loop variables and ranges
      auto for_node = GetRef<For>(op);
      Array<Var> loop_vars;
      Array<PrimExpr> loop_extents;
      Stmt body = op->body;

      // Bind the range of outer loop variables
      analyzer_->Bind(op->loop_var, Range::FromMinExtent(0, op->extent));
      loop_vars.push_back(op->loop_var);
      loop_extents.push_back(op->extent);

      // If there are inner loops, bind their ranges as well
      while (const ForNode *inner = body.as<ForNode>()) {
        analyzer_->Bind(inner->loop_var,
                        Range::FromMinExtent(0, inner->extent));
        loop_vars.push_back(inner->loop_var);
        loop_extents.push_back(inner->extent);
        body = inner->body;
      }

      ICHECK(loop_vars.size() == loop_extents.size())
          << "loop_vars and loop_extents size mismatch";

      // Collect buffer access information
      BufferAccessCollector collector;
      collector(op->body);

      PrimExpr condition;

      for (const auto &[buffer, indices] : collector.buffer_indices) {
        ICHECK(indices.size() == buffer->shape.size())
            << "indices size mismatch with buffer shape";

        for (size_t i = 0; i < indices.size(); ++i) {
          auto index = indices[i];
          auto bound = analyzer_->const_int_bound(index);
          int64_t upper_bound = bound->max_value + 1;
          int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;

          // Collect the variables that used in the index
          std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_vars;
          // post order visit the index
          PostOrderVisit(index, [&](const ObjectRef &obj) {
            if (const VarNode *v = obj.as<VarNode>()) {
              used_vars.insert(GetRef<Var>(v));
            }
          });
          if (used_vars.size() == 0) {
            continue;
          }

          // find related loop vars
          Array<Var> related_loop_vars;
          for (size_t j = 0; j < loop_vars.size(); ++j) {
            auto loop_var = loop_vars[j];
            // if find related, pop the loop_vars and loop_extents
            if (used_vars.count(loop_var)) {
              related_loop_vars.push_back(loop_var);
            }
            ICHECK(related_loop_vars.size() <= 1)
                << "Only one related loop var is supported currently, but got "
                << related_loop_vars
                << " implement multiple loop vars may not be "
                << "too hard, please send an issue if you need "
                << "came up with this message.";

            auto bound = analyzer_->const_int_bound(index);
            int64_t upper_bound = bound->max_value + 1;
            int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;
            if (upper_bound < shape) {
              PrimExpr predicate =
                  LT(index, IntImm(index.dtype(), upper_bound));
              condition =
                  condition.defined() ? And(condition, predicate) : predicate;

              // replace the buffer index from A[i, r * 2] with A[i, j]
              // where r is the original index, j is the loop_var
              auto index_map = tir::IndexMap({loop_var}, {index});
              auto inverse_index_map = index_map.Inverse(
                  {Range::FromMinExtent(0, IntImm(index.dtype(), upper_bound))},
                  analyzer_);

              loop_extents.Set(i, IntImm(index.dtype(), shape));
              body = tir::Substitute(
                  body, {{loop_var, inverse_index_map->MapIndices(
                                        {loop_var}, analyzer_)[0]}});
            }
          }
        }
      }
      if (condition.defined()) {
        body = IfThenElse(condition, body);
        for (int j = loop_vars.size() - 1; j >= 0; --j) {
          auto loop_var = loop_vars[j];
          auto loop_extent = loop_extents[j];
          body = For(loop_var, 0, loop_extent, ForKind::kParallel, body);
        }
        return Downcast<For>(body);
      }
      // Only traverse the outer loop
      return for_node;
    }
    return StmtMutator::VisitStmt_(op);
  }

private:
  // Helper class for collecting buffer access information, only counts fragment
  // buffer access
  class BufferAccessCollector : public StmtExprVisitor {
  public:
    void VisitExpr_(const BufferLoadNode *op) final {
      if (op->buffer.scope() == "local.fragment") {
        if (buffer_indices.find(op->buffer) == buffer_indices.end()) {
          buffer_indices[op->buffer] = op->indices;
        } else {
          // check equal
          ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices))
              << "indices mismatch for buffer: " << op->buffer;
        }
      }
      StmtExprVisitor::VisitExpr_(op);
    }

    void VisitStmt_(const BufferStoreNode *op) final {
      if (op->buffer.scope() == "local.fragment") {
        if (buffer_indices.find(op->buffer) == buffer_indices.end()) {
          buffer_indices[op->buffer] = op->indices;
        } else {
          // check equal
          ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices))
              << "indices mismatch for buffer: " << op->buffer;
        }
      }
      StmtExprVisitor::VisitStmt_(op);
    }

    std::unordered_map<Buffer, Array<PrimExpr>, ObjectPtrHash, ObjectPtrEqual>
        buffer_indices;
  };
};

203
204
205
206
207
208
209
struct LayoutInferenceResult {
  Map<Buffer, Layout> layout_map;
  Map<For, Fragment> for_map;
  Map<For, PrimExpr> predicate_map;
};

class BufferUseDefCollector : public StmtExprVisitor {
210
public:
211
212
  BufferUseDefCollector(bool skip_thread_partition)
      : skip_thread_partition_(skip_thread_partition) {}
213
214

  LayoutInferenceResult Run() {
215
216
217
218
219
220
221
222
223
224
    // Basic consistency check: infer_list_ and thread_var_vec_ should have the
    // same size
    ICHECK_EQ(infer_list_.size(), thread_var_vec_.size())
        << "Size mismatch: infer_list_ and thread_var_vec_ must match in "
           "length.";

    // If needed, you can also check that annotated_layout_map_ is not empty, or
    // anything else relevant to your setup.

    // Copy the annotated layout map to local variable
225
226
227
    Map<Buffer, Layout> layout_map = annotated_layout_map_;
    int num_infer = infer_list_.size();

228
    // Prepare BFS queue for iterative inference
229
230
    std::queue<int> q;
    std::vector<bool> in_queue(num_infer, true);
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    for (int i = 0; i < num_infer; i++) {
      // Check that each infer_list_ entry is valid
      ICHECK(infer_list_[i] != nullptr)
          << "infer_list_[" << i
          << "] is null. The inference object is not allocated properly.";

      // Check that each thread_var_vec_ entry is defined
      if (!thread_var_vec_[i].defined() && skip_thread_partition_) {
        // TODO(lei): This is a hack for cpu backend
        if (!thread_var_.defined()) {
          // Fake thread var to inference predicate for the buffer
          thread_var_ = IterVar(Range::FromMinExtent(PrimExpr(0), PrimExpr(1)),
                                Var(""), IterVarType::kDataPar);
        }
        thread_var_vec_[i] = thread_var_;
      }
247
      q.push(i);
248
    }
249
250
    auto run_infer_step = [&](int cur_infer_id, InferLevel level,
                              bool update_queue) {
251
252
253
254
255
256
257
258
259
      // Range check for cur_infer_id
      ICHECK_GE(cur_infer_id, 0)
          << "cur_infer_id is negative, which is invalid.";
      ICHECK_LT(cur_infer_id, num_infer)
          << "cur_infer_id " << cur_infer_id << " is out of range, must be < "
          << num_infer << ".";

      // Make sure we can safely access infer_list_[cur_infer_id] and
      // thread_var_vec_[cur_infer_id]
260
      auto &next = infer_list_[cur_infer_id];
261
      auto iter_var = thread_var_vec_[cur_infer_id];
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282

      // Double-check that 'next' is valid
      ICHECK(next != nullptr) << "infer_list_[" << cur_infer_id
                              << "] is null inside run_infer_step.";

      // Check iter_var->dom and dom->extent
      ICHECK(iter_var.defined())
          << "thread_var_vec_[" << cur_infer_id << "] is not defined.";
      ICHECK(iter_var->dom.defined())
          << "iter_var->dom is not defined for infer_list_[" << cur_infer_id
          << "].";
      ICHECK(iter_var->dom->extent.defined())
          << "iter_var->dom->extent is not defined for infer_list_["
          << cur_infer_id << "].";

      const int64_t *extent_ptr = as_const_int(iter_var->dom->extent);
      ICHECK(extent_ptr != nullptr)
          << "iter_var->dom->extent is not a constant integer, which is "
             "required for layout inference.";

      // Run InferLayout
283
      auto updates = next->InferLayout(
284
285
          LayoutInferArgs{target_, static_cast<size_t>(*extent_ptr),
                          layout_map},
286
          level);
287
      // Process the returned updates
288
      for (const auto &[buffer, layout] : updates) {
289
290
291
292
        // Basic validity checks
        ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
        ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";

293
        if (layout_map.count(buffer)) {
294
          // If already in map, ensure they are structurally equal
295
          ICHECK(StructuralEqual()(layout, layout_map[buffer]))
296
              << "Get different layout for " << buffer
297
298
              << " current layout: " << layout->DebugOutput()
              << " previous layout: " << layout_map[buffer]->DebugOutput();
299
        } else {
300
          // Otherwise, update map
301
          layout_map.Set(buffer, layout);
302
303
          if (!update_queue)
            continue;
304
305

          // Check if buffer exists in use_list_
306
307
308
309
310
311
          if (!use_list_.count(buffer)) {
            LOG(WARNING) << "Buffer " << buffer << " not found in use_list_. "
                         << "Potential mismatch between inference updates and "
                         << "use_list_.";
            continue;
          }
312
313

          // Push back into BFS queue
314
          for (int idx : use_list_[buffer]) {
315
316
317
318
319
320
            ICHECK_GE(idx, 0) << "Index in use_list_ for buffer " << buffer
                              << " is negative.";
            ICHECK_LT(idx, num_infer)
                << "Index in use_list_ for buffer " << buffer
                << " out of range: " << idx << " >= " << num_infer << ".";

321
322
323
324
325
326
327
328
            if (!in_queue[idx] && idx != cur_infer_id) {
              in_queue[idx] = true;
              q.push(idx);
            }
          }
        }
      }
    };
329

330
331
332
333
    auto finish_infer_queue = [&]() {
      while (!q.empty()) {
        int cur_infer_id = q.front();
        q.pop();
334
335
336
337
        // Range check again, just to be safe
        ICHECK_GE(cur_infer_id, 0);
        ICHECK_LT(cur_infer_id, num_infer);

338
339
340
341
342
        in_queue[cur_infer_id] = false;
        run_infer_step(cur_infer_id, InferLevel::kCommon, true);
      }
    };

343
    // step 1: infer strict layout
344
345
346
347
    for (int i = 0; i < num_infer; i++) {
      run_infer_step(i, InferLevel::kStrict, false);
    }

348
    // step 2: infer common layout with BFS
349
    finish_infer_queue();
350

351
    // step 3: relax constraints to free and re-run
352
353
354
355
356
    for (int i = 0; i < num_infer; i++) {
      run_infer_step(i, InferLevel::kFree, true);
      finish_infer_queue();
    }

357
    // Check that all local.fragment buffers have inferred layouts
358
    for (const auto &[buffer, _] : use_list_) {
359
360
361
362
363
      if (buffer.scope() == "local.fragment") {
        ICHECK_NE(layout_map.count(buffer), 0)
            << "The layout for fragment " << buffer
            << " can not be inferred correctly.";
      }
364
365
    }

366
    // Collect layout info for For nodes
367
368
    Map<For, Fragment> for_map;
    Map<For, PrimExpr> predicate_map;
369
370
371
372
373
374
    ICHECK(infer_list_.size() == thread_var_vec_.size())
        << "infer_list_ and thread_var_vec_ size mismatch";
    for (int i = 0; i < infer_list_.size(); i++) {
      std::unique_ptr<Operator> base_infer = std::move(infer_list_[i]);
      auto thread_var = thread_var_vec_[i];

375
376
377
      // Check if base_infer is valid
      ICHECK(base_infer != nullptr) << "Null pointer encountered in "
                                       "infer_list_ while collecting for_map.";
378
      if (auto for_infer = dynamic_cast<ParallelOp *>(base_infer.get())) {
379
        // Check that the loop layout is defined
380
        ICHECK(for_infer->GetLoopLayout().defined())
381
            << "The Layout for Parallel for cannot be inferred correctly:\n"
382
383
            << for_infer->GetRoot();
        for_map.Set(for_infer->GetRoot(), for_infer->GetLoopLayout());
384
        // thread_var_ should be defined if we rely on it
385
386
        ICHECK(thread_var.defined())
            << "thread_var is not defined. Cannot retrieve predicate.";
387

388
        if (auto predicate = for_infer->GetPredicate(thread_var->var)) {
389
          predicate_map.Set(for_infer->GetRoot(), predicate.value());
390
        }
391
392
393
394
395
396
      }
    }

    return {layout_map, for_map, predicate_map};
  }

397
398
  void Collect(const PrimFunc &f) {
    for (const auto &[_, buffer] : f->buffer_map) {
399
400
401
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
402
403
    ICHECK(target.defined())
        << "Layout_Inference: Require the target attribute";
404
405
406
407
    target_ = target.value();
    this->operator()(f->body);
  }

408
409
private:
  void VisitExpr_(const CallNode *op) final {
410
411
    StmtExprVisitor::VisitExpr_(op);
    // Do not analysis the call node to the global function.
412
413
    if (op->op.as<GlobalVarNode>())
      return;
414
415
416

    auto p = ParseOperator(GetRef<Call>(op), buffer_data_to_buffer_);
    if (p != nullptr) {
417
      for (const auto &arg : op->args) {
418
419
420
421
422
423
424
425
426
        if (auto buffer = getBufferFromAccessPtr(arg)) {
          addToUseList(buffer.value());
        }
      }
      infer_list_.push_back(std::move(p));
      thread_var_vec_.push_back(thread_var_);
    }
  }

427
  Optional<Buffer> getBufferFromAccessPtr(const PrimExpr &expr) {
428
429
430
431
432
433
434
435
    auto call = expr.as<CallNode>();
    if (call && call->op.same_as(builtin::tvm_access_ptr())) {
      auto var = call->args[1].as<Var>().value();
      return buffer_data_to_buffer_[var];
    }
    return NullOpt;
  }

436
  void addToUseList(const Buffer &buffer) {
437
438
439
440
441
442
443
    int infer_idx = infer_list_.size();
    if (use_list_.find(buffer) == use_list_.end()) {
      use_list_[buffer] = {};
    }
    use_list_[buffer].push_back(infer_idx);
  }

444
  void VisitStmt_(const ForNode *op) final {
445
446
    if (op->kind == ForKind::kParallel) {
      auto infer = std::make_unique<ParallelOp>(GetRef<For>(op));
447
      for (const auto &[buffer, _] : infer->GetIndiceMap()) {
448
449
450
451
452
453
454
455
456
        addToUseList(buffer);
      }
      infer_list_.push_back(std::move(infer));
      thread_var_vec_.push_back(thread_var_);
    } else {
      StmtExprVisitor::VisitStmt(op->body);
    }
  }

457
  void VisitStmt_(const BlockNode *op) final {
458
459
460
461
    for (auto buffer : op->alloc_buffers) {
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    if (op->annotations.count(attr::kLayoutMap)) {
462
463
464
      auto map =
          op->annotations.Get(attr::kLayoutMap).as<Map<Var, Layout>>().value();
      for (const auto &[var, layout] : map) {
465
466
467
468
469
470
471
472
        auto buffer = buffer_data_to_buffer_[var];
        ICHECK(StructuralEqual()(layout->InputShape(), buffer->shape));
        annotated_layout_map_.Set(buffer, layout);
      }
    }
    StmtExprVisitor::VisitStmt_(op);
  }

473
  void VisitStmt_(const AttrStmtNode *op) final {
474
475
476
477
478
479
480
481
482
483
484
485
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      if (iv->thread_tag == "threadIdx.x") {
        ICHECK(iv->dom->extent.as<IntImmNode>());
        thread_var_ = iv;
      }
    }
    StmtExprVisitor::VisitStmt_(op);
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  std::vector<std::unique_ptr<Operator>> infer_list_;
486
487
  std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
      use_list_;
488
489
490
491
  IterVar thread_var_;
  std::vector<IterVar> thread_var_vec_;
  Target target_;
  LayoutMap annotated_layout_map_;
492
  bool skip_thread_partition_{false};
493
494
495
};

class LayoutInferencer : public IRMutatorWithAnalyzer {
496
public:
497
  static PrimFunc Substitute(PrimFunc f, bool skip_thread_partition = false) {
498
    arith::Analyzer analyzer;
499
    PrimFuncNode *fptr = f.CopyOnWrite();
500
    fptr->body = ParallelLoopFuser::Fuse(f->body);
501
    BufferUseDefCollector collector(skip_thread_partition);
502
503
    collector.Collect(f);
    auto result = collector.Run();
504
    LayoutInferencer substituter(result, skip_thread_partition, &analyzer);
505
506
507
508
    fptr->body = substituter.VisitStmt(f->body);
    return f;
  }

509
510
private:
  LayoutInferencer(const LayoutInferenceResult result,
511
512
513
                   bool skip_thread_partition, arith::Analyzer *analyzer)
      : arith::IRMutatorWithAnalyzer(analyzer), result_(result),
        skip_thread_partition_(skip_thread_partition){};
514

515
  Stmt VisitStmt_(const BlockNode *op) final {
516
517
518
519
520
521
522
523
524
525
526
527
528
    Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));

    for (auto buffer : block->alloc_buffers) {
      if (buffer.scope() == "local.framgent") {
        ICHECK(result_.layout_map.count(buffer))
            << "Cannot inference fragment layout for " << buffer;
      }
    }
    auto block_ptr = block.CopyOnWrite();
    block_ptr->annotations.Set(attr::kLayoutMap, result_.layout_map);
    return block;
  }

529
  Stmt VisitStmt_(const ForNode *op) final {
530
531
532
    For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
    if (result_.for_map.count(GetRef<For>(op))) {
      auto loop_layout = result_.for_map[GetRef<For>(op)];
533
534
535
536
537
      if (!skip_thread_partition_) {
        // If none thread bindings are provided, partition the loop
        for_node =
            PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
      }
538
      for_node = VectorizeLoop(for_node);
539

540
541
542
543
544
545
546
547
548
      if (result_.predicate_map.count(GetRef<For>(op))) {
        return IfThenElse(result_.predicate_map[GetRef<For>(op)], for_node);
      } else {
        return for_node;
      }
    }
    return for_node;
  }

549
  Stmt VisitStmt_(const AttrStmtNode *op) final {
550
551
552
553
554
555
556
557
558
559
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      ICHECK_NE(iv->thread_tag.length(), 0U);
      if (iv->thread_tag == "threadIdx.x") {
        thread_var_ = iv;
      }
    }
    return IRMutatorWithAnalyzer::VisitStmt_(op);
  }

560
private:
561
562
  const LayoutInferenceResult result_;
  IterVar thread_var_;
563
  bool skip_thread_partition_{false};
564
565
566
567
568
};

tvm::transform::Pass LayoutInference() {
  using namespace tir::transform;
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
569
    f.CopyOnWrite()->body = ParallelLoopTransformer::Substitute(f->body);
570
    ThreadBindingCollector collector;
571
    collector(f->body);
572
573
    bool has_thread_binding = collector.thread_binding_.size() > 0;
    bool skip_thread_partition = !has_thread_binding;
574
    return LayoutInferencer::Substitute(std::move(f), skip_thread_partition);
575
576
577
578
579
580
581
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {});
}

TVM_REGISTER_GLOBAL("tl.transform.LayoutInference")
    .set_body_typed(LayoutInference);

582
583
} // namespace tl
} // namespace tvm