inject_pipeline.cc 45.7 KB
Newer Older
1
2
/*!
 * \file inject_software_pipeline.cc
3
4
 * \brief Transform annotated loops into pipelined one that parallelize
 * producers and consumers
5
6
7
8
9
 */
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h>

10
#include <functional>
11
#include <unordered_set>
12
#include <utility>
13
14
15
16
17
18
19
20

#include "support/utils.h"
#include "tir/schedule/utils.h"
#include "tir/transforms/ir_utils.h"

namespace tvm {
namespace tl {
using namespace tir;
21
using namespace ffi;
22
23
namespace software_pipeline {

24
25
26
27
28
struct LetWrapper {
  Var var;
  PrimExpr value;
};

29
30
31
/*!
 * \brief Create a block and infer the access region with the given body.
 *
32
33
34
 * The result is a opaque block that doesn't contain any block iter vars. In
 * case the body is a block realize without predicate, it is unnecessary to
 * create a new block, the block of the block realize will be returned.
35
36
37
38
39
 *
 * \param body The body of the block.
 * \param buffer_data_to_buffer The map from buffer data to buffer.
 * \return The result block.
 */
40
41
42
Block MakeBlock(const Stmt &body,
                const Map<Var, Buffer> &buffer_data_to_buffer) {
  if (const BlockRealizeNode *block_realize = body.as<BlockRealizeNode>()) {
43
44
45
46
47
    if (is_one(block_realize->predicate)) {
      // no need to create a new block
      return block_realize->block;
    }
  }
48
49
50
51
52
  Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
              /*body*/ body);
  Array<Array<BufferRegion>> access =
      GetBlockReadWriteRegion(block, buffer_data_to_buffer);
  BlockNode *n = block.CopyOnWrite();
53
54
55
56
57
58
59
60
61
62
  n->reads = access[0];
  n->writes = access[1];
  return block;
}

/*! Structure that represents the provided annotation per block or loop. */
struct PipelineAnnotation {
  int stage;
  int order;
  bool async;
63
64
  // Index of the statement in the original loop body order (SeqStmt order)
  int original_idx = -1;
65
66
};

67
68
using PipelineInfo = std::unordered_map<Block, PipelineAnnotation,
                                        ObjectPtrHash, ObjectPtrEqual>;
69
70

struct BufferAccessInfo {
71
72
  int def = -1; // the defining stage of the buffer
  int use = -1; // the last using stage of the buffer
73
74
75
};

/*!
76
77
78
 * \brief Rewriter for the body of the software pipeline. This pass inserts
 * `floormod` to indices of the remapped buffer to select the version
 * corresponding to the pipeline stage.
79
80
 */
class PipelineBodyRewriter : public StmtExprMutator {
81
public:
82
83
84
  /*!
   * \brief Constructor of PipelineBodyRewriter.
   * \param buffer_data_to_buffer The map from buffer data to buffer.
85
86
87
88
89
90
   * \param buffer_remap The map from original buffer to the buffer with updated
   * shape for multi-versioning in the software pipeline. \param pipeline_loop
   * The original loop to be software pipelined. \param access_all_versions
   * Whether all versions the buffers in the software pipeline are accessed.
   * This will be used to update block access region. In the prologue and
   * epilogue of a two-stage software pipeline, only one version of these
91
   * buffers are accessed.
92
   */
93
94
95
  PipelineBodyRewriter(const Map<Var, Buffer> &buffer_data_to_buffer,
                       const Map<Buffer, Buffer> &buffer_remap,
                       For pipeline_loop, bool access_all_versions)
96
      : buffer_data_to_buffer_(buffer_data_to_buffer),
97
        buffer_remap_(buffer_remap), pipeline_loop_(std::move(pipeline_loop)),
98
        access_all_versions_(access_all_versions) {}
99

100
101
102
private:
  BufferRegion
  RewritePipelineBufferRegion(const BufferRegion &buffer_region) const {
103
104
105
    auto it = buffer_remap_.find(buffer_region->buffer);
    if (it != buffer_remap_.end()) {
      Region new_region = buffer_region->region;
106
107
108
      const Buffer &new_buffer = (*it).second;
      // For pipeline buffers, relax the access region of the first dimension to
      // full extent if access_all_versions == true
109
110
111
      Range accessed_version =
          access_all_versions_
              ? Range::FromMinExtent(0, new_buffer->shape[0])
112
113
114
115
              : Range::FromMinExtent(
                    floormod((pipeline_loop_->loop_var - pipeline_loop_->min),
                             new_buffer->shape[0]),
                    Integer(1));
116
117
118
119
120
121
      new_region.insert(new_region.begin(), accessed_version);
      return BufferRegion(new_buffer, new_region);
    }
    return buffer_region;
  }

122
  PrimExpr RewriteBufferAccess(const Call &call,
123
                               const std::vector<int> &arg_indices) {
124
125
    auto product = [](const Array<PrimExpr> &input) {
      return foldl(
126
127
128
          [](PrimExpr a, PrimExpr b, Span span) {
            return mul(std::move(a), std::move(b), std::move(span));
          },
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
          make_const(DataType::Int(32), 1), input);
    };
    Array<PrimExpr> new_args = call->args;
    for (int i : arg_indices) {
      const Buffer &buffer =
          buffer_data_to_buffer_.at(Downcast<Var>(call->args[i]));
      auto it = buffer_remap_.find(buffer);
      if (it != buffer_remap_.end()) {
        const Buffer &new_buffer = (*it).second;
        const PrimExpr &old_index = call->args[i + 1];
        PrimExpr offset;
        if (new_buffer->strides.empty()) {
          offset = product(buffer->shape);
        } else {
          offset = new_buffer->strides[0];
        }
        PrimExpr new_index =
            old_index +
            floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
        new_args.Set(i + 1, new_index);
      }
    }
qisan's avatar
qisan committed
151
    LOG(INFO) << "Rewriting buffer access " << call << " to " << Call(call->dtype, call->op, new_args, call->span);
152
153
154
    return Call(call->dtype, call->op, new_args, call->span);
  }

155
156
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const Buffer &alloc_buffer : op->alloc_buffers) {
157
158
159
      buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer);
    }
    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
160
161
    BlockNode *n = block.CopyOnWrite();
    n->reads.MutateByApply([this](const BufferRegion &buffer_region) {
162
163
      return RewritePipelineBufferRegion(buffer_region);
    });
164
    n->writes.MutateByApply([this](const BufferRegion &buffer_region) {
165
166
      return RewritePipelineBufferRegion(buffer_region);
    });
167
    for (const Buffer &alloc_buffer : op->alloc_buffers) {
168
169
      buffer_data_to_buffer_.erase(alloc_buffer->data);
    }
qisan's avatar
qisan committed
170
    LOG(INFO) << "Rewriting block " << GetRef<Block>(op) << " to " << GetRef<Block>(n);
171
    return block;
172
173
  }

