layout_inference.cc 18.9 KB
Newer Older
1
2
3
4
5
/*!
 * \file layout_inference.cc
 * \brief infer the fragment/shared memory layout
 */

6
#include <tvm/ffi/reflection/registry.h>
7
#include <tvm/tir/builtin.h>
8
#include <tvm/tir/index_map.h>
9
10
11
12
13
14
15
16
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>

#include <queue>

#include "../op/parallel.h"
17
#include "arith/ir_mutator_with_analyzer.h"
18
#include "arith/ir_visitor_with_analyzer.h"
19
#include "common/loop_fusion_utils.h"
20
#include "common/loop_parallel_transform_utils.h"
21
22
#include "loop_partition.h"
#include "loop_vectorize.h"
23
24
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
25
26
27
28

namespace tvm {
namespace tl {

29
30
31
using namespace tir;

/*!
32
 * \brief collect the mapping from the buffer var to it allocated buffer
33
 */
34
class ThreadBindingCollector : public StmtExprVisitor {
35
36
public:
  void VisitStmt_(const AttrStmtNode *op) final {
37
38
39
40
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      thread_binding_[iv->var.get()] = iv;
    }
41
42
43
    StmtExprVisitor::VisitStmt_(op);
  }

44
45
  // The thread binding map
  std::unordered_map<const VarNode *, IterVar> thread_binding_;
46
47
};

48
49
using namespace tir;
using arith::IRMutatorWithAnalyzer;
50
using arith::IRVisitorWithAnalyzer;
51
52
53
54
55
56
57

struct LayoutInferenceResult {
  Map<Buffer, Layout> layout_map;
  Map<For, Fragment> for_map;
  Map<For, PrimExpr> predicate_map;
};

58
class BufferUseDefCollector : public IRVisitorWithAnalyzer {
59
public:
60
61
  BufferUseDefCollector(bool skip_thread_partition)
      : skip_thread_partition_(skip_thread_partition) {}
62
63

  LayoutInferenceResult Run() {
64
65
66
67
68
    // Basic consistency check: infer_list_ and thread_var_vec_ should have the
    // same size
    ICHECK_EQ(infer_list_.size(), thread_var_vec_.size())
        << "Size mismatch: infer_list_ and thread_var_vec_ must match in "
           "length.";
69
70
71
    ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size())
        << "Size mismatch: thread_bounds_vec_ and infer_list_ must match in "
           "length.";
72
73
74
75
76

    // If needed, you can also check that annotated_layout_map_ is not empty, or
    // anything else relevant to your setup.

