inject_pipeline.cc 44.7 KB
Newer Older
1
2
/*!
 * \file inject_software_pipeline.cc
3
4
 * \brief Transform annotated loops into pipelined one that parallelize
 * producers and consumers
5
6
7
8
9
 */
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h>

10
#include <functional>
11
#include <unordered_set>
12
#include <utility>
13
14
15
16
17
18
19
20

#include "support/utils.h"
#include "tir/schedule/utils.h"
#include "tir/transforms/ir_utils.h"

namespace tvm {
namespace tl {
using namespace tir;
21
using namespace ffi;
22
23
namespace software_pipeline {

24
25
26
27
28
struct LetWrapper {
  Var var;
  PrimExpr value;
};

29
30
31
/*!
 * \brief Create a block and infer the access region with the given body.
 *
32
33
34
 * The result is a opaque block that doesn't contain any block iter vars. In
 * case the body is a block realize without predicate, it is unnecessary to
 * create a new block, the block of the block realize will be returned.
35
36
37
38
39
 *
 * \param body The body of the block.
 * \param buffer_data_to_buffer The map from buffer data to buffer.
 * \return The result block.
 */
40
41
42
Block MakeBlock(const Stmt &body,
                const Map<Var, Buffer> &buffer_data_to_buffer) {
  if (const BlockRealizeNode *block_realize = body.as<BlockRealizeNode>()) {
43
44
45
46
47
    if (is_one(block_realize->predicate)) {
      // no need to create a new block
      return block_realize->block;
    }
  }
48
49
50
51
52
  Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
              /*body*/ body);
  Array<Array<BufferRegion>> access =
      GetBlockReadWriteRegion(block, buffer_data_to_buffer);
  BlockNode *n = block.CopyOnWrite();
53
54
55
56
57
58
59
60
61
62
  n->reads = access[0];
  n->writes = access[1];
  return block;
}

/*! Structure that represents the provided annotation per block or loop. */
struct PipelineAnnotation {
  int stage;
  int order;
  bool async;
63
64
  // Index of the statement in the original loop body order (SeqStmt order)
  int original_idx = -1;
65
66
};

67
68
using PipelineInfo = std::unordered_map<Block, PipelineAnnotation,
                                        ObjectPtrHash, ObjectPtrEqual>;
69
70

struct BufferAccessInfo {
71
72
  int def = -1; // the defining stage of the buffer
  int use = -1; // the last using stage of the buffer
73
74
75
};

/*!
76
77
78
 * \brief Rewriter for the body of the software pipeline. This pass inserts
 * `floormod` to indices of the remapped buffer to select the version
 * corresponding to the pipeline stage.
79
80
 */
class PipelineBodyRewriter : public StmtExprMutator {
81
public:
82
83
84
  /*!
   * \brief Constructor of PipelineBodyRewriter.
   * \param buffer_data_to_buffer The map from buffer data to buffer.
85
86
87
88
89
90
   * \param buffer_remap The map from original buffer to the buffer with updated
   * shape for multi-versioning in the software pipeline. \param pipeline_loop
   * The original loop to be software pipelined. \param access_all_versions
   * Whether all versions the buffers in the software pipeline are accessed.
   * This will be used to update block access region. In the prologue and
   * epilogue of a two-stage software pipeline, only one version of these
91
   * buffers are accessed.
92
   */
93
94
95
  PipelineBodyRewriter(const Map<Var, Buffer> &buffer_data_to_buffer,
                       const Map<Buffer, Buffer> &buffer_remap,
                       For pipeline_loop, bool access_all_versions)
96
      : buffer_data_to_buffer_(buffer_data_to_buffer),
97
        buffer_remap_(buffer_remap), pipeline_loop_(std::move(pipeline_loop)),
98
        access_all_versions_(access_all_versions) {}
99

100
101
102
private:
  BufferRegion
  RewritePipelineBufferRegion(const BufferRegion &buffer_region) const {
103
104
105
    auto it = buffer_remap_.find(buffer_region->buffer);
    if (it != buffer_remap_.end()) {
      Region new_region = buffer_region->region;
106
107
108
      const Buffer &new_buffer = (*it).second;
      // For pipeline buffers, relax the access region of the first dimension to
      // full extent if access_all_versions == true
109
110
111
      Range accessed_version =
          access_all_versions_
              ? Range::FromMinExtent(0, new_buffer->shape[0])
112
113
114
115
              : Range::FromMinExtent(
                    floormod((pipeline_loop_->loop_var - pipeline_loop_->min),
                             new_buffer->shape[0]),
                    Integer(1));
116
117
118
119
120
121
      new_region.insert(new_region.begin(), accessed_version);
      return BufferRegion(new_buffer, new_region);
    }
    return buffer_region;
  }

122
  PrimExpr RewriteBufferAccess(const Call &call,
123
                               const std::vector<int> &arg_indices) {
124
125
    auto product = [](const Array<PrimExpr> &input) {
      return foldl(
126
127
128
          [](PrimExpr a, PrimExpr b, Span span) {
            return mul(std::move(a), std::move(b), std::move(span));
          },
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
          make_const(DataType::Int(32), 1), input);
    };
    Array<PrimExpr> new_args = call->args;
    for (int i : arg_indices) {
      const Buffer &buffer =
          buffer_data_to_buffer_.at(Downcast<Var>(call->args[i]));
      auto it = buffer_remap_.find(buffer);
      if (it != buffer_remap_.end()) {
        const Buffer &new_buffer = (*it).second;
        const PrimExpr &old_index = call->args[i + 1];
        PrimExpr offset;
        if (new_buffer->strides.empty()) {
          offset = product(buffer->shape);
        } else {
          offset = new_buffer->strides[0];
        }
        PrimExpr new_index =
            old_index +
            floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
        new_args.Set(i + 1, new_index);
      }
    }
    return Call(call->dtype, call->op, new_args, call->span);
  }

154
155
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const Buffer &alloc_buffer : op->alloc_buffers) {
156
157
158
      buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer);
    }
    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
159
160
    BlockNode *n = block.CopyOnWrite();
    n->reads.MutateByApply([this](const BufferRegion &buffer_region) {
161
162
      return RewritePipelineBufferRegion(buffer_region);
    });
163
    n->writes.MutateByApply([this](const BufferRegion &buffer_region) {
164
165
      return RewritePipelineBufferRegion(buffer_region);
    });
166
    for (const Buffer &alloc_buffer : op->alloc_buffers) {
167
168
      buffer_data_to_buffer_.erase(alloc_buffer->data);
    }
169
    return block;
170
171
  }