174
  Stmt VisitStmt_(const BufferStoreNode *op) final {
175
176
177
    BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
    auto it = buffer_remap_.find(store->buffer);
    if (it == buffer_remap_.end()) {
178
      return store;
179
    }
180
181
    const Buffer &new_buffer = (*it).second;
    auto *n = store.CopyOnWrite();
182
    n->buffer = new_buffer;
183
184
    PrimExpr version = floormod(
        (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
185
    n->indices.insert(n->indices.begin(), version);
186
    return store;
187
188
  }

189
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
190
191
192
    BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
    auto it = buffer_remap_.find(load->buffer);
    if (it == buffer_remap_.end()) {
193
      return load;
194
    }
195
196
    const Buffer &new_buffer = (*it).second;
    auto *n = load.CopyOnWrite();
197
    n->buffer = new_buffer;
198
199
    PrimExpr version = floormod(
        (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
200
    n->indices.insert(n->indices.begin(), version);
201
    return load;
202
203
  }

204
  PrimExpr VisitExpr_(const CallNode *op) final {
205
    Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
206
207
208
209
    if (call->op.same_as(builtin::tvm_access_ptr())) {
      return RewriteBufferAccess(call, {1});
    }
    return call;
210
211
212
213
214
215
216
217
218
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Map<Buffer, Buffer> buffer_remap_;
  For pipeline_loop_;
  bool access_all_versions_;
};

/*!
219
220
 * \brief Rewriter for the software pipeline that rewrite a loop into a
 * pipelined one.
221
222
 */
class PipelineRewriter : public StmtExprMutator {
223
public:
224
225
  PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
                   const Array<Buffer> &pipeline_allocs,
226
227
                   const For &pipeline_loop, const PipelineInfo &pipeline_info,
                   const std::vector<LetWrapper> &loop_var_let_wrappers)
228
      : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
229
        pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop),
230
231
        pipeline_info_(pipeline_info),
        loop_var_let_wrappers_(loop_var_let_wrappers) {}
232
233

  Stmt BuildPipeline() {
234
235
236
237
238
    // Step 1: Analyze accesses to the buffers in the pipeline and compute the
    // number of versions need to maintain for each buffer.
    std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
        infos = GetBufferAccessInfo();
    for (const Buffer &buffer : pipeline_allocs_) {
239
240
241
242
243
244
      int num_versions = ComputeBufferVersions(buffer, infos.at(buffer));
      if (num_versions > 1) {
        buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions));
      }
    }
    ordered_stmts_.resize(pipeline_info_.size());
245
246
    for (const auto &[block, anno] : pipeline_info_) {
      ordered_stmts_.Set(anno.order, block);
247
248
    }

249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    for (const Block &block : ordered_stmts_) {
      int stage = pipeline_info_[block].stage;
      if (pipeline_info_[block].async) {
        auto &state = async_states[stage];
        state.producer_head = pipeline_loop_->min - 1;
        for (auto write_region : block->writes) {
          auto buffer = write_region->buffer;
          state.dst_buffers.insert(buffer.get());
          if (buffer_remap_.count(buffer))
            state.dst_buffers.insert(buffer_remap_[buffer].get());
        }
      }
    }
    std::unordered_set<int> consumed;
    for (const Block &block : ordered_stmts_) {
      int stage = pipeline_info_[block].stage;
      if (pipeline_info_[block].async) {
        auto &state = async_states[stage];
        if (state.commit_groups.empty() || consumed.count(stage)) {
          state.commit_groups.push_back({});
        }
        state.commit_groups.back().push_back(pipeline_info_[block].order);
        consumed.erase(stage);
        for (auto write_region : block->writes) {
          auto buffer = buffer_remap_.count(write_region->buffer)
                            ? buffer_remap_[write_region->buffer]
                            : write_region->buffer;
          state.buffer_to_commit_group_[buffer.get()] =
              state.commit_groups.size() - 1;
        }
      }

      for (auto read_region : block->reads) {
        for (const auto &[producer_stage_id, producer_state] : async_states) {
          if (producer_stage_id <= stage &&
              producer_state.writes(read_region->buffer)) {
            consumed.insert(producer_stage_id);
          }
        }
288
289
      }
    }
290
291

    // Step 2: Emit the pipeline prologue, body and epilogue.
292
293
294
295
296
297
298
299
300
301
302
    Stmt prologue =
        EmitImpl(pipeline_loop_->min, pipeline_loop_->min + max_stage_, true,
                 true, false);
    Stmt body = EmitImpl(pipeline_loop_->min + max_stage_,
                         pipeline_loop_->min + pipeline_loop_->extent, false,
                         false, false);

    Stmt epilogue =
        EmitImpl(pipeline_loop_->min + pipeline_loop_->extent,
                 pipeline_loop_->min + pipeline_loop_->extent + max_stage_,
                 true, true, true);
303
304
    SeqStmt stmt = SeqStmt({prologue, body, epilogue});

305
306
    // Step 3: Make a new block that contains new buffer allocations after
    // pipeline rewriting.
307
    Array<Buffer> alloc_buffers;
308
    for (const auto &alloc : pipeline_allocs_) {
309
310
311
312
313
      alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc));
      buffer_data_to_buffer_.erase(alloc->data);
    }
    Block block = MakeBlock(stmt, buffer_data_to_buffer_);
    block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers);