    // Copy the annotated layout map to local variable
77
    Map<Buffer, Layout> layout_map = annotated_layout_map_;
78
    Map<Buffer, Layout> strict_layout_map;
79
80
    int num_infer = infer_list_.size();

81
    // Prepare BFS queue for iterative inference
82
83
    std::queue<int> q;
    std::vector<bool> in_queue(num_infer, true);
84
85
86
87
88
89
90
91
92
93
    for (int i = 0; i < num_infer; i++) {
      // Check that each infer_list_ entry is valid
      ICHECK(infer_list_[i] != nullptr)
          << "infer_list_[" << i
          << "] is null. The inference object is not allocated properly.";

      // Check that each thread_var_vec_ entry is defined
      if (!thread_var_vec_[i].defined() && skip_thread_partition_) {
        thread_var_vec_[i] = thread_var_;
      }
94
      q.push(i);
95
    }
96

97
98
    auto run_infer_step = [&](int cur_infer_id, InferLevel level,
                              bool update_queue) {
99
100
101
102
103
104
105
106
107
      // Range check for cur_infer_id
      ICHECK_GE(cur_infer_id, 0)
          << "cur_infer_id is negative, which is invalid.";
      ICHECK_LT(cur_infer_id, num_infer)
          << "cur_infer_id " << cur_infer_id << " is out of range, must be < "
          << num_infer << ".";

      // Make sure we can safely access infer_list_[cur_infer_id] and
      // thread_var_vec_[cur_infer_id]
108
      auto &next = infer_list_[cur_infer_id];
109
      auto iter_var = thread_var_vec_[cur_infer_id];
110
      auto thread_bounds = thread_bounds_vec_[cur_infer_id];
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
      // Double-check that 'next' is valid
      ICHECK(next != nullptr) << "infer_list_[" << cur_infer_id
                              << "] is null inside run_infer_step.";

      // Check iter_var->dom and dom->extent
      ICHECK(iter_var.defined())
          << "thread_var_vec_[" << cur_infer_id << "] is not defined.";
      ICHECK(iter_var->dom.defined())
          << "iter_var->dom is not defined for infer_list_[" << cur_infer_id
          << "].";
      ICHECK(iter_var->dom->extent.defined())
          << "iter_var->dom->extent is not defined for infer_list_["
          << cur_infer_id << "].";

      const int64_t *extent_ptr = as_const_int(iter_var->dom->extent);
      ICHECK(extent_ptr != nullptr)
          << "iter_var->dom->extent is not a constant integer, which is "
             "required for layout inference.";

      // Run InferLayout
131
      auto updates = next->InferLayout(
132
          LayoutInferArgs{target_, thread_bounds, layout_map}, level);
133
      // Process the returned updates
134
      for (const auto &[buffer, layout] : updates) {
135
136
137
138
        // Basic validity checks
        ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
        ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";

139
        if (layout_map.count(buffer)) {
140
141
          // If replicate size of this buffer is greater than the old one
          if (buffer.scope() == "local.fragment" &&
142
143
              level != InferLevel::kStrict) {
            const FragmentNode *dst_layout = layout.as<FragmentNode>();
144
            const FragmentNode *src_layout =
145
                layout_map[buffer].as<FragmentNode>();
146
147
148
149
150
151
152
153
154
            if (as_const_int(dst_layout->ReplicateExtent()) &&
                as_const_int(src_layout->ReplicateExtent()) &&
                (*as_const_int(dst_layout->ReplicateExtent()) >
                 *as_const_int(src_layout->ReplicateExtent()))) {
              // update map
              layout_map.Set(buffer, layout);
              continue;
            }
          }
155
          // If already in map, ensure they are structurally equal
156
157
158
159
160
161
162
163
164
165
166
          // (zhengju) We can not modify the strict layout map when current
          // level is not strict. This check should be done in certain
          // conditions, since the strict layout map is not updated in the
          // above code when current level is not strict
          if (level == InferLevel::kStrict ||
              !strict_layout_map.count(buffer)) {
            ICHECK(StructuralEqual()(layout, layout_map[buffer]))
                << "Get different layout for " << buffer
                << "\n current layout: " << layout->DebugOutput()
                << "\n previous layout: " << layout_map[buffer]->DebugOutput();
          }
167
        } else {
168
          // Otherwise, update map
169
          layout_map.Set(buffer, layout);
170
171
          if (!update_queue)
            continue;
172
173

          // Check if buffer exists in use_list_
174
          if (!use_list_.count(buffer)) {
175
176
177
178
            LOG(WARNING) << "Layout inference failed for buffer " << buffer
                         << ". "
                         << "The buffer cannot be inferred with current layout "
                            "inference rules.";
179
180
            continue;
          }
181
182

          // Push back into BFS queue
183
          for (int idx : use_list_[buffer]) {
184
185
186
187
188
189
            ICHECK_GE(idx, 0) << "Index in use_list_ for buffer " << buffer
                              << " is negative.";
            ICHECK_LT(idx, num_infer)
                << "Index in use_list_ for buffer " << buffer
                << " out of range: " << idx << " >= " << num_infer << ".";

190
191
192
193
194
195
196
197
            if (!in_queue[idx] && idx != cur_infer_id) {
              in_queue[idx] = true;
              q.push(idx);
            }
          }
        }
      }
    };
198

199
200
201
202
    auto finish_infer_queue = [&]() {
      while (!q.empty()) {
        int cur_infer_id = q.front();
        q.pop();
203
204
205
206
        // Range check again, just to be safe
        ICHECK_GE(cur_infer_id, 0);
        ICHECK_LT(cur_infer_id, num_infer);

207
208
209
210
211
        in_queue[cur_infer_id] = false;
        run_infer_step(cur_infer_id, InferLevel::kCommon, true);
      }
    };

212
    // step 1: infer strict layout
213
214
215
216
    for (int i = 0; i < num_infer; i++) {
      run_infer_step(i, InferLevel::kStrict, false);
    }

217
218
219
220
    for (const auto &[buffer, layout] : layout_map) {
      strict_layout_map.Set(buffer, layout);
    }

221
    // step 2: infer common layout with BFS
222
    finish_infer_queue();
223

224
    // step 3: relax constraints to free and re-run
225
226
227
228
    for (int i = 0; i < num_infer; i++) {
      run_infer_step(i, InferLevel::kFree, true);
      finish_infer_queue();
    }
229
    // Check that all local.fragment buffers have inferred layouts
230
    for (const auto &[buffer, _] : use_list_) {
231
232
233
234
235
      if (buffer.scope() == "local.fragment") {
        ICHECK_NE(layout_map.count(buffer), 0)
            << "The layout for fragment " << buffer
            << " can not be inferred correctly.";
      }
236
237
    }

238
    // Collect layout info for For nodes
239
240
    Map<For, Fragment> for_map;
    Map<For, PrimExpr> predicate_map;
241
242
243
244
245
246
    ICHECK(infer_list_.size() == thread_var_vec_.size())
        << "infer_list_ and thread_var_vec_ size mismatch";
    for (int i = 0; i < infer_list_.size(); i++) {
      std::unique_ptr<Operator> base_infer = std::move(infer_list_[i]);
      auto thread_var = thread_var_vec_[i];

247
248
249
      // Check if base_infer is valid
      ICHECK(base_infer != nullptr) << "Null pointer encountered in "
                                       "infer_list_ while collecting for_map.";
250
      if (auto for_infer = dynamic_cast<ParallelOp *>(base_infer.get())) {
251
        // Check that the loop layout is defined
252
        ICHECK(for_infer->GetLoopLayout().defined())
253
            << "The Layout for Parallel for cannot be inferred correctly:\n"
254
255
            << for_infer->GetRoot();
        for_map.Set(for_infer->GetRoot(), for_infer->GetLoopLayout());
256
        // thread_var_ should be defined if we rely on it
257
258
        ICHECK(thread_var.defined())
            << "thread_var is not defined. Cannot retrieve predicate.";
259

260
        if (auto predicate = for_infer->GetPredicate(thread_var->var)) {
261
          predicate_map.Set(for_infer->GetRoot(), predicate.value());
262
        }
263
264
265
266
267
268
      }
    }