172
  Stmt VisitStmt_(const BufferStoreNode *op) final {
173
174
175
    BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
    auto it = buffer_remap_.find(store->buffer);
    if (it == buffer_remap_.end()) {
176
      return store;
177
    }
178
179
    const Buffer &new_buffer = (*it).second;
    auto *n = store.CopyOnWrite();
180
    n->buffer = new_buffer;
181
182
    PrimExpr version = floormod(
        (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
183
    n->indices.insert(n->indices.begin(), version);
184
    return store;
185
186
  }

187
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
188
189
190
    BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
    auto it = buffer_remap_.find(load->buffer);
    if (it == buffer_remap_.end()) {
191
      return load;
192
    }
193
194
    const Buffer &new_buffer = (*it).second;
    auto *n = load.CopyOnWrite();
195
    n->buffer = new_buffer;
196
197
    PrimExpr version = floormod(
        (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
198
    n->indices.insert(n->indices.begin(), version);
199
    return load;
200
201
  }

202
  PrimExpr VisitExpr_(const CallNode *op) final {
203
    Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
204
205
206
207
    if (call->op.same_as(builtin::tvm_access_ptr())) {
      return RewriteBufferAccess(call, {1});
    }
    return call;
208
209
210
211
212
213
214
215
216
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Map<Buffer, Buffer> buffer_remap_;
  For pipeline_loop_;
  bool access_all_versions_;
};

/*!
217
218
 * \brief Rewriter for the software pipeline that rewrite a loop into a
 * pipelined one.
219
220
 */
class PipelineRewriter : public StmtExprMutator {
221
public:
222
223
  PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
                   const Array<Buffer> &pipeline_allocs,
224
225
                   const For &pipeline_loop, const PipelineInfo &pipeline_info,
                   const std::vector<LetWrapper> &loop_var_let_wrappers)
226
      : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
227
        pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop),
228
229
        pipeline_info_(pipeline_info),
        loop_var_let_wrappers_(loop_var_let_wrappers) {}
230
231

  Stmt BuildPipeline() {
232
233
234
235
236
    // Step 1: Analyze accesses to the buffers in the pipeline and compute the
    // number of versions need to maintain for each buffer.
    std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
        infos = GetBufferAccessInfo();
    for (const Buffer &buffer : pipeline_allocs_) {
237
238
239
240
241
242
      int num_versions = ComputeBufferVersions(buffer, infos.at(buffer));
      if (num_versions > 1) {
        buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions));
      }
    }
    ordered_stmts_.resize(pipeline_info_.size());
243
244
    for (const auto &[block, anno] : pipeline_info_) {
      ordered_stmts_.Set(anno.order, block);
245
246
    }

247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
    for (const Block &block : ordered_stmts_) {
      int stage = pipeline_info_[block].stage;
      if (pipeline_info_[block].async) {
        auto &state = async_states[stage];
        state.producer_head = pipeline_loop_->min - 1;
        for (auto write_region : block->writes) {
          auto buffer = write_region->buffer;
          state.dst_buffers.insert(buffer.get());
          if (buffer_remap_.count(buffer))
            state.dst_buffers.insert(buffer_remap_[buffer].get());
        }
      }
    }
    std::unordered_set<int> consumed;
    for (const Block &block : ordered_stmts_) {
      int stage = pipeline_info_[block].stage;
      if (pipeline_info_[block].async) {
        auto &state = async_states[stage];
        if (state.commit_groups.empty() || consumed.count(stage)) {
          state.commit_groups.push_back({});
        }
        state.commit_groups.back().push_back(pipeline_info_[block].order);
        consumed.erase(stage);
        for (auto write_region : block->writes) {
          auto buffer = buffer_remap_.count(write_region->buffer)
                            ? buffer_remap_[write_region->buffer]
                            : write_region->buffer;
          state.buffer_to_commit_group_[buffer.get()] =
              state.commit_groups.size() - 1;
        }
      }

      for (auto read_region : block->reads) {
        for (const auto &[producer_stage_id, producer_state] : async_states) {
          if (producer_stage_id <= stage &&
              producer_state.writes(read_region->buffer)) {
            consumed.insert(producer_stage_id);
          }
        }
286
287
      }
    }
288
289

    // Step 2: Emit the pipeline prologue, body and epilogue.
290
291
292
293
294
295
296
297
298
299
300
    Stmt prologue =
        EmitImpl(pipeline_loop_->min, pipeline_loop_->min + max_stage_, true,
                 true, false);
    Stmt body = EmitImpl(pipeline_loop_->min + max_stage_,
                         pipeline_loop_->min + pipeline_loop_->extent, false,
                         false, false);

    Stmt epilogue =
        EmitImpl(pipeline_loop_->min + pipeline_loop_->extent,
                 pipeline_loop_->min + pipeline_loop_->extent + max_stage_,
                 true, true, true);
301
302
    SeqStmt stmt = SeqStmt({prologue, body, epilogue});

303
304
    // Step 3: Make a new block that contains new buffer allocations after
    // pipeline rewriting.
305
    Array<Buffer> alloc_buffers;
306
    for (const auto &alloc : pipeline_allocs_) {
307
308
309
310
311
312
313
314
      alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc));
      buffer_data_to_buffer_.erase(alloc->data);
    }
    Block block = MakeBlock(stmt, buffer_data_to_buffer_);
    block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers);
    return BlockRealize({}, Bool(true), block);
  }