qisan's avatar
qisan committed
314
    LOG(INFO) << "Final rewritten pipeline block: " << block;
315
316
317
    return BlockRealize({}, Bool(true), block);
  }

318
private:
319
320
321
  /*!
   * \brief Analyze accesses to the buffers in the software pipeline.
   *
322
323
324
   * This method check the 'define' and 'use' stage of the buffers in the
   * software pipeline, which can be used to compute the number of versions
   * needed to maintain after rewriting.
325
326
327
   */
  std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
  GetBufferAccessInfo() {
328
329
330
331
    std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
        infos;
    for (const auto &pair : pipeline_info_) {
      const Block &block = pair.first;
332
333
334
      int stage = pair.second.stage;
      max_stage_ = std::max(max_stage_, stage);

335
      for (const BufferRegion &write : block->writes) {
336
337
338
        if (!infos.count(write->buffer)) {
          infos.emplace(write->buffer, BufferAccessInfo{});
        }
339
        auto &info = infos.at(write->buffer);
340
341
342
343
344
345
346
        if (info.def == -1) {
          info.def = stage;
        } else {
          info.def = std::min(info.def, stage);
        }
      }

347
      for (const BufferRegion &read : block->reads) {
348
349
350
        if (!infos.count(read->buffer)) {
          infos.emplace(read->buffer, BufferAccessInfo{});
        }
351
        auto &info = infos.at(read->buffer);
352
353
354
355
356
357
358
359
360
361
362
363
        info.use = std::max(info.use, stage);
      }
    }
    return infos;
  }

  /*!
   * \brief Check whether two regions have intersections.
   * \param region1 The first region.
   * \param region2 The second region.
   * \return Whether region1 and region2 have intersections.
   */
364
  bool MayConflict(const Region &region1, const Region &region2) {
365
366
367
368
369
370
371
372
373
374
375
376
377
378
    ICHECK(region1.size() == region2.size());
    for (size_t i = 0; i < region1.size(); i++) {
      Range dim1 = region1[i];
      Range dim2 = region2[i];
      auto int_set1 = arith::IntSet::FromRange(dim1);
      auto int_set2 = arith::IntSet::FromRange(dim2);
      if (arith::Intersect({int_set1, int_set2}).IsNothing()) {
        return false;
      }
    }
    return true;
  }

  /*!
379
380
   * \brief Compute the number of versions need to maintain for buffer accessed
   * in the software pipeline.
381
   *
382
383
384
385
386
387
   * This method applies liveness analysis to the target buffer to compute the
   * number of versions need to maintain during the software pipeline.
   * Annotation `attr::double_buffer_scope` is handled here which provides a way
   * to override the result of the analysis. Additional double buffering in the
   * software pipeline can be useful to eliminate synchronizations in GPU
   * devices.
388
389
390
391
392
   *
   * \param buffer The target buffer
   * \param buffer_info The access information of the target buffer.
   * \return The number of versions required for the target buffer.
   */
393
394
  int ComputeBufferVersions(const Buffer &buffer,
                            const BufferAccessInfo &buffer_info) {
395
    if (buffer_info.def == -1) {
396
397
      // Keep the original number of versions as buffers defined outside the
      // software pipeline should not be mutated.
398
399
400
401
      return 1;
    }

    // `use - def + 1` is a upper bound of the needed versions
402
403
    // We optimize a few case where the number of versions can be smaller than
    // the upper bound
404
    int num_versions = buffer_info.use - buffer_info.def + 1;
405
    if (num_versions >= 2) {
406
407
408
409
      // A special case when `use - def + 1 == 2`. Double buffering is only
      // needed in this case when these exists a reader block_i and a writer
      // block_j such that order(block_i) < order(block_j) and stage(block_i) <
      // stage(block_j) and the access regions of block_i and block_j overlap.
410
      bool need_multi_version = false;
411
412
413
      for (const auto &pair1 : pipeline_info_) {
        const Block &writer_block = pair1.first;
        const auto &writer_info = pair1.second;
414

415
416
417
        auto it1 = std::find_if(writer_block->writes.begin(),
                                writer_block->writes.end(),
                                [&](const BufferRegion &buffer_region) {
418
419
420
421
422
423
                                  return buffer_region->buffer.same_as(buffer);
                                });
        if (it1 == writer_block->writes.end()) {
          continue;
        }

424
425
426
427
428
429
430
431
        for (const auto &pair2 : pipeline_info_) {
          const Block &reader_block = pair2.first;
          const auto &reader_info = pair2.second;
          auto it2 = std::find_if(
              reader_block->reads.begin(), reader_block->reads.end(),
              [&](const BufferRegion &buffer_region) {
                return buffer_region->buffer.same_as(buffer);
              });
432
433
434
          if (it2 == reader_block->reads.end()) {
            continue;
          }
435
436
          if (writer_info.order < reader_info.order &&
              writer_info.stage < reader_info.stage &&
437
438
439
440
441
442
443
              MayConflict((*it1)->region, (*it2)->region)) {
            need_multi_version = true;
            break;
          }
        }
      }
      if (!need_multi_version) {
444
        num_versions--;
445
446
447
448
449
450
      }
    }
    return num_versions;
  }

  /*!
451
452
   * \brief Rewrite buffer allocation to keep multiple versions of original
   * buffer for pipelined accesses. \param buffer The buffer to be resized.
453
454
455
   * \param num_versions The number of versions to keep.
   * \return The resized buffer.
   */
456
  Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
457
458
    ObjectPtr<BufferNode> new_buffer =
        tvm::ffi::make_object<BufferNode>(*(buffer.get()));
459
    new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
460
    if (!new_buffer->strides.empty()) {
461
462
463
464
465
466
467
      ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
      PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1];
      new_buffer->strides.insert(new_buffer->strides.begin(), stride_0);
    }
    return Buffer(new_buffer);
  }