    return {layout_map, for_map, predicate_map};
  }

269
270
  void Collect(const PrimFunc &f) {
    for (const auto &[_, buffer] : f->buffer_map) {
271
272
273
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
274
275
    ICHECK(target.defined())
        << "Layout_Inference: Require the target attribute";
276
277
278
279
    target_ = target.value();
    this->operator()(f->body);
  }

280
281
private:
  void VisitExpr_(const CallNode *op) final {
282
    IRVisitorWithAnalyzer::VisitExpr_(op);
283
    // Do not analysis the call node to the global function.
284
285
    if (op->op.as<GlobalVarNode>())
      return;
286
287
288

    auto p = ParseOperator(GetRef<Call>(op), buffer_data_to_buffer_);
    if (p != nullptr) {
289
      for (const auto &arg : op->args) {
290
291
292
293
294
295
        if (auto buffer = getBufferFromAccessPtr(arg)) {
          addToUseList(buffer.value());
        }
      }
      infer_list_.push_back(std::move(p));
      thread_var_vec_.push_back(thread_var_);
296
297
298
299
      if (analyzer_.const_int_bound.IsBound(thread_var_->var)) {
        auto const_int_bound = analyzer_.const_int_bound(thread_var_);
        auto min_value = const_int_bound->min_value;
        auto max_value = const_int_bound->max_value;
300
        auto extent = max_value - min_value + 1;
301
302
        auto dtype = thread_var_->var.dtype();
        thread_bounds_vec_.push_back(Range::FromMinExtent(
303
            IntImm(dtype, min_value), IntImm(dtype, extent)));
304
305
306
      } else {
        thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
      }
307
308
309
    }
  }

310
  Optional<Buffer> getBufferFromAccessPtr(const PrimExpr &expr) {
311
312
313
314
315
    auto call = expr.as<CallNode>();
    if (call && call->op.same_as(builtin::tvm_access_ptr())) {
      auto var = call->args[1].as<Var>().value();
      return buffer_data_to_buffer_[var];
    }
316
    return std::nullopt;
317
318
  }

319
  void addToUseList(const Buffer &buffer) {
320
321
322
323
324
325
326
    int infer_idx = infer_list_.size();
    if (use_list_.find(buffer) == use_list_.end()) {
      use_list_[buffer] = {};
    }
    use_list_[buffer].push_back(infer_idx);
  }

327
  void VisitStmt_(const ForNode *op) final {
328
329
    if (op->kind == ForKind::kParallel) {
      auto infer = std::make_unique<ParallelOp>(GetRef<For>(op));
330
      for (const auto &[buffer, _] : infer->GetIndiceMap()) {
331
332
333
334
        addToUseList(buffer);
      }
      infer_list_.push_back(std::move(infer));
      thread_var_vec_.push_back(thread_var_);
335
336
337
338
      if (thread_var_.defined() &&
          analyzer_.const_int_bound.IsBound(thread_var_->var)) {
        auto const_int_bound = analyzer_.const_int_bound(thread_var_);
        auto dtype = thread_var_->var.dtype();
339
340
        auto extent =
            const_int_bound->max_value - const_int_bound->min_value + 1;
341
        thread_bounds_vec_.push_back(Range::FromMinExtent(
342
            IntImm(dtype, const_int_bound->min_value), IntImm(dtype, extent)));
343
344
345
      } else {
        thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
      }
346
    } else {
347
      IRVisitorWithAnalyzer::VisitStmt(op->body);
348
349
350
    }
  }

351
  void VisitStmt_(const BlockNode *op) final {
352
353
354
355
    for (auto buffer : op->alloc_buffers) {
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    if (op->annotations.count(attr::kLayoutMap)) {
356
      // Check if the layout map is Map<Var, Layout>
357
358
359
      auto map =
          op->annotations.Get(attr::kLayoutMap)->as<Map<Var, Layout>>().value();
      for (const auto &[var, layout] : map) {
360
361
        ICHECK(buffer_data_to_buffer_.count(var))
            << "buffer " << var << " is not found in the block";
362
363
364
365
366
        auto buffer = buffer_data_to_buffer_[var];
        ICHECK(StructuralEqual()(layout->InputShape(), buffer->shape));
        annotated_layout_map_.Set(buffer, layout);
      }
    }
367
    IRVisitorWithAnalyzer::VisitStmt_(op);
368
369
  }

370
  void VisitStmt_(const AttrStmtNode *op) final {
371
372
373
374
375
376
377
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      if (iv->thread_tag == "threadIdx.x") {
        ICHECK(iv->dom->extent.as<IntImmNode>());
        thread_var_ = iv;
      }
    }
378
    IRVisitorWithAnalyzer::VisitStmt_(op);
379
380
381
382
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  std::vector<std::unique_ptr<Operator>> infer_list_;
383
384
  std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
      use_list_;
385
386
387
388
  // This is a workaround for cpu backend,
  // we need to define a thread_var for the serial loop.
  IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
                                IterVarType::kDataPar);
389
  std::vector<IterVar> thread_var_vec_;
390
  std::vector<Range> thread_bounds_vec_;
391
392
  Target target_;
  LayoutMap annotated_layout_map_;
393
  bool skip_thread_partition_{false};
394
395
396
};

class LayoutInferencer : public IRMutatorWithAnalyzer {
397
public:
398
  static PrimFunc Substitute(PrimFunc f, bool skip_thread_partition = false) {
399
    arith::Analyzer analyzer;
400
    PrimFuncNode *fptr = f.CopyOnWrite();
401
    fptr->body = ParallelLoopFuser::Fuse(f->body);
402
    BufferUseDefCollector collector(skip_thread_partition);
403
404
    collector.Collect(f);
    auto result = collector.Run();
405
    LayoutInferencer substituter(result, skip_thread_partition, &analyzer);
406
407
408
409
    fptr->body = substituter.VisitStmt(f->body);
    return f;
  }

410
411
private:
  LayoutInferencer(const LayoutInferenceResult result,
412
413
414
                   bool skip_thread_partition, arith::Analyzer *analyzer)
      : arith::IRMutatorWithAnalyzer(analyzer), result_(result),
        skip_thread_partition_(skip_thread_partition){};
415

416
  Stmt VisitStmt_(const BlockNode *op) final {
417
418
419
420
421
422
423
424
425
426
427
428
429
    Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));