315
private:
316
317
318
  /*!
   * \brief Analyze accesses to the buffers in the software pipeline.
   *
319
320
321
   * This method check the 'define' and 'use' stage of the buffers in the
   * software pipeline, which can be used to compute the number of versions
   * needed to maintain after rewriting.
322
323
324
   */
  std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
  GetBufferAccessInfo() {
325
326
327
328
    std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
        infos;
    for (const auto &pair : pipeline_info_) {
      const Block &block = pair.first;
329
330
331
      int stage = pair.second.stage;
      max_stage_ = std::max(max_stage_, stage);

332
      for (const BufferRegion &write : block->writes) {
333
334
335
        if (!infos.count(write->buffer)) {
          infos.emplace(write->buffer, BufferAccessInfo{});
        }
336
        auto &info = infos.at(write->buffer);
337
338
339
340
341
342
343
        if (info.def == -1) {
          info.def = stage;
        } else {
          info.def = std::min(info.def, stage);
        }
      }

344
      for (const BufferRegion &read : block->reads) {
345
346
347
        if (!infos.count(read->buffer)) {
          infos.emplace(read->buffer, BufferAccessInfo{});
        }
348
        auto &info = infos.at(read->buffer);
349
350
351
352
353
354
355
356
357
358
359
360
        info.use = std::max(info.use, stage);
      }
    }
    return infos;
  }

  /*!
   * \brief Check whether two regions have intersections.
   * \param region1 The first region.
   * \param region2 The second region.
   * \return Whether region1 and region2 have intersections.
   */
361
  bool MayConflict(const Region &region1, const Region &region2) {
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    ICHECK(region1.size() == region2.size());
    for (size_t i = 0; i < region1.size(); i++) {
      Range dim1 = region1[i];
      Range dim2 = region2[i];
      auto int_set1 = arith::IntSet::FromRange(dim1);
      auto int_set2 = arith::IntSet::FromRange(dim2);
      if (arith::Intersect({int_set1, int_set2}).IsNothing()) {
        return false;
      }
    }
    return true;
  }

  /*!
376
377
   * \brief Compute the number of versions need to maintain for buffer accessed
   * in the software pipeline.
378
   *
379
380
381
382
383
384
   * This method applies liveness analysis to the target buffer to compute the
   * number of versions need to maintain during the software pipeline.
   * Annotation `attr::double_buffer_scope` is handled here which provides a way
   * to override the result of the analysis. Additional double buffering in the
   * software pipeline can be useful to eliminate synchronizations in GPU
   * devices.
385
386
387
388
389
   *
   * \param buffer The target buffer
   * \param buffer_info The access information of the target buffer.
   * \return The number of versions required for the target buffer.
   */
390
391
  int ComputeBufferVersions(const Buffer &buffer,
                            const BufferAccessInfo &buffer_info) {
392
    if (buffer_info.def == -1) {
393
394
      // Keep the original number of versions as buffers defined outside the
      // software pipeline should not be mutated.
395
396
397
398
      return 1;
    }

    // `use - def + 1` is a upper bound of the needed versions
399
400
    // We optimize a few case where the number of versions can be smaller than
    // the upper bound
401
    int num_versions = buffer_info.use - buffer_info.def + 1;
402
    if (num_versions >= 2) {
403
404
405
406
      // A special case when `use - def + 1 == 2`. Double buffering is only
      // needed in this case when these exists a reader block_i and a writer
      // block_j such that order(block_i) < order(block_j) and stage(block_i) <
      // stage(block_j) and the access regions of block_i and block_j overlap.
407
      bool need_multi_version = false;
408
409
410
      for (const auto &pair1 : pipeline_info_) {
        const Block &writer_block = pair1.first;
        const auto &writer_info = pair1.second;
411

412
413
414
        auto it1 = std::find_if(writer_block->writes.begin(),
                                writer_block->writes.end(),
                                [&](const BufferRegion &buffer_region) {
415
416
417
418
419
420
                                  return buffer_region->buffer.same_as(buffer);
                                });
        if (it1 == writer_block->writes.end()) {
          continue;
        }

421
422
423
424
425
426
427
428
        for (const auto &pair2 : pipeline_info_) {
          const Block &reader_block = pair2.first;
          const auto &reader_info = pair2.second;
          auto it2 = std::find_if(
              reader_block->reads.begin(), reader_block->reads.end(),
              [&](const BufferRegion &buffer_region) {
                return buffer_region->buffer.same_as(buffer);
              });
429
430
431
          if (it2 == reader_block->reads.end()) {
            continue;
          }
432
433
          if (writer_info.order < reader_info.order &&
              writer_info.stage < reader_info.stage &&
434
435
436
437
438
439
440
              MayConflict((*it1)->region, (*it2)->region)) {
            need_multi_version = true;
            break;
          }
        }
      }
      if (!need_multi_version) {
441
        num_versions--;
442
443
444
445
446
447
      }
    }
    return num_versions;
  }

  /*!
448
449
   * \brief Rewrite buffer allocation to keep multiple versions of original
   * buffer for pipelined accesses. \param buffer The buffer to be resized.
450
451
452
   * \param num_versions The number of versions to keep.
   * \return The resized buffer.
   */
453
  Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
454
455
    ObjectPtr<BufferNode> new_buffer =
        tvm::ffi::make_object<BufferNode>(*(buffer.get()));
456
    new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