468
469
  // Per-stage states that need to be tracked across pipeline prologue, body,
  // and epilogue.
470
471
  struct AsyncStateGlobal {
    // Buffers that this stage asynchronously writes.
472
473
474
475
476
477
478
    std::unordered_set<const BufferNode *> dst_buffers;
    // An imaginary index that the latest async operation associated with this
    // stage has written into. Only valid if all associated predicates are true,
    // so that we can count the number of async invocations exactly. When it is
    // valid, it is the "sum of extents of loops that have been executed" - 1,
    // e.g. for epilogue it is prologue extent + body extent - 1. This is only
    // needed to compute wait count for epilogue without async producers.
479
480
481
    PrimExpr producer_head;
    std::vector<std::vector<int>> commit_groups;
    std::unordered_map<const BufferNode *, int> buffer_to_commit_group_;
482
483
484
    bool writes(const Buffer &buf) const {
      return dst_buffers.count(buf.get()) > 0;
    }
485
486
  };

487
488
  // Per-stage states that are local to each of pipeline prologue, body, and
  // epilogue.
489
  struct AsyncStateLocal {
490
    struct PendingWait {
491
492
      // The index into a list of blocks, where async_wait_queue should be
      // attached at the beginning.
493
      int insert_before;
494
495
      // in_flight_count would be a more precise name, but the implementation
      // uses wait_count for brevity.
496
497
498
      PrimExpr wait_count{nullptr};

      bool valid() const { return wait_count.defined(); }
499
500
501
    };

    std::vector<PendingWait> pending_waits;
502

503
504
    // A symbolic expression representing the index the latest async operation
    // associated with this stage has written into, at the "current" iteration.
505
    Optional<PrimExpr> producer_head;
506
507
    // the commit block's predicate
    PrimExpr commit_predicate{nullptr};
508
509
510
511
512
  };

  /*! Structure holding intermediate information for pipeline loop rewriting. */
  struct RewrittenBlockInfo {
    int stage;
513
    int order;
514
515
    PrimExpr start;
    PrimExpr end;
516
517
518
519
520
521
    PrimExpr predicate;
    Block block;
    PrimExpr access_index;
    bool is_async;
  };

522
  void PopulateWaitCounts(const std::vector<RewrittenBlockInfo> &new_blocks,
523
524
525
526
527
528
529
530
531
                          std::map<int, AsyncStateLocal> *async_states_local,
                          bool is_epilogue = false) {
    // Precompute which orders are present in this emit, and their access_index
    std::unordered_map<int, PrimExpr> order_to_access_index;
    std::unordered_set<int> present_orders;
    for (const auto &nb : new_blocks) {
      order_to_access_index[nb.order] = nb.access_index;
      present_orders.insert(nb.order);
    }
532
    for (size_t i = 0; i < new_blocks.size(); ++i) {
533
      // 1. Find the unique async producer stage
534
      int producer_stage_idx = -1;
535
      for (const auto &read_region : new_blocks[i].block->reads) {
536
537
538
        for (const auto &[stage, state] : async_states) {
          if (stage <= new_blocks[i].stage &&
              state.writes(read_region->buffer)) {
539
            // Currently only a single async stage dependency is supported
540
            ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage)
541
                << "A dependency on multiple async stages is not supported";
542
            producer_stage_idx = stage;
543
544
545
          }
        }
      }
546
547
      if (producer_stage_idx == -1) {
        // This block does not depend on any async producer
548
        continue;
549
      }
550
      const auto &state = async_states[producer_stage_idx];
551

552
      auto &dep_local_state = (*async_states_local)[producer_stage_idx];
553

554
555
      // 2. Use buffer_to_commit_group_ to find all actually dependent commit
      // groups
556
557
      std::unordered_set<int> dependent_groups;
      for (const auto &read_region : new_blocks[i].block->reads) {
558
559
560
561
        auto it = state.buffer_to_commit_group_.find(read_region->buffer.get());
        if (it != state.buffer_to_commit_group_.end()) {
          dependent_groups.insert(it->second);
        }
562
      }
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615

      // If there is no dependent commit group, no wait needs to be inserted
      if (dependent_groups.empty()) {
        continue;
      }

      // 3. Compute wait = max_g max(0, t_consumer - committed_before[g])
      PrimExpr t_consumer = new_blocks[i].access_index;
      PrimExpr wait_expr = make_zero(t_consumer.dtype());

      PrimExpr current_head = dep_local_state.producer_head.defined()
                                  ? dep_local_state.producer_head.value()
                                  : state.producer_head;
      int consumer_order = new_blocks[i].order;

      for (int g : dependent_groups) {
        const auto &group = state.commit_groups[g];
        if (group.empty())
          continue;
        int commit_order = group.back();
        bool commit_present = present_orders.count(commit_order) > 0;

        PrimExpr committed_before;
        if (commit_present && commit_order <= consumer_order) {
          // Commit point is in this iteration and earlier than the current
          // consumer; this iteration's head is visible
          auto commit_predicate = dep_local_state.commit_predicate;
          if (analyzer_.CanProve(!commit_predicate,
                                 arith::ProofStrength::kSymbolicBound)) {
            // it means the commit block is not executed in this iteration
            committed_before = new_blocks[i].start - 1;
          } else if (is_epilogue) {
            committed_before = new_blocks[i].start - 1;
          } else {
            committed_before = order_to_access_index.at(commit_order);
          }
        } else {
          // Commit point is later than the current consumer or not in this
          // iteration; only the previous iteration's head is visible
          if (dep_local_state.producer_head.defined()) {
            auto commit_predicate = dep_local_state.commit_predicate;
            if (analyzer_.CanProve(!commit_predicate,
                                   arith::ProofStrength::kSymbolicBound)) {
              committed_before = new_blocks[i].start - 1;
            } else if (is_epilogue) {
              committed_before = new_blocks[i].start - 1;
            } else {
              committed_before = current_head - 1;
            }
          }
        }

        wait_expr = analyzer_.Simplify(committed_before - t_consumer);
616
      }
617
618
619

      wait_expr = analyzer_.Simplify(wait_expr);
      dep_local_state.pending_waits.push_back({static_cast<int>(i), wait_expr});
620
621
622
    }
  }