    for (auto buffer : block->alloc_buffers) {
      if (buffer.scope() == "local.framgent") {
        ICHECK(result_.layout_map.count(buffer))
            << "Cannot inference fragment layout for " << buffer;
      }
    }
    auto block_ptr = block.CopyOnWrite();
    block_ptr->annotations.Set(attr::kLayoutMap, result_.layout_map);
    return block;
  }

430
  Stmt VisitStmt_(const ForNode *op) final {
431
432
    For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
    if (result_.for_map.count(GetRef<For>(op))) {
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
      auto root = GetRef<For>(op);
      // This check is a workaround to support T.Parallel for local buffers.
      // For example:
      //   for i in T.Parallel(1024):
      //     A_local[i] = A_global[i]
      // Here, A_local is a register-local buffer held independently by each
      // thread, so explicit thread binding is not required.
      //
      // We use PostOrderVisit to detect whether the buffer store targets a
      // "local" buffer, which indicates register usage and justifies skipping
      // thread binding.
      bool is_register_store = false;
      PostOrderVisit(root, [&](const ObjectRef &obj) {
        if (const auto *store = obj.as<BufferStoreNode>()) {
          if (store->buffer.scope() == "local") {
            is_register_store = true;
          }
        }
      });

453
      auto loop_layout = result_.for_map[root];
454
      bool parallel_loop = !is_register_store && !skip_thread_partition_;
455

456
      if (parallel_loop) {
457
458
459
        for_node =
            PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
      }
460
      // If none thread bindings are provided, partition the loop
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
      bool has_non_local = false;
      PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
        if (const auto *load = obj.as<BufferLoadNode>()) {
          String scope = load->buffer.scope();
          if (scope != "local" && scope != "local.fragment") {
            has_non_local = true;
          }
        } else if (const auto *store = obj.as<BufferStoreNode>()) {
          String scope = store->buffer.scope();
          if (scope != "local" && scope != "local.fragment") {
            has_non_local = true;
          }
        }
      });

      if (has_non_local) {
        for_node = VectorizeLoop(for_node);
      }