457
    if (!new_buffer->strides.empty()) {
458
459
460
461
462
463
464
      ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
      PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1];
      new_buffer->strides.insert(new_buffer->strides.begin(), stride_0);
    }
    return Buffer(new_buffer);
  }

465
466
  // Per-stage states that need to be tracked across pipeline prologue, body,
  // and epilogue.
467
468
  struct AsyncStateGlobal {
    // Buffers that this stage asynchronously writes.
469
470
471
472
473
474
475
    std::unordered_set<const BufferNode *> dst_buffers;
    // An imaginary index that the latest async operation associated with this
    // stage has written into. Only valid if all associated predicates are true,
    // so that we can count the number of async invocations exactly. When it is
    // valid, it is the "sum of extents of loops that have been executed" - 1,
    // e.g. for epilogue it is prologue extent + body extent - 1. This is only
    // needed to compute wait count for epilogue without async producers.
476
477
478
    PrimExpr producer_head;
    std::vector<std::vector<int>> commit_groups;
    std::unordered_map<const BufferNode *, int> buffer_to_commit_group_;
479
480
481
    bool writes(const Buffer &buf) const {
      return dst_buffers.count(buf.get()) > 0;
    }
482
483
  };

484
485
  // Per-stage states that are local to each of pipeline prologue, body, and
  // epilogue.
486
  struct AsyncStateLocal {
487
    struct PendingWait {
488
489
      // The index into a list of blocks, where async_wait_queue should be
      // attached at the beginning.
490
      int insert_before;
491
492
      // in_flight_count would be a more precise name, but the implementation
      // uses wait_count for brevity.
493
494
495
      PrimExpr wait_count{nullptr};

      bool valid() const { return wait_count.defined(); }
496
497
498
    };

    std::vector<PendingWait> pending_waits;
499

500
501
    // A symbolic expression representing the index the latest async operation
    // associated with this stage has written into, at the "current" iteration.
502
    Optional<PrimExpr> producer_head;
503
504
    // the commit block's predicate
    PrimExpr commit_predicate{nullptr};
505
506
507
508
509
  };

  /*! Structure holding intermediate information for pipeline loop rewriting. */
  struct RewrittenBlockInfo {
    int stage;
510
    int order;
511
512
    PrimExpr start;
    PrimExpr end;
513
514
515
516
517
518
    PrimExpr predicate;
    Block block;
    PrimExpr access_index;
    bool is_async;
  };

519
  void PopulateWaitCounts(const std::vector<RewrittenBlockInfo> &new_blocks,
520
521
522
523
524
525
526
527
528
                          std::map<int, AsyncStateLocal> *async_states_local,
                          bool is_epilogue = false) {
    // Precompute which orders are present in this emit, and their access_index
    std::unordered_map<int, PrimExpr> order_to_access_index;
    std::unordered_set<int> present_orders;
    for (const auto &nb : new_blocks) {
      order_to_access_index[nb.order] = nb.access_index;
      present_orders.insert(nb.order);
    }
529
    for (size_t i = 0; i < new_blocks.size(); ++i) {
530
      // 1. Find the unique async producer stage
531
      int producer_stage_idx = -1;
532
      for (const auto &read_region : new_blocks[i].block->reads) {
533
534
535
        for (const auto &[stage, state] : async_states) {
          if (stage <= new_blocks[i].stage &&
              state.writes(read_region->buffer)) {
536
            // Currently only a single async stage dependency is supported
537
            ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage)
538
                << "A dependency on multiple async stages is not supported";
539
            producer_stage_idx = stage;
540
541
542
          }
        }
      }
543
544
      if (producer_stage_idx == -1) {
        // This block does not depend on any async producer
545
        continue;
546
      }
547
      const auto &state = async_states[producer_stage_idx];
548

549
      auto &dep_local_state = (*async_states_local)[producer_stage_idx];
550

551
552
      // 2. Use buffer_to_commit_group_ to find all actually dependent commit
      // groups
553
554
      std::unordered_set<int> dependent_groups;
      for (const auto &read_region : new_blocks[i].block->reads) {
555
556
557
558
        auto it = state.buffer_to_commit_group_.find(read_region->buffer.get());
        if (it != state.buffer_to_commit_group_.end()) {
          dependent_groups.insert(it->second);
        }
559
      }
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612

      // If there is no dependent commit group, no wait needs to be inserted
      if (dependent_groups.empty()) {
        continue;
      }

      // 3. Compute wait = max_g max(0, t_consumer - committed_before[g])
      PrimExpr t_consumer = new_blocks[i].access_index;
      PrimExpr wait_expr = make_zero(t_consumer.dtype());

      PrimExpr current_head = dep_local_state.producer_head.defined()
                                  ? dep_local_state.producer_head.value()
                                  : state.producer_head;
      int consumer_order = new_blocks[i].order;

      for (int g : dependent_groups) {
        const auto &group = state.commit_groups[g];
        if (group.empty())
          continue;
        int commit_order = group.back();
        bool commit_present = present_orders.count(commit_order) > 0;

        PrimExpr committed_before;
        if (commit_present && commit_order <= consumer_order) {
          // Commit point is in this iteration and earlier than the current
          // consumer; this iteration's head is visible
          auto commit_predicate = dep_local_state.commit_predicate;
          if (analyzer_.CanProve(!commit_predicate,
                                 arith::ProofStrength::kSymbolicBound)) {
            // it means the commit block is not executed in this iteration
            committed_before = new_blocks[i].start - 1;
          } else if (is_epilogue) {
            committed_before = new_blocks[i].start - 1;
          } else {
            committed_before = order_to_access_index.at(commit_order);
          }
        } else {
          // Commit point is later than the current consumer or not in this
          // iteration; only the previous iteration's head is visible
          if (dep_local_state.producer_head.defined()) {
            auto commit_predicate = dep_local_state.commit_predicate;
            if (analyzer_.CanProve(!commit_predicate,
                                   arith::ProofStrength::kSymbolicBound)) {
              committed_before = new_blocks[i].start - 1;
            } else if (is_epilogue) {
              committed_before = new_blocks[i].start - 1;
            } else {
              committed_before = current_head - 1;
            }
          }
        }

        wait_expr = analyzer_.Simplify(committed_before - t_consumer);
613
      }
614
615
616

      wait_expr = analyzer_.Simplify(wait_expr);
      dep_local_state.pending_waits.push_back({static_cast<int>(i), wait_expr});
617
618
619
    }
  }