623
624
  // Given pipelined blocks and async-related information, generate final loop
  // statements with async scopes (if any).
625
  Array<Stmt> CompletePipelineLoopStatements(
626
      const std::vector<RewrittenBlockInfo> &blocks,
627
      const std::map<int, AsyncStateLocal> &async_states_local) const {
628
    std::vector<RewrittenBlockInfo> new_blocks = blocks;
629
    for (const auto &[stage_id, state] : async_states_local) {
630
631
632
633
634
635
636
      for (const auto &pw : state.pending_waits) {
        auto &block = new_blocks[pw.insert_before].block;
        BlockNode *n = block.CopyOnWrite();
        auto zero = make_zero(DataType::Int(32));
        n->body = AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
                           AttrStmt(zero, tir::attr::async_wait_inflight_count,
                                    pw.wait_count, n->body));
qisan's avatar
qisan committed
637
638
639
        LOG(INFO) << "Inserting async_wait with count " << pw.wait_count
                  << " before block with order " << new_blocks[pw.insert_before].order
                  << " for async stage " << stage_id;
640
      }
641
    }
642

643
644
645
646
647
    // mark the last async stmt as commit
    std::unordered_set<int> commit_group_indices;
    for (const auto &[stage_id, state] : async_states) {
      for (size_t i = 0; i < state.commit_groups.size(); ++i) {
        commit_group_indices.insert(state.commit_groups[i].back());
648
649
650
651
652
      }
    }

    Array<Stmt> stmts;

653
654
655
656
657
658
659
    for (size_t i = 0; i < new_blocks.size(); i++) {
      Block block = new_blocks[i].block;
      if (commit_group_indices.count(new_blocks[i].order)) {
        auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)),
                                           tir::attr::async_commit_queue_scope,
                                           new_blocks[i].stage, block->body);
        block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_);
660
      }
661
      stmts.push_back(BlockRealize({}, new_blocks[i].predicate, block));
662
663
664
665
666
667
668
669
670
671
672
673
    }

    return stmts;
  }

  /*!
   * \brief Emit the pipeline loop in the given range.
   * \param start The start of the range
   * \param end The end of the range
   * \param unroll_loop Whether the loop should be unrolled.
   * \return The result loop.
   */
674
  Stmt EmitImpl(const PrimExpr &start, const PrimExpr &end, bool unroll_loop,
675
                bool need_bound_check, bool is_epilogue = false) {
676
677
    PrimExpr new_loop_var;
    PrimExpr extent = end - start;
678
679
680
    auto make_nop = []() {
      return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {}));
    };
681
682
683

    bool is_unit_loop = analyzer_.CanProveEqual(extent, 1);
    if (is_unit_loop) {
684
      new_loop_var = start; // use constants as the loop var for unit loops
685
686
    } else {
      new_loop_var = pipeline_loop_->loop_var.copy_with_suffix("");
687
688
689
690
691
692
693
694
695
696
697
698
699
700
      // Bind the iteration domain [start, end) to strengthen analyzer facts.
      analyzer_.Bind(Downcast<Var>(new_loop_var),
                     Range::FromMinExtent(start, end - start));
    }
    // Keep the bound constraints active for all analysis below.
    // Only meaningful when the loop var is symbolic (non-unit loop).
    std::unique_ptr<With<arith::ConstraintContext>> ctx_lb_guard;
    std::unique_ptr<With<arith::ConstraintContext>> ctx_ub_guard;
    if (!is_unit_loop) {
      Var loop_iter = Downcast<Var>(new_loop_var);
      ctx_lb_guard.reset(
          new With<arith::ConstraintContext>(&analyzer_, loop_iter >= start));
      ctx_ub_guard.reset(
          new With<arith::ConstraintContext>(&analyzer_, loop_iter < end));
701
702
703
704
705
706
707
    }

    std::vector<RewrittenBlockInfo> new_blocks;

    // Async related
    std::map<int, AsyncStateLocal> async_states_local;

708
    for (const Block &block : ordered_stmts_) {
709
      int stage = pipeline_info_.at(block).stage;
710
      int order = pipeline_info_.at(block).order;
711

712
      PrimExpr inbound = Bool(true);
713
      PrimExpr skewed_loop_var = new_loop_var - stage;
714
      if (need_bound_check)
715
716
717
718
        inbound = And(
            pipeline_loop_->min <= skewed_loop_var,
            (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent));

719
720
721
      Block new_block = Downcast<Block>(
          PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_,
                               pipeline_loop_, max_stage_ != 1)(block));
722
723
724
725

      PrimExpr delta = start - pipeline_loop_->min;
      // This variable corresponds to
      // - "producer_head" if this stage is an async producer
726
727
      // - "consumer_head" if this stage reads from asynchronously written
      // buffers.
728
      PrimExpr normalized_access_index =
729
          is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
730

731
732
      normalized_access_index = analyzer_.Simplify(normalized_access_index);

733
734
      // Adjust the block predicate and the body according to the final loop
      // bound
735
736
737
738
739
      //  [pipeline_loop_->min, extent).
      if (!is_unit_loop) {
        Var loop_iter = Downcast<Var>(new_loop_var);
        inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}});
      }