479

480
481
      if (result_.predicate_map.count(root) && parallel_loop) {
        return IfThenElse(result_.predicate_map[root], for_node);
482
483
484
485
486
487
488
      } else {
        return for_node;
      }
    }
    return for_node;
  }

489
  Stmt VisitStmt_(const AttrStmtNode *op) final {
490
491
492
493
494
495
496
497
498
499
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      ICHECK_NE(iv->thread_tag.length(), 0U);
      if (iv->thread_tag == "threadIdx.x") {
        thread_var_ = iv;
      }
    }
    return IRMutatorWithAnalyzer::VisitStmt_(op);
  }

500
private:
501
  const LayoutInferenceResult result_;
502
503
  IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
                                IterVarType::kDataPar);
504
  bool skip_thread_partition_{false};
505
506
507
508
509
};

tvm::transform::Pass LayoutInference() {
  using namespace tir::transform;
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
510
    f.CopyOnWrite()->body = ParallelLoopTransformer::Substitute(f->body);
511
    ThreadBindingCollector collector;
512
    collector(f->body);
513
514
    bool has_thread_binding = collector.thread_binding_.size() > 0;
    bool skip_thread_partition = !has_thread_binding;
515
    return LayoutInferencer::Substitute(std::move(f), skip_thread_partition);
516
517
518
519
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {});
}

520
521
522
523
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.LayoutInference", LayoutInference);
});
524

525
526
} // namespace tl
} // namespace tvm