620
621
  // Given pipelined blocks and async-related information, generate final loop
  // statements with async scopes (if any).
622
  Array<Stmt> CompletePipelineLoopStatements(
623
      const std::vector<RewrittenBlockInfo> &blocks,
624
      const std::map<int, AsyncStateLocal> &async_states_local) const {
625
    std::vector<RewrittenBlockInfo> new_blocks = blocks;
626
    for (const auto &[stage_id, state] : async_states_local) {
627
628
629
630
631
632
633
      for (const auto &pw : state.pending_waits) {
        auto &block = new_blocks[pw.insert_before].block;
        BlockNode *n = block.CopyOnWrite();
        auto zero = make_zero(DataType::Int(32));
        n->body = AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
                           AttrStmt(zero, tir::attr::async_wait_inflight_count,
                                    pw.wait_count, n->body));
634
      }
635
    }
636

637
638
639
640
641
    // mark the last async stmt as commit
    std::unordered_set<int> commit_group_indices;
    for (const auto &[stage_id, state] : async_states) {
      for (size_t i = 0; i < state.commit_groups.size(); ++i) {
        commit_group_indices.insert(state.commit_groups[i].back());
642
643
644
645
646
      }
    }

    Array<Stmt> stmts;

647
648
649
650
651
652
653
    for (size_t i = 0; i < new_blocks.size(); i++) {
      Block block = new_blocks[i].block;
      if (commit_group_indices.count(new_blocks[i].order)) {
        auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)),
                                           tir::attr::async_commit_queue_scope,
                                           new_blocks[i].stage, block->body);
        block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_);
654
      }
655
      stmts.push_back(BlockRealize({}, new_blocks[i].predicate, block));
656
657
658
659
660
661
662
663
664
665
666
667
    }

    return stmts;
  }

  /*!
   * \brief Emit the pipeline loop in the given range.
   * \param start The start of the range
   * \param end The end of the range
   * \param unroll_loop Whether the loop should be unrolled.
   * \return The result loop.
   */
668
  Stmt EmitImpl(const PrimExpr &start, const PrimExpr &end, bool unroll_loop,
669
                bool need_bound_check, bool is_epilogue = false) {
670
671
    PrimExpr new_loop_var;
    PrimExpr extent = end - start;
672
673
674
    auto make_nop = []() {
      return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {}));
    };
675
676
677

    bool is_unit_loop = analyzer_.CanProveEqual(extent, 1);
    if (is_unit_loop) {
678
      new_loop_var = start; // use constants as the loop var for unit loops
679
680
    } else {
      new_loop_var = pipeline_loop_->loop_var.copy_with_suffix("");
681
682
683
684
685
686
687
688
689
690
691
692
693
694
      // Bind the iteration domain [start, end) to strengthen analyzer facts.
      analyzer_.Bind(Downcast<Var>(new_loop_var),
                     Range::FromMinExtent(start, end - start));
    }
    // Keep the bound constraints active for all analysis below.
    // Only meaningful when the loop var is symbolic (non-unit loop).
    std::unique_ptr<With<arith::ConstraintContext>> ctx_lb_guard;
    std::unique_ptr<With<arith::ConstraintContext>> ctx_ub_guard;
    if (!is_unit_loop) {
      Var loop_iter = Downcast<Var>(new_loop_var);
      ctx_lb_guard.reset(
          new With<arith::ConstraintContext>(&analyzer_, loop_iter >= start));
      ctx_ub_guard.reset(
          new With<arith::ConstraintContext>(&analyzer_, loop_iter < end));
695
696
697
698
699
700
701
    }

    std::vector<RewrittenBlockInfo> new_blocks;

    // Async related
    std::map<int, AsyncStateLocal> async_states_local;

702
    for (const Block &block : ordered_stmts_) {
703
      int stage = pipeline_info_.at(block).stage;
704
      int order = pipeline_info_.at(block).order;
705

706
      PrimExpr inbound = Bool(true);
707
      PrimExpr skewed_loop_var = new_loop_var - stage;
708
      if (need_bound_check)
709
710
711
712
        inbound = And(
            pipeline_loop_->min <= skewed_loop_var,
            (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent));

713
714
715
      Block new_block = Downcast<Block>(
          PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_,
                               pipeline_loop_, max_stage_ != 1)(block));
716
717
718
719

      PrimExpr delta = start - pipeline_loop_->min;
      // This variable corresponds to
      // - "producer_head" if this stage is an async producer
720
721
      // - "consumer_head" if this stage reads from asynchronously written
      // buffers.
722
      PrimExpr normalized_access_index =
723
          is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
724

725
726
      normalized_access_index = analyzer_.Simplify(normalized_access_index);

727
728
      // Adjust the block predicate and the body according to the final loop
      // bound
729
730
731
732
733
      //  [pipeline_loop_->min, extent).
      if (!is_unit_loop) {
        Var loop_iter = Downcast<Var>(new_loop_var);
        inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}});
      }