740
741
      new_block = Downcast<Block>(Substitute(
          new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
742

743
744
745
746
747
748
749
750
751
752
753
754
755
756
      // If there were Let-wrappers outside the original pipeline body that
      // depended on the pipeline loop var, push them into each rewritten
      // block with the correct per-block substitution.
      if (!loop_var_let_wrappers_.empty()) {
        BlockNode *n = new_block.CopyOnWrite();
        Stmt inner = n->body;
        for (const auto &lw : loop_var_let_wrappers_) {
          PrimExpr substituted = Substitute(
              lw.value, {{pipeline_loop_->loop_var, normalized_access_index}});
          inner = LetStmt(lw.var, substituted, inner);
        }
        n->body = inner;
      }

757
      if (pipeline_info_[block].async) {
758
        auto &local_state = async_states_local[stage];
759
        local_state.producer_head = normalized_access_index;
760
        local_state.commit_predicate = inbound;
761
762
763
        BlockNode *n = new_block.CopyOnWrite();
        n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope,
                           1, n->body);
764
765
      }

766
      new_blocks.push_back({stage, order, start, end, inbound, new_block,
767
                            normalized_access_index,
768
                            pipeline_info_[block].async});
769
    }
770

771
    PopulateWaitCounts(new_blocks, &async_states_local, is_epilogue);
772
773

    auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local);
774

775
776
    Stmt new_loop{nullptr};

777
    if (stmts.empty()) {
778
779
      return make_nop();
    }
780

781
782
    if (stmts.size() == 1) {
      new_loop = stmts[0];
783
    } else {
784
      new_loop = SeqStmt(stmts);
785
786
787
    }

    if (!is_unit_loop) {
788
789
790
791
792
793
794
795
796
      Map<String, Any> preserved_annotations;
      for (const auto &kv : pipeline_loop_->annotations) {
        const String &key = kv.first;
        if (kv.first != tir::attr::software_pipeline_stage &&
            kv.first != tir::attr::software_pipeline_order &&
            kv.first != tir::attr::software_pipeline_async_stages) {
          preserved_annotations.Set(key, kv.second);
        }
      }
797
      new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
798
                     unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind,
799
                     std::move(new_loop), std::nullopt, preserved_annotations);
800
801
    }
    // Update producer heads in the global async states.
802
803
    for (const auto &[stage_id, state] : async_states_local) {
      async_states[stage_id].producer_head += extent;
804
805
    }

806
    return BlockRealize({}, Bool(true),
807
                        MakeBlock(new_loop, buffer_data_to_buffer_));
808
809
810
811
812
813
814
815
816
817
818
  }

  arith::Analyzer analyzer_;
  Map<Var, Buffer> buffer_data_to_buffer_;
  Array<Buffer> pipeline_allocs_;
  For pipeline_loop_;
  PipelineInfo pipeline_info_;
  int max_stage_ = -1;
  Map<Buffer, Buffer> buffer_remap_;
  Array<Block> ordered_stmts_;
  std::map<int, AsyncStateGlobal> async_states;
819
  std::vector<LetWrapper> loop_var_let_wrappers_;
820
821
822
823
824
};

/*!
 * \brief Build the dependency graph among a array of blocks.
 * \param[in] blocks The array of blocks.
825
826
827
 * \param[out] dep_src2dst Optional, a map to store dependency edges from the
 * source to the destination. \param[out] dep_dst2src Optional, a map to store
 * dependency edges from the destination to the source.
828
 */
829
830
831
832
833
void BuildDependencyGraph(const Array<Block> &blocks,
                          std::unordered_map<Block, Array<Block>, ObjectPtrHash,
                                             ObjectPtrEqual> *dep_src2dst,
                          std::unordered_map<Block, Array<Block>, ObjectPtrHash,
                                             ObjectPtrEqual> *dep_dst2src) {
834
835
  std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
      buffer_writers;
836
837
838

  for (const Block &block : blocks) {
    for (const BufferRegion &read : block->reads) {
839
840
      auto it = buffer_writers.find(read->buffer->data);
      if (it != buffer_writers.end()) {
841
        for (const Block &writer : it->second) {
842
843
844
845
846
847
848
849
850
          if (dep_src2dst != nullptr) {
            (*dep_src2dst)[writer].push_back(block);
          }
          if (dep_dst2src != nullptr) {
            (*dep_dst2src)[block].push_back(writer);
          }
        }
      }
    }
851
    for (const BufferRegion &write : block->writes) {
852
853
854
855
856
857
      buffer_writers[write->buffer->data].push_back(block);
    }
  }
}

class PipelineInjector : private StmtExprMutator {
858
859
public:
  static Stmt Inject(const PrimFunc &func) {
860
861
    auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
    PipelineInjector injector(global_symbol);
862
863
    for (const auto &kv : func->buffer_map) {
      const Buffer &buffer = kv.second;
864
865
866
867
868
      injector.buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    return injector(func->body);
  }

869
870
private:
  explicit PipelineInjector(Optional<String> global_symbol)
871
      : global_symbol_(std::move(global_symbol)) {}
872
873
874
875

  /*!
   * \brief Check the pipeline satisfies the following conditions:
   * 1. No conflicting order: The order of each statement should be unique.
876
877
878
879
   * 2. Reordering of statements doesn't break buffer access dependencies.
   * Specifically, for dependency (e.g. read-after-write) from statement A to
   * statement B, it requires: case 1: stage(A) < stage(B) case 2: stage(A) ==
   * stage(B) and order(A) < order(B)
880
   */
881
882
  void ValidatePipelineBody(const PipelineInfo &pipeline_info,
                            const Array<Block> &original_order) {
883
884
    std::unordered_set<int> used_orders;
    std::unordered_map<int, int> stage_max_order;
885
886
887
888
    std::unordered_map<int, const Block *> order_to_block;
    std::unordered_map<const Block *, int> block_to_stage;
    for (const Block &block : original_order) {
      const auto &stmt_info = pipeline_info.at(block);
889
890
      int order = stmt_info.order;
      CHECK(!used_orders.count(order))
891
892
          << "ValueError: Two statements in the software pipeline cannot have "
             "the same order";
893
894
895
      used_orders.insert(order);
    }

896
897
    std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
        dep_src2dst;
898
899
    BuildDependencyGraph(original_order, &dep_src2dst, nullptr);

900
901
902
903
904
905
    for (const auto &pair : dep_src2dst) {
      const Block &src = pair.first;
      const auto &src_info = pipeline_info.at(src);
      const Array<Block> &dsts = pair.second;
      for (const Block &dst : dsts) {
        const auto &dst_info = pipeline_info.at(dst);
906
907
        CHECK_LE(src_info.stage, dst_info.stage)
            << "ValueError: statement " << dst << " in stage " << dst_info.stage
908
909
            << " cannot depends on statement " << src << " in a later stage "
            << src_info.stage;
910
        if (src_info.stage == dst_info.stage) {
911
912
913
914
          CHECK_LT(src_info.order, dst_info.order)
              << "ValueError: two statements with buffer "
                 "access dependency in the same stage of the "
                 "software pipeline cannot be reordered";
915
916
917
918
919
        }
      }
    }
  }

920
  Stmt VisitStmt_(const ForNode *op) final {
921
922
923
    // Step 1: Recursively rewrite the children first.
    For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
    if (!HasPipelineAnnotation(op)) {
924
      return for_node;
925
    }
926
927
928
    // Step 2: Find the body and buffer allocations of the pipeline. The body
    // can be direct child of the for-loop. If the for-loop has BlockRealize as
    // its child, the pipeline body will be the child of the block.
929
930
    Stmt pipeline_body_root{nullptr};
    bool pipeline_body_from_block = false;
931
    Array<Buffer> pipeline_allocs;
932
933
934
    if (const auto *realize = for_node->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
935
936
937
        ICHECK(buffer->IsInstance<BufferNode>());
        buffer_data_to_buffer_.Set(buffer->data, buffer);
      }
938
      pipeline_body_root = block->body;
939
      pipeline_allocs = block->alloc_buffers;
940
      pipeline_body_from_block = true;
941
    } else {
942
943
944
945
946
      pipeline_body_root = for_node->body;
    }

    const SeqStmtNode *pipeline_body_seq = nullptr;
    std::vector<std::function<Stmt(Stmt)>> rewrap_fns;
947
    std::vector<LetWrapper> loop_var_let_wrappers;
948
    auto append_attr_wrapper = [&rewrap_fns](const AttrStmtNode *attr) {
949
      Any node = attr->node;
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
      String attr_key = attr->attr_key;
      PrimExpr value = attr->value;
      Span span = attr->span;
      rewrap_fns.emplace_back(
          [node = std::move(node), attr_key = std::move(attr_key),
           value = std::move(value), span](Stmt body) -> Stmt {
            return AttrStmt(node, attr_key, value, body, span);
          });
    };
    {
      Stmt current = pipeline_body_root;
      while (true) {
        if (const auto *seq_stmt = current.as<SeqStmtNode>()) {
          pipeline_body_seq = seq_stmt;
          break;
        }
        if (const auto *if_then_else = current.as<IfThenElseNode>()) {
          ICHECK(!if_then_else->else_case.defined())
              << "InjectSoftwarePipeline: Can't handle the body of the loop "
                 "because the IfThenElse node has an else branch";
          PrimExpr condition = if_then_else->condition;
          Span span = if_then_else->span;
          rewrap_fns.emplace_back(
              [condition = std::move(condition), span](Stmt body) -> Stmt {
                return IfThenElse(condition, body, Stmt(), span);
              });
          current = if_then_else->then_case;
          continue;
        }
        if (const auto *let_stmt = current.as<LetStmtNode>()) {
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
          // If this Let value uses the pipeline loop var, record it and push
          // inside each rewritten block later so the loop var can be
          // substituted with the correct per-iteration index. Otherwise, keep
          // it as a normal wrapper.
          bool uses_loop_var = UsesVar(
              let_stmt->value,
              [v = op->loop_var.get()](const VarNode *vn) { return vn == v; });
          if (uses_loop_var) {
            loop_var_let_wrappers.push_back({let_stmt->var, let_stmt->value});
          } else {
            Var var = let_stmt->var;
            PrimExpr value = let_stmt->value;
            Span span = let_stmt->span;
            rewrap_fns.emplace_back([var = std::move(var),
                                     value = std::move(value),
                                     span](Stmt body) -> Stmt {
              return LetStmt(var, value, body, span);
            });
          }
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
          current = let_stmt->body;
          continue;
        }
        if (const auto *attr = current.as<AttrStmtNode>()) {
          append_attr_wrapper(attr);
          current = attr->body;
          continue;
        }
        LOG(FATAL) << "ValueError: The body of the software pipeline should be "
                   << "SeqStmt, got " << current->GetTypeKey();
      }
1010
    }
1011
    ICHECK(pipeline_body_seq != nullptr);
1012

1013
1014
    // Step 3: Blockize the components of the pipeline. Each child of the
    // pipelined loop will be converted into a block.
1015
    PipelineInfo pipeline_info;
1016
    Array<Block> original_order; // pipeline body blocks in the original order
1017

1018
    auto f_add_child = [&](const Stmt &child) {
1019
1020
1021
      original_order.push_back(MakeBlock(child, buffer_data_to_buffer_));
    };
    for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) {
1022
1023
      const Stmt &child = pipeline_body_seq->seq[i];
      const auto *nested_block_realize = child.as<BlockRealizeNode>();
1024
1025
      if (nested_block_realize && is_one(nested_block_realize->predicate) &&
          nested_block_realize->block->body->IsInstance<SeqStmtNode>()) {
1026
1027
1028
1029
        const Block &nested_pipeline_block = nested_block_realize->block;
        ICHECK(nested_pipeline_block->match_buffers
                   .empty()); // match_buffer should have been lowered
        for (const auto &buffer : nested_pipeline_block->alloc_buffers) {
1030
1031
1032
1033
          pipeline_allocs.push_back(buffer);
          buffer_data_to_buffer_.Set(buffer->data, buffer);
        }
      }
1034
      f_add_child(child);
1035
1036
    }

1037
1038
1039
1040
    auto pipeline_stages = Downcast<Array<Integer>>(
        op->annotations.at(tir::attr::software_pipeline_stage));
    auto pipeline_orders = Downcast<Array<Integer>>(
        op->annotations.at(tir::attr::software_pipeline_order));
1041
1042
    CHECK_EQ(pipeline_stages.size(), original_order.size())
        << "PrimFunc " << global_symbol_ << " has original order "
1043
1044
1045
1046
        << original_order.Map(
               [](const auto &block) { return block->name_hint; })
        << ", but pipeline annotation is " << pipeline_stages
        << " with different size";
1047
1048
    CHECK_EQ(pipeline_orders.size(), original_order.size())
        << "PrimFunc " << global_symbol_ << " has original order "
1049
1050
1051
1052
        << original_order.Map(
               [](const auto &block) { return block->name_hint; })
        << ", but pipeline annotation is " << pipeline_orders
        << " with different size";
1053
1054