734
735
      new_block = Downcast<Block>(Substitute(
          new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
736

737
738
739
740
741
742
743
744
745
746
747
748
749
750
      // If there were Let-wrappers outside the original pipeline body that
      // depended on the pipeline loop var, push them into each rewritten
      // block with the correct per-block substitution.
      if (!loop_var_let_wrappers_.empty()) {
        BlockNode *n = new_block.CopyOnWrite();
        Stmt inner = n->body;
        for (const auto &lw : loop_var_let_wrappers_) {
          PrimExpr substituted = Substitute(
              lw.value, {{pipeline_loop_->loop_var, normalized_access_index}});
          inner = LetStmt(lw.var, substituted, inner);
        }
        n->body = inner;
      }

751
      if (pipeline_info_[block].async) {
752
        auto &local_state = async_states_local[stage];
753
        local_state.producer_head = normalized_access_index;
754
        local_state.commit_predicate = inbound;
755
756
757
        BlockNode *n = new_block.CopyOnWrite();
        n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope,
                           1, n->body);
758
759
      }

760
      new_blocks.push_back({stage, order, start, end, inbound, new_block,
761
                            normalized_access_index,
762
                            pipeline_info_[block].async});
763
    }
764

765
    PopulateWaitCounts(new_blocks, &async_states_local, is_epilogue);
766
767

    auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local);
768

769
770
    Stmt new_loop{nullptr};

771
    if (stmts.empty()) {
772
773
      return make_nop();
    }
774

775
776
    if (stmts.size() == 1) {
      new_loop = stmts[0];
777
    } else {
778
      new_loop = SeqStmt(stmts);
779
780
781
    }

    if (!is_unit_loop) {
782
783
784
785
786
787
788
789
790
      Map<String, Any> preserved_annotations;
      for (const auto &kv : pipeline_loop_->annotations) {
        const String &key = kv.first;
        if (kv.first != tir::attr::software_pipeline_stage &&
            kv.first != tir::attr::software_pipeline_order &&
            kv.first != tir::attr::software_pipeline_async_stages) {
          preserved_annotations.Set(key, kv.second);
        }
      }
791
      new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
792
                     unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind,
793
                     std::move(new_loop), std::nullopt, preserved_annotations);
794
795
    }
    // Update producer heads in the global async states.
796
797
    for (const auto &[stage_id, state] : async_states_local) {
      async_states[stage_id].producer_head += extent;
798
799
    }

800
    return BlockRealize({}, Bool(true),
801
                        MakeBlock(new_loop, buffer_data_to_buffer_));
802
803
804
805
806
807
808
809
810
811
812
  }

  arith::Analyzer analyzer_;
  Map<Var, Buffer> buffer_data_to_buffer_;
  Array<Buffer> pipeline_allocs_;
  For pipeline_loop_;
  PipelineInfo pipeline_info_;
  int max_stage_ = -1;
  Map<Buffer, Buffer> buffer_remap_;
  Array<Block> ordered_stmts_;
  std::map<int, AsyncStateGlobal> async_states;
813
  std::vector<LetWrapper> loop_var_let_wrappers_;
814
815
816
817
818
};

/*!
 * \brief Build the dependency graph among a array of blocks.
 * \param[in] blocks The array of blocks.
819
820
821
 * \param[out] dep_src2dst Optional, a map to store dependency edges from the
 * source to the destination. \param[out] dep_dst2src Optional, a map to store
 * dependency edges from the destination to the source.
822
 */
823
824
825
826
827
void BuildDependencyGraph(const Array<Block> &blocks,
                          std::unordered_map<Block, Array<Block>, ObjectPtrHash,
                                             ObjectPtrEqual> *dep_src2dst,
                          std::unordered_map<Block, Array<Block>, ObjectPtrHash,
                                             ObjectPtrEqual> *dep_dst2src) {
828
829
  std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
      buffer_writers;
830
831
832

  for (const Block &block : blocks) {
    for (const BufferRegion &read : block->reads) {
833
834
      auto it = buffer_writers.find(read->buffer->data);
      if (it != buffer_writers.end()) {
835
        for (const Block &writer : it->second) {
836
837
838
839
840
841
842
843
844
          if (dep_src2dst != nullptr) {
            (*dep_src2dst)[writer].push_back(block);
          }
          if (dep_dst2src != nullptr) {
            (*dep_dst2src)[block].push_back(writer);
          }
        }
      }
    }
845
    for (const BufferRegion &write : block->writes) {
846
847
848
849
850
851
      buffer_writers[write->buffer->data].push_back(block);
    }
  }
}