    std::unordered_set<int> pipeline_async_stages;
1055
1056
    if (auto annot =
            op->annotations.Get(tir::attr::software_pipeline_async_stages)) {
1057
      for (auto s : Downcast<Array<Integer>>(annot.value())) {
1058
1059
1060
1061
1062
1063
        pipeline_async_stages.insert(s->value);
      }
    }

    for (size_t i = 0; i < pipeline_stages.size(); i++) {
      int stage = static_cast<int>(pipeline_stages[i]->value);
1064
1065
      bool is_async =
          pipeline_async_stages.find(stage) != pipeline_async_stages.end();
1066
1067
1068
      printf("Block %s assigned to stage %d with order %d%s\n", original_order[i]->name_hint.c_str(),
            stage, static_cast<int>(pipeline_orders[i]->value),
            is_async ? " (async)" : " sync");
1069
1070
      PipelineAnnotation stage_order{
          stage,
1071
1072
          /*order=*/static_cast<int>(pipeline_orders[i]->value), is_async,
          /*original_idx=*/static_cast<int>(i)};
1073
1074
1075
1076
1077
1078
      pipeline_info.emplace(original_order[i], stage_order);
    }

    ValidatePipelineBody(pipeline_info, original_order);

    // Step 4: Rewrite the pipeline body.
1079
    Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
1080
1081
                                     tvm::ffi::GetRef<For>(op), pipeline_info,
                                     loop_var_let_wrappers)
1082
                        .BuildPipeline();
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
    auto apply_wrappers = [&](Stmt stmt) {
      for (auto it = rewrap_fns.rbegin(); it != rewrap_fns.rend(); ++it) {
        stmt = (*it)(stmt);
      }
      return stmt;
    };
    if (!rewrap_fns.empty()) {
      if (pipeline_body_from_block) {
        BlockRealize pipeline_realize = Downcast<BlockRealize>(pipeline);
        Block pipeline_block = pipeline_realize->block;
        {
          BlockNode *block_node = pipeline_block.CopyOnWrite();
          block_node->body = apply_wrappers(block_node->body);
        }
        pipeline = BlockRealize(pipeline_realize->iter_values,
                                pipeline_realize->predicate, pipeline_block,
                                pipeline_realize->span);
      } else {
        pipeline = apply_wrappers(pipeline);
      }
    }
1104

1105
1106
1107
    if (const auto *realize = op->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
1108
1109
1110
        buffer_data_to_buffer_.erase(buffer->data);
      }
    }
qisan's avatar
qisan committed
1111
    LOG(INFO) << "Finished rewriting the pipeline loop with body:\n" << pipeline;
1112
1113
1114
    return pipeline;
  }

1115
1116
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const auto &buffer : op->alloc_buffers) {
1117
1118
1119
1120
1121
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }

    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));

1122
1123
1124
1125
1126
1127
    Array<Array<BufferRegion>> access =
        GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
    BlockNode *n = block.CopyOnWrite();
    n->reads = access[0];
    n->writes = access[1];

1128
    for (const auto &buffer : op->alloc_buffers) {
1129
1130
      buffer_data_to_buffer_.erase(buffer->data);
    }
qisan's avatar
qisan committed
1131
    LOG(INFO) << "Rewriting blockddd " << block;
1132
    return block;
1133
1134
  }

1135
  bool HasPipelineAnnotation(const ForNode *op) const {
1136
1137
1138
1139
1140
1141
1142
1143
    auto it1 = op->annotations.find(tir::attr::software_pipeline_stage);
    auto it2 = op->annotations.find(tir::attr::software_pipeline_order);
    bool has_stage = it1 != op->annotations.end();
    bool has_order = it2 != op->annotations.end();
    if (has_stage && has_order) {
      return true;
    }
    if (has_stage) {
1144
      LOG(FATAL)
1145
          << "ValueError: Stage of the software pipeline is not defined.";
1146
1147
    }
    if (has_order) {
1148
      LOG(FATAL)
1149
          << "ValueError: Order of the software pipeline is not defined.";
1150
1151
1152
1153
1154
1155
1156
    }
    return false;
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Optional<String> global_symbol_;
};
1157
1158
} // namespace software_pipeline

1159
/*!
1160
1161
 * \brief Transform annotated loops into pipelined one that parallelize
 * producers and consumers. \return The IR transform pass.
1162
1163
1164
 */
tir::transform::Pass InjectSoftwarePipeline() {
  using namespace tir::transform;
1165
  auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
1166
    auto *fptr = f.CopyOnWrite();
1167
    fptr->body = software_pipeline::PipelineInjector::Inject(f);
1168
    fptr->body = ConvertSSA(std::move(fptr->body));
qisan's avatar
qisan committed
1169
1170
    LOG(INFO) << "Finished injecting software pipeline for PrimFunc " << f->GetAttr<String>(tvm::attr::kGlobalSymbol).value_or("<unknown>")
              << ", the transformed body is:\n" << fptr->body;
1171
1172
1173
1174
1175
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {});
}

1176
TVM_FFI_STATIC_INIT_BLOCK() {
1177
1178
1179
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.InjectSoftwarePipeline",
                        InjectSoftwarePipeline);
1180
}
1181

1182
1183
} // namespace tl
} // namespace tvm