class PipelineInjector : private StmtExprMutator {
852
853
public:
  static Stmt Inject(const PrimFunc &func) {
854
855
    auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
    PipelineInjector injector(global_symbol);
856
857
    for (const auto &kv : func->buffer_map) {
      const Buffer &buffer = kv.second;
858
859
860
861
862
      injector.buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    return injector(func->body);
  }

863
864
private:
  explicit PipelineInjector(Optional<String> global_symbol)
865
      : global_symbol_(std::move(global_symbol)) {}
866
867
868
869

  /*!
   * \brief Check the pipeline satisfies the following conditions:
   * 1. No conflicting order: The order of each statement should be unique.
870
871
872
873
   * 2. Reordering of statements doesn't break buffer access dependencies.
   * Specifically, for dependency (e.g. read-after-write) from statement A to
   * statement B, it requires: case 1: stage(A) < stage(B) case 2: stage(A) ==
   * stage(B) and order(A) < order(B)
874
   */
875
876
  void ValidatePipelineBody(const PipelineInfo &pipeline_info,
                            const Array<Block> &original_order) {
877
878
    std::unordered_set<int> used_orders;
    std::unordered_map<int, int> stage_max_order;
879
880
881
882
    std::unordered_map<int, const Block *> order_to_block;
    std::unordered_map<const Block *, int> block_to_stage;
    for (const Block &block : original_order) {
      const auto &stmt_info = pipeline_info.at(block);
883
884
      int order = stmt_info.order;
      CHECK(!used_orders.count(order))
885
886
          << "ValueError: Two statements in the software pipeline cannot have "
             "the same order";
887
888
889
      used_orders.insert(order);
    }

890
891
    std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
        dep_src2dst;
892
893
    BuildDependencyGraph(original_order, &dep_src2dst, nullptr);

894
895
896
897
898
899
    for (const auto &pair : dep_src2dst) {
      const Block &src = pair.first;
      const auto &src_info = pipeline_info.at(src);
      const Array<Block> &dsts = pair.second;
      for (const Block &dst : dsts) {
        const auto &dst_info = pipeline_info.at(dst);
900
901
        CHECK_LE(src_info.stage, dst_info.stage)
            << "ValueError: statement " << dst << " in stage " << dst_info.stage
902
903
            << " cannot depends on statement " << src << " in a later stage "
            << src_info.stage;
904
        if (src_info.stage == dst_info.stage) {
905
906
907
908
          CHECK_LT(src_info.order, dst_info.order)
              << "ValueError: two statements with buffer "
                 "access dependency in the same stage of the "
                 "software pipeline cannot be reordered";
909
910
911
912
913
        }
      }
    }
  }

914
  Stmt VisitStmt_(const ForNode *op) final {
915
916
917
    // Step 1: Recursively rewrite the children first.
    For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
    if (!HasPipelineAnnotation(op)) {
918
      return for_node;
919
    }
920
921
922
    // Step 2: Find the body and buffer allocations of the pipeline. The body
    // can be direct child of the for-loop. If the for-loop has BlockRealize as
    // its child, the pipeline body will be the child of the block.
923
924
    Stmt pipeline_body_root{nullptr};
    bool pipeline_body_from_block = false;
925
    Array<Buffer> pipeline_allocs;
926
927
928
    if (const auto *realize = for_node->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
929
930
931
        ICHECK(buffer->IsInstance<BufferNode>());
        buffer_data_to_buffer_.Set(buffer->data, buffer);
      }
932
      pipeline_body_root = block->body;
933
      pipeline_allocs = block->alloc_buffers;
934
      pipeline_body_from_block = true;
935
    } else {
936
937
938
939
940
      pipeline_body_root = for_node->body;
    }

    const SeqStmtNode *pipeline_body_seq = nullptr;
    std::vector<std::function<Stmt(Stmt)>> rewrap_fns;
941
    std::vector<LetWrapper> loop_var_let_wrappers;
942
    auto append_attr_wrapper = [&rewrap_fns](const AttrStmtNode *attr) {
943
      Any node = attr->node;
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
      String attr_key = attr->attr_key;
      PrimExpr value = attr->value;
      Span span = attr->span;
      rewrap_fns.emplace_back(
          [node = std::move(node), attr_key = std::move(attr_key),
           value = std::move(value), span](Stmt body) -> Stmt {
            return AttrStmt(node, attr_key, value, body, span);
          });
    };
    {
      Stmt current = pipeline_body_root;
      while (true) {
        if (const auto *seq_stmt = current.as<SeqStmtNode>()) {
          pipeline_body_seq = seq_stmt;
          break;
        }
        if (const auto *if_then_else = current.as<IfThenElseNode>()) {
          ICHECK(!if_then_else->else_case.defined())
              << "InjectSoftwarePipeline: Can't handle the body of the loop "
                 "because the IfThenElse node has an else branch";
          PrimExpr condition = if_then_else->condition;
          Span span = if_then_else->span;
          rewrap_fns.emplace_back(
              [condition = std::move(condition), span](Stmt body) -> Stmt {
                return IfThenElse(condition, body, Stmt(), span);
              });
          current = if_then_else->then_case;
          continue;
        }
        if (const auto *let_stmt = current.as<LetStmtNode>()) {
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
          // If this Let value uses the pipeline loop var, record it and push
          // inside each rewritten block later so the loop var can be
          // substituted with the correct per-iteration index. Otherwise, keep
          // it as a normal wrapper.
          bool uses_loop_var = UsesVar(
              let_stmt->value,
              [v = op->loop_var.get()](const VarNode *vn) { return vn == v; });
          if (uses_loop_var) {
            loop_var_let_wrappers.push_back({let_stmt->var, let_stmt->value});
          } else {
            Var var = let_stmt->var;
            PrimExpr value = let_stmt->value;
            Span span = let_stmt->span;
            rewrap_fns.emplace_back([var = std::move(var),
                                     value = std::move(value),
                                     span](Stmt body) -> Stmt {
              return LetStmt(var, value, body, span);
            });
          }
993
994
995
996
997
998
999
1000
1001
1002
1003
          current = let_stmt->body;
          continue;
        }
        if (const auto *attr = current.as<AttrStmtNode>()) {
          append_attr_wrapper(attr);
          current = attr->body;
          continue;
        }
        LOG(FATAL) << "ValueError: The body of the software pipeline should be "
                   << "SeqStmt, got " << current->GetTypeKey();
      }
1004
    }
1005
    ICHECK(pipeline_body_seq != nullptr);
1006

1007
1008
    // Step 3: Blockize the components of the pipeline. Each child of the
    // pipelined loop will be converted into a block.
1009
    PipelineInfo pipeline_info;
1010
    Array<Block> original_order; // pipeline body blocks in the original order
1011

1012
    auto f_add_child = [&](const Stmt &child) {
1013
1014
1015
      original_order.push_back(MakeBlock(child, buffer_data_to_buffer_));
    };
    for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) {
1016
1017
      const Stmt &child = pipeline_body_seq->seq[i];
      const auto *nested_block_realize = child.as<BlockRealizeNode>();
1018
1019
      if (nested_block_realize && is_one(nested_block_realize->predicate) &&
          nested_block_realize->block->body->IsInstance<SeqStmtNode>()) {
1020
1021
1022
1023
        const Block &nested_pipeline_block = nested_block_realize->block;
        ICHECK(nested_pipeline_block->match_buffers
                   .empty()); // match_buffer should have been lowered
        for (const auto &buffer : nested_pipeline_block->alloc_buffers) {
1024
1025
1026
1027
          pipeline_allocs.push_back(buffer);
          buffer_data_to_buffer_.Set(buffer->data, buffer);
        }
      }
1028
      f_add_child(child);
1029
1030
    }

1031
1032
1033
1034
    auto pipeline_stages = Downcast<Array<Integer>>(
        op->annotations.at(tir::attr::software_pipeline_stage));
    auto pipeline_orders = Downcast<Array<Integer>>(
        op->annotations.at(tir::attr::software_pipeline_order));
1035
1036
    CHECK_EQ(pipeline_stages.size(), original_order.size())
        << "PrimFunc " << global_symbol_ << " has original order "
1037
1038
1039
1040
        << original_order.Map(
               [](const auto &block) { return block->name_hint; })
        << ", but pipeline annotation is " << pipeline_stages
        << " with different size";
1041
1042
    CHECK_EQ(pipeline_orders.size(), original_order.size())
        << "PrimFunc " << global_symbol_ << " has original order "
1043
1044
1045
1046
        << original_order.Map(
               [](const auto &block) { return block->name_hint; })
        << ", but pipeline annotation is " << pipeline_orders
        << " with different size";
1047
1048

    std::unordered_set<int> pipeline_async_stages;
1049
1050
    if (auto annot =
            op->annotations.Get(tir::attr::software_pipeline_async_stages)) {
1051
      for (auto s : Downcast<Array<Integer>>(annot.value())) {
1052
1053
1054
1055
1056
1057
        pipeline_async_stages.insert(s->value);
      }
    }

    for (size_t i = 0; i < pipeline_stages.size(); i++) {
      int stage = static_cast<int>(pipeline_stages[i]->value);
1058
1059
1060
1061
      bool is_async =
          pipeline_async_stages.find(stage) != pipeline_async_stages.end();
      PipelineAnnotation stage_order{
          stage,
1062
1063
          /*order=*/static_cast<int>(pipeline_orders[i]->value), is_async,
          /*original_idx=*/static_cast<int>(i)};
1064
1065
1066
1067
1068
1069
      pipeline_info.emplace(original_order[i], stage_order);
    }

    ValidatePipelineBody(pipeline_info, original_order);

    // Step 4: Rewrite the pipeline body.
1070
    Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
1071
1072
                                     tvm::ffi::GetRef<For>(op), pipeline_info,
                                     loop_var_let_wrappers)
1073
                        .BuildPipeline();
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
    auto apply_wrappers = [&](Stmt stmt) {
      for (auto it = rewrap_fns.rbegin(); it != rewrap_fns.rend(); ++it) {
        stmt = (*it)(stmt);
      }
      return stmt;
    };
    if (!rewrap_fns.empty()) {
      if (pipeline_body_from_block) {
        BlockRealize pipeline_realize = Downcast<BlockRealize>(pipeline);
        Block pipeline_block = pipeline_realize->block;
        {
          BlockNode *block_node = pipeline_block.CopyOnWrite();
          block_node->body = apply_wrappers(block_node->body);
        }
        pipeline = BlockRealize(pipeline_realize->iter_values,
                                pipeline_realize->predicate, pipeline_block,
                                pipeline_realize->span);
      } else {
        pipeline = apply_wrappers(pipeline);
      }
    }
1095

1096
1097
1098
    if (const auto *realize = op->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
1099
1100
1101
1102
1103
1104
        buffer_data_to_buffer_.erase(buffer->data);
      }
    }
    return pipeline;
  }

1105
1106
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const auto &buffer : op->alloc_buffers) {
1107
1108
1109
1110
1111
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }

    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));

1112
1113
1114
1115
1116
1117
    Array<Array<BufferRegion>> access =
        GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
    BlockNode *n = block.CopyOnWrite();
    n->reads = access[0];
    n->writes = access[1];

1118
    for (const auto &buffer : op->alloc_buffers) {
1119
1120
      buffer_data_to_buffer_.erase(buffer->data);
    }
1121
    return block;
1122
1123
  }

1124
  bool HasPipelineAnnotation(const ForNode *op) const {
1125
1126
1127
1128
1129
1130
1131
1132
    auto it1 = op->annotations.find(tir::attr::software_pipeline_stage);
    auto it2 = op->annotations.find(tir::attr::software_pipeline_order);
    bool has_stage = it1 != op->annotations.end();
    bool has_order = it2 != op->annotations.end();
    if (has_stage && has_order) {
      return true;
    }
    if (has_stage) {
1133
      LOG(FATAL)
1134
          << "ValueError: Stage of the software pipeline is not defined.";
1135
1136
    }
    if (has_order) {
1137
      LOG(FATAL)
1138
          << "ValueError: Order of the software pipeline is not defined.";
1139
1140
1141
1142
1143
1144
1145
    }
    return false;
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Optional<String> global_symbol_;
};
1146
1147
} // namespace software_pipeline

1148
/*!
1149
1150
 * \brief Transform annotated loops into pipelined one that parallelize
 * producers and consumers. \return The IR transform pass.
1151
1152
1153
 */
tir::transform::Pass InjectSoftwarePipeline() {
  using namespace tir::transform;
1154
  auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
1155
    auto *fptr = f.CopyOnWrite();
1156
    fptr->body = software_pipeline::PipelineInjector::Inject(f);
1157
1158
1159
1160
1161
1162
    fptr->body = ConvertSSA(std::move(fptr->body));
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {});
}

1163
TVM_FFI_STATIC_INIT_BLOCK() {
1164
1165
1166
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.InjectSoftwarePipeline",
                        InjectSoftwarePipeline);
1167
}
1168

1169
1170
} // namespace tl
} // namespace tvm