inject_pipeline.cc 41 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file inject_software_pipeline.cc
22
23
 * \brief Transform annotated loops into pipelined one that parallelize
 * producers and consumers
24
25
26
27
28
 */
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h>

29
#include <functional>
30
#include <unordered_set>
31
#include <utility>
32
33
34
35
36
37
38
39

#include "support/utils.h"
#include "tir/schedule/utils.h"
#include "tir/transforms/ir_utils.h"

namespace tvm {
namespace tl {
using namespace tir;
40
using namespace ffi;
41
42
namespace software_pipeline {

43
44
45
/*!
 * \brief Create a block and infer the access region with the given body.
 *
46
47
48
 * The result is a opaque block that doesn't contain any block iter vars. In
 * case the body is a block realize without predicate, it is unnecessary to
 * create a new block, the block of the block realize will be returned.
49
50
51
52
53
 *
 * \param body The body of the block.
 * \param buffer_data_to_buffer The map from buffer data to buffer.
 * \return The result block.
 */
54
55
56
Block MakeBlock(const Stmt &body,
                const Map<Var, Buffer> &buffer_data_to_buffer) {
  if (const BlockRealizeNode *block_realize = body.as<BlockRealizeNode>()) {
57
58
59
60
61
    if (is_one(block_realize->predicate)) {
      // no need to create a new block
      return block_realize->block;
    }
  }
62
63
64
65
66
  Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
              /*body*/ body);
  Array<Array<BufferRegion>> access =
      GetBlockReadWriteRegion(block, buffer_data_to_buffer);
  BlockNode *n = block.CopyOnWrite();
67
68
69
70
71
72
73
74
75
76
77
78
  n->reads = access[0];
  n->writes = access[1];
  return block;
}

/*! Structure that represents the provided annotation per block or loop. */
struct PipelineAnnotation {
  int stage;
  int order;
  bool async;
};

79
80
using PipelineInfo = std::unordered_map<Block, PipelineAnnotation,
                                        ObjectPtrHash, ObjectPtrEqual>;
81
82

struct BufferAccessInfo {
83
84
  int def = -1; // the defining stage of the buffer
  int use = -1; // the last using stage of the buffer
85
86
87
};

/*!
88
89
90
 * \brief Rewriter for the body of the software pipeline. This pass inserts
 * `floormod` to indices of the remapped buffer to select the version
 * corresponding to the pipeline stage.
91
92
 */
class PipelineBodyRewriter : public StmtExprMutator {
93
public:
94
95
96
  /*!
   * \brief Constructor of PipelineBodyRewriter.
   * \param buffer_data_to_buffer The map from buffer data to buffer.
97
98
99
100
101
102
   * \param buffer_remap The map from original buffer to the buffer with updated
   * shape for multi-versioning in the software pipeline. \param pipeline_loop
   * The original loop to be software pipelined. \param access_all_versions
   * Whether all versions the buffers in the software pipeline are accessed.
   * This will be used to update block access region. In the prologue and
   * epilogue of a two-stage software pipeline, only one version of these
103
   * buffers are accessed.
104
   */
105
106
107
  PipelineBodyRewriter(const Map<Var, Buffer> &buffer_data_to_buffer,
                       const Map<Buffer, Buffer> &buffer_remap,
                       For pipeline_loop, bool access_all_versions)
108
      : buffer_data_to_buffer_(buffer_data_to_buffer),
109
        buffer_remap_(buffer_remap), pipeline_loop_(std::move(pipeline_loop)),
110
        access_all_versions_(access_all_versions) {}
111

112
113
114
private:
  BufferRegion
  RewritePipelineBufferRegion(const BufferRegion &buffer_region) const {
115
116
117
    auto it = buffer_remap_.find(buffer_region->buffer);
    if (it != buffer_remap_.end()) {
      Region new_region = buffer_region->region;
118
119
120
      const Buffer &new_buffer = (*it).second;
      // For pipeline buffers, relax the access region of the first dimension to
      // full extent if access_all_versions == true
121
122
123
      Range accessed_version =
          access_all_versions_
              ? Range::FromMinExtent(0, new_buffer->shape[0])
124
125
126
127
              : Range::FromMinExtent(
                    floormod((pipeline_loop_->loop_var - pipeline_loop_->min),
                             new_buffer->shape[0]),
                    Integer(1));
128
129
130
131
132
133
      new_region.insert(new_region.begin(), accessed_version);
      return BufferRegion(new_buffer, new_region);
    }
    return buffer_region;
  }

134
  PrimExpr RewriteBufferAccess(const Call &call,
135
                               const std::vector<int> &arg_indices) {
136
137
    auto product = [](const Array<PrimExpr> &input) {
      return foldl(
138
139
140
          [](PrimExpr a, PrimExpr b, Span span) {
            return mul(std::move(a), std::move(b), std::move(span));
          },
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
          make_const(DataType::Int(32), 1), input);
    };
    Array<PrimExpr> new_args = call->args;
    for (int i : arg_indices) {
      const Buffer &buffer =
          buffer_data_to_buffer_.at(Downcast<Var>(call->args[i]));
      auto it = buffer_remap_.find(buffer);
      if (it != buffer_remap_.end()) {
        const Buffer &new_buffer = (*it).second;
        const PrimExpr &old_index = call->args[i + 1];
        PrimExpr offset;
        if (new_buffer->strides.empty()) {
          offset = product(buffer->shape);
        } else {
          offset = new_buffer->strides[0];
        }
        PrimExpr new_index =
            old_index +
            floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
        new_args.Set(i + 1, new_index);
      }
    }
    return Call(call->dtype, call->op, new_args, call->span);
  }

166
167
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const Buffer &alloc_buffer : op->alloc_buffers) {
168
169
170
      buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer);
    }
    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
171
172
    BlockNode *n = block.CopyOnWrite();
    n->reads.MutateByApply([this](const BufferRegion &buffer_region) {
173
174
      return RewritePipelineBufferRegion(buffer_region);
    });
175
    n->writes.MutateByApply([this](const BufferRegion &buffer_region) {
176
177
      return RewritePipelineBufferRegion(buffer_region);
    });
178
    for (const Buffer &alloc_buffer : op->alloc_buffers) {
179
180
      buffer_data_to_buffer_.erase(alloc_buffer->data);
    }
181
    return block;
182
183
  }

184
  Stmt VisitStmt_(const BufferStoreNode *op) final {
185
186
187
    BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
    auto it = buffer_remap_.find(store->buffer);
    if (it == buffer_remap_.end()) {
188
      return store;
189
    }
190
191
    const Buffer &new_buffer = (*it).second;
    auto *n = store.CopyOnWrite();
192
    n->buffer = new_buffer;
193
194
    PrimExpr version = floormod(
        (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
195
    n->indices.insert(n->indices.begin(), version);
196
    return store;
197
198
  }

199
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
200
201
202
    BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
    auto it = buffer_remap_.find(load->buffer);
    if (it == buffer_remap_.end()) {
203
      return load;
204
    }
205
206
    const Buffer &new_buffer = (*it).second;
    auto *n = load.CopyOnWrite();
207
    n->buffer = new_buffer;
208
209
    PrimExpr version = floormod(
        (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
210
    n->indices.insert(n->indices.begin(), version);
211
    return load;
212
213
  }

214
  PrimExpr VisitExpr_(const CallNode *op) final {
215
    Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
216
217
218
219
    if (call->op.same_as(builtin::tvm_access_ptr())) {
      return RewriteBufferAccess(call, {1});
    }
    return call;
220
221
222
223
224
225
226
227
228
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Map<Buffer, Buffer> buffer_remap_;
  For pipeline_loop_;
  bool access_all_versions_;
};

/*!
229
230
 * \brief Rewriter for the software pipeline that rewrite a loop into a
 * pipelined one.
231
232
 */
class PipelineRewriter : public StmtExprMutator {
233
public:
234
235
236
  PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
                   const Array<Buffer> &pipeline_allocs,
                   const For &pipeline_loop, const PipelineInfo &pipeline_info)
237
      : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
238
239
        pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop),
        pipeline_info_(pipeline_info) {}
240
241

  Stmt BuildPipeline() {
242
243
244
245
246
    // Step 1: Analyze accesses to the buffers in the pipeline and compute the
    // number of versions need to maintain for each buffer.
    std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
        infos = GetBufferAccessInfo();
    for (const Buffer &buffer : pipeline_allocs_) {
247
248
249
250
251
252
      int num_versions = ComputeBufferVersions(buffer, infos.at(buffer));
      if (num_versions > 1) {
        buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions));
      }
    }
    ordered_stmts_.resize(pipeline_info_.size());
253
254
    for (const auto &[block, anno] : pipeline_info_) {
      ordered_stmts_.Set(anno.order, block);
255
256
    }

257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
    for (const Block &block : ordered_stmts_) {
      int stage = pipeline_info_[block].stage;
      if (pipeline_info_[block].async) {
        auto &state = async_states[stage];
        state.producer_head = pipeline_loop_->min - 1;
        for (auto write_region : block->writes) {
          auto buffer = write_region->buffer;
          state.dst_buffers.insert(buffer.get());
          if (buffer_remap_.count(buffer))
            state.dst_buffers.insert(buffer_remap_[buffer].get());
        }
      }
    }
    std::unordered_set<int> consumed;
    for (const Block &block : ordered_stmts_) {
      int stage = pipeline_info_[block].stage;
      if (pipeline_info_[block].async) {
        auto &state = async_states[stage];
        if (state.commit_groups.empty() || consumed.count(stage)) {
          state.commit_groups.push_back({});
        }
        state.commit_groups.back().push_back(pipeline_info_[block].order);
        consumed.erase(stage);
        for (auto write_region : block->writes) {
          auto buffer = buffer_remap_.count(write_region->buffer)
                            ? buffer_remap_[write_region->buffer]
                            : write_region->buffer;
          state.buffer_to_commit_group_[buffer.get()] =
              state.commit_groups.size() - 1;
        }
      }

      for (auto read_region : block->reads) {
        for (const auto &[producer_stage_id, producer_state] : async_states) {
          if (producer_stage_id <= stage &&
              producer_state.writes(read_region->buffer)) {
            consumed.insert(producer_stage_id);
          }
        }
296
297
      }
    }
298
299
300
301
302
303
304
305
306
307

    // Step 2: Emit the pipeline prologue, body and epilogue.
    Stmt prologue = EmitImpl(pipeline_loop_->min,
                             pipeline_loop_->min + max_stage_, true, true);
    Stmt body =
        EmitImpl(pipeline_loop_->min + max_stage_,
                 pipeline_loop_->min + pipeline_loop_->extent, false, false);
    Stmt epilogue = EmitImpl(
        pipeline_loop_->min + pipeline_loop_->extent,
        pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true);
308
309
310

    SeqStmt stmt = SeqStmt({prologue, body, epilogue});

311
312
    // Step 3: Make a new block that contains new buffer allocations after
    // pipeline rewriting.
313
    Array<Buffer> alloc_buffers;
314
    for (const auto &alloc : pipeline_allocs_) {
315
316
317
318
319
320
321
322
      alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc));
      buffer_data_to_buffer_.erase(alloc->data);
    }
    Block block = MakeBlock(stmt, buffer_data_to_buffer_);
    block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers);
    return BlockRealize({}, Bool(true), block);
  }

323
private:
324
325
326
  /*!
   * \brief Analyze accesses to the buffers in the software pipeline.
   *
327
328
329
   * This method check the 'define' and 'use' stage of the buffers in the
   * software pipeline, which can be used to compute the number of versions
   * needed to maintain after rewriting.
330
331
332
   */
  std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
  GetBufferAccessInfo() {
333
334
335
336
    std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
        infos;
    for (const auto &pair : pipeline_info_) {
      const Block &block = pair.first;
337
338
339
      int stage = pair.second.stage;
      max_stage_ = std::max(max_stage_, stage);

340
      for (const BufferRegion &write : block->writes) {
341
342
343
        if (!infos.count(write->buffer)) {
          infos.emplace(write->buffer, BufferAccessInfo{});
        }
344
        auto &info = infos.at(write->buffer);
345
346
347
348
349
350
351
        if (info.def == -1) {
          info.def = stage;
        } else {
          info.def = std::min(info.def, stage);
        }
      }

352
      for (const BufferRegion &read : block->reads) {
353
354
355
        if (!infos.count(read->buffer)) {
          infos.emplace(read->buffer, BufferAccessInfo{});
        }
356
        auto &info = infos.at(read->buffer);
357
358
359
360
361
362
363
364
365
366
367
368
        info.use = std::max(info.use, stage);
      }
    }
    return infos;
  }

  /*!
   * \brief Check whether two regions have intersections.
   * \param region1 The first region.
   * \param region2 The second region.
   * \return Whether region1 and region2 have intersections.
   */
369
  bool MayConflict(const Region &region1, const Region &region2) {
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    ICHECK(region1.size() == region2.size());
    for (size_t i = 0; i < region1.size(); i++) {
      Range dim1 = region1[i];
      Range dim2 = region2[i];
      auto int_set1 = arith::IntSet::FromRange(dim1);
      auto int_set2 = arith::IntSet::FromRange(dim2);
      if (arith::Intersect({int_set1, int_set2}).IsNothing()) {
        return false;
      }
    }
    return true;
  }

  /*!
384
385
   * \brief Compute the number of versions need to maintain for buffer accessed
   * in the software pipeline.
386
   *
387
388
389
390
391
392
   * This method applies liveness analysis to the target buffer to compute the
   * number of versions need to maintain during the software pipeline.
   * Annotation `attr::double_buffer_scope` is handled here which provides a way
   * to override the result of the analysis. Additional double buffering in the
   * software pipeline can be useful to eliminate synchronizations in GPU
   * devices.
393
394
395
396
397
   *
   * \param buffer The target buffer
   * \param buffer_info The access information of the target buffer.
   * \return The number of versions required for the target buffer.
   */
398
399
  int ComputeBufferVersions(const Buffer &buffer,
                            const BufferAccessInfo &buffer_info) {
400
    if (buffer_info.def == -1) {
401
402
      // Keep the original number of versions as buffers defined outside the
      // software pipeline should not be mutated.
403
404
405
406
      return 1;
    }

    // `use - def + 1` is a upper bound of the needed versions
407
408
    // We optimize a few case where the number of versions can be smaller than
    // the upper bound
409
    int num_versions = buffer_info.use - buffer_info.def + 1;
410
    if (num_versions >= 2) {
411
412
413
414
      // A special case when `use - def + 1 == 2`. Double buffering is only
      // needed in this case when these exists a reader block_i and a writer
      // block_j such that order(block_i) < order(block_j) and stage(block_i) <
      // stage(block_j) and the access regions of block_i and block_j overlap.
415
      bool need_multi_version = false;
416
417
418
      for (const auto &pair1 : pipeline_info_) {
        const Block &writer_block = pair1.first;
        const auto &writer_info = pair1.second;
419

420
421
422
        auto it1 = std::find_if(writer_block->writes.begin(),
                                writer_block->writes.end(),
                                [&](const BufferRegion &buffer_region) {
423
424
425
426
427
428
                                  return buffer_region->buffer.same_as(buffer);
                                });
        if (it1 == writer_block->writes.end()) {
          continue;
        }

429
430
431
432
433
434
435
436
        for (const auto &pair2 : pipeline_info_) {
          const Block &reader_block = pair2.first;
          const auto &reader_info = pair2.second;
          auto it2 = std::find_if(
              reader_block->reads.begin(), reader_block->reads.end(),
              [&](const BufferRegion &buffer_region) {
                return buffer_region->buffer.same_as(buffer);
              });
437
438
439
          if (it2 == reader_block->reads.end()) {
            continue;
          }
440
441
          if (writer_info.order < reader_info.order &&
              writer_info.stage < reader_info.stage &&
442
443
444
445
446
447
448
              MayConflict((*it1)->region, (*it2)->region)) {
            need_multi_version = true;
            break;
          }
        }
      }
      if (!need_multi_version) {
449
        num_versions--;
450
451
452
453
454
455
      }
    }
    return num_versions;
  }

  /*!
456
457
   * \brief Rewrite buffer allocation to keep multiple versions of original
   * buffer for pipelined accesses. \param buffer The buffer to be resized.
458
459
460
   * \param num_versions The number of versions to keep.
   * \return The resized buffer.
   */
461
  Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
462
463
    ObjectPtr<BufferNode> new_buffer =
        tvm::ffi::make_object<BufferNode>(*(buffer.get()));
464
    new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
465
    if (!new_buffer->strides.empty()) {
466
467
468
469
470
471
472
      ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
      PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1];
      new_buffer->strides.insert(new_buffer->strides.begin(), stride_0);
    }
    return Buffer(new_buffer);
  }

473
474
  // Per-stage states that need to be tracked across pipeline prologue, body,
  // and epilogue.
475
476
  struct AsyncStateGlobal {
    // Buffers that this stage asynchronously writes.
477
478
479
480
481
482
483
    std::unordered_set<const BufferNode *> dst_buffers;
    // An imaginary index that the latest async operation associated with this
    // stage has written into. Only valid if all associated predicates are true,
    // so that we can count the number of async invocations exactly. When it is
    // valid, it is the "sum of extents of loops that have been executed" - 1,
    // e.g. for epilogue it is prologue extent + body extent - 1. This is only
    // needed to compute wait count for epilogue without async producers.
484
485
486
    PrimExpr producer_head;
    std::vector<std::vector<int>> commit_groups;
    std::unordered_map<const BufferNode *, int> buffer_to_commit_group_;
487
488
489
    bool writes(const Buffer &buf) const {
      return dst_buffers.count(buf.get()) > 0;
    }
490
491
  };

492
493
  // Per-stage states that are local to each of pipeline prologue, body, and
  // epilogue.
494
  struct AsyncStateLocal {
495
    struct PendingWait {
496
497
      // The index into a list of blocks, where async_wait_queue should be
      // attached at the beginning.
498
      int insert_before;
499
500
      // in_flight_count would be a more precise name, but the implementation
      // uses wait_count for brevity.
501
502
503
      PrimExpr wait_count{nullptr};

      bool valid() const { return wait_count.defined(); }
504
505
506
    };

    std::vector<PendingWait> pending_waits;
507

508
509
    // A symbolic expression representing the index the latest async operation
    // associated with this stage has written into, at the "current" iteration.
510
511
512
513
514
515
    Optional<PrimExpr> producer_head;
  };

  /*! Structure holding intermediate information for pipeline loop rewriting. */
  struct RewrittenBlockInfo {
    int stage;
516
    int order;
517
518
519
520
521
522
    PrimExpr predicate;
    Block block;
    PrimExpr access_index;
    bool is_async;
  };

523
524
  void PopulateWaitCounts(const std::vector<RewrittenBlockInfo> &new_blocks,
                          std::map<int, AsyncStateLocal> *async_states_local) {
525
526
527
    for (size_t i = 0; i < new_blocks.size(); ++i) {
      int producer_stage_idx = -1;
      for (auto read_region : new_blocks[i].block->reads) {
528
529
530
        for (const auto &[stage, state] : async_states) {
          if (stage <= new_blocks[i].stage &&
              state.writes(read_region->buffer)) {
531
532
            // Found an earlier stage where read_region->buffer was
            // asynchronously written
533
            ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage)
534
                << "A dependency on multiple async stages is not supported";
535
            producer_stage_idx = stage;
536
537
538
          }
        }
      }
539
540
      if (producer_stage_idx == -1)
        continue;
541
      const auto &state = async_states[producer_stage_idx];
542
      auto &dep_local_state = (*async_states_local)[producer_stage_idx];
543
544
545
546
547
548
549
550
551
552
553
      PrimExpr in_flight_cnt = 0;
      for (const auto &group : state.commit_groups) {
        PrimExpr consumer_head = new_blocks[i].access_index;
        PrimExpr producer_head;
        if (dep_local_state.producer_head.defined()) {
          producer_head = dep_local_state.producer_head.value();
          // if the group is after the wait point, minus by 1
          if (group.front() > new_blocks[i].order)
            producer_head -= 1;
        } else {
          producer_head = state.producer_head;
554
        }
555
        in_flight_cnt += producer_head - consumer_head;
556
557
      }

558
559
560
561
562
563
564
565
566
567
568
569
      // We can relax the in-flight-count by the number of independent commit.
      std::unordered_set<int> dependent_groups;
      for (const auto &read_region : new_blocks[i].block->reads) {
        if (state.buffer_to_commit_group_.count(read_region->buffer.get()))
          dependent_groups.insert(
              state.buffer_to_commit_group_.at(read_region->buffer.get()));
      }
      for (int i = int(state.commit_groups.size()) - 1; i >= 0; i--) {
        if (dependent_groups.count(i) == 0)
          in_flight_cnt += 1;
        else
          break; // stop relaxing
570
      }
571
572
573
      in_flight_cnt = analyzer_.Simplify(in_flight_cnt);
      dep_local_state.pending_waits.push_back(
          {static_cast<int>(i), in_flight_cnt});
574
575
576
    }
  }

577
578
  // Given pipelined blocks and async-related information, generate final loop
  // statements with async scopes (if any).
579
  Array<Stmt> CompletePipelineLoopStatements(
580
      const std::vector<RewrittenBlockInfo> &blocks,
581
      const std::map<int, AsyncStateLocal> &async_states_local) const {
582
    std::vector<RewrittenBlockInfo> new_blocks = blocks;
583
    for (const auto &[stage_id, state] : async_states_local) {
584
585
586
587
588
589
590
      for (const auto &pw : state.pending_waits) {
        auto &block = new_blocks[pw.insert_before].block;
        BlockNode *n = block.CopyOnWrite();
        auto zero = make_zero(DataType::Int(32));
        n->body = AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
                           AttrStmt(zero, tir::attr::async_wait_inflight_count,
                                    pw.wait_count, n->body));
591
      }
592
    }
593

594
595
596
597
598
    // mark the last async stmt as commit
    std::unordered_set<int> commit_group_indices;
    for (const auto &[stage_id, state] : async_states) {
      for (size_t i = 0; i < state.commit_groups.size(); ++i) {
        commit_group_indices.insert(state.commit_groups[i].back());
599
600
601
602
603
      }
    }

    Array<Stmt> stmts;

604
605
606
607
608
609
610
    for (size_t i = 0; i < new_blocks.size(); i++) {
      Block block = new_blocks[i].block;
      if (commit_group_indices.count(new_blocks[i].order)) {
        auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)),
                                           tir::attr::async_commit_queue_scope,
                                           new_blocks[i].stage, block->body);
        block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_);
611
      }
612
      stmts.push_back(BlockRealize({}, new_blocks[i].predicate, block));
613
614
615
616
617
618
619
620
621
622
623
624
    }

    return stmts;
  }

  /*!
   * \brief Emit the pipeline loop in the given range.
   * \param start The start of the range
   * \param end The end of the range
   * \param unroll_loop Whether the loop should be unrolled.
   * \return The result loop.
   */
625
  Stmt EmitImpl(const PrimExpr &start, const PrimExpr &end, bool unroll_loop,
626
                bool need_bound_check) {
627
628
    PrimExpr new_loop_var;
    PrimExpr extent = end - start;
629
630
631
    auto make_nop = []() {
      return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {}));
    };
632
633
634

    bool is_unit_loop = analyzer_.CanProveEqual(extent, 1);
    if (is_unit_loop) {
635
      new_loop_var = start; // use constants as the loop var for unit loops
636
637
638
639
640
641
642
643
644
645
    } else {
      new_loop_var = pipeline_loop_->loop_var.copy_with_suffix("");
      analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end));
    }

    std::vector<RewrittenBlockInfo> new_blocks;

    // Async related
    std::map<int, AsyncStateLocal> async_states_local;

646
    for (const Block &block : ordered_stmts_) {
647
      int stage = pipeline_info_.at(block).stage;
648
649
      int order = pipeline_info_.at(block).order;
      PrimExpr inbound = Bool(true);
650
      PrimExpr skewed_loop_var = new_loop_var - stage;
651
652
653
654
      if (need_bound_check)
        inbound =
            analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) &&
            (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent);
655
656
657
      if (analyzer_.CanProve(!inbound)) {
        continue;
      }
658
659
660
      Block new_block = Downcast<Block>(
          PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_,
                               pipeline_loop_, max_stage_ != 1)(block));
661
662
663
664

      PrimExpr delta = start - pipeline_loop_->min;
      // This variable corresponds to
      // - "producer_head" if this stage is an async producer
665
666
      // - "consumer_head" if this stage reads from asynchronously written
      // buffers.
667
      PrimExpr normalized_access_index =
668
          is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
669

670
671
      // Adjust the block predicate and the body according to the final loop
      // bound
672
673
674
675
676
      //  [pipeline_loop_->min, extent).
      if (!is_unit_loop) {
        Var loop_iter = Downcast<Var>(new_loop_var);
        inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}});
      }
677
678
      new_block = Downcast<Block>(Substitute(
          new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
679

680
      if (pipeline_info_[block].async) {
681
        auto &local_state = async_states_local[stage];
682
        local_state.producer_head = normalized_access_index;
683
684
685
        BlockNode *n = new_block.CopyOnWrite();
        n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope,
                           1, n->body);
686
687
      }

688
689
      new_blocks.push_back({stage, order, inbound, new_block,
                            normalized_access_index,
690
                            pipeline_info_[block].async});
691
    }
692

693
694
695
    PopulateWaitCounts(new_blocks, &async_states_local);

    auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local);
696

697
698
    Stmt new_loop{nullptr};

699
    if (stmts.empty()) {
700
701
      return make_nop();
    }
702

703
704
    if (stmts.size() == 1) {
      new_loop = stmts[0];
705
    } else {
706
      new_loop = SeqStmt(stmts);
707
708
709
    }

    if (!is_unit_loop) {
710
711
712
713
714
715
716
717
718
      Map<String, Any> preserved_annotations;
      for (const auto &kv : pipeline_loop_->annotations) {
        const String &key = kv.first;
        if (kv.first != tir::attr::software_pipeline_stage &&
            kv.first != tir::attr::software_pipeline_order &&
            kv.first != tir::attr::software_pipeline_async_stages) {
          preserved_annotations.Set(key, kv.second);
        }
      }
719
      new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
720
                     unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind,
721
                     std::move(new_loop), std::nullopt, preserved_annotations);
722
723
    }
    // Update producer heads in the global async states.
724
725
    for (const auto &[stage_id, state] : async_states_local) {
      async_states[stage_id].producer_head += extent;
726
727
    }

728
    return BlockRealize({}, Bool(true),
729
                        MakeBlock(new_loop, buffer_data_to_buffer_));
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
  }

  arith::Analyzer analyzer_;
  Map<Var, Buffer> buffer_data_to_buffer_;
  Array<Buffer> pipeline_allocs_;
  For pipeline_loop_;
  PipelineInfo pipeline_info_;
  int max_stage_ = -1;
  Map<Buffer, Buffer> buffer_remap_;
  Array<Block> ordered_stmts_;
  std::map<int, AsyncStateGlobal> async_states;
};

/*!
 * \brief Build the dependency graph among a array of blocks.
 * \param[in] blocks The array of blocks.
746
747
748
 * \param[out] dep_src2dst Optional, a map to store dependency edges from the
 * source to the destination. \param[out] dep_dst2src Optional, a map to store
 * dependency edges from the destination to the source.
749
 */
750
751
752
753
754
void BuildDependencyGraph(const Array<Block> &blocks,
                          std::unordered_map<Block, Array<Block>, ObjectPtrHash,
                                             ObjectPtrEqual> *dep_src2dst,
                          std::unordered_map<Block, Array<Block>, ObjectPtrHash,
                                             ObjectPtrEqual> *dep_dst2src) {
755
756
  std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
      buffer_writers;
757
758
759

  for (const Block &block : blocks) {
    for (const BufferRegion &read : block->reads) {
760
761
      auto it = buffer_writers.find(read->buffer->data);
      if (it != buffer_writers.end()) {
762
        for (const Block &writer : it->second) {
763
764
765
766
767
768
769
770
771
          if (dep_src2dst != nullptr) {
            (*dep_src2dst)[writer].push_back(block);
          }
          if (dep_dst2src != nullptr) {
            (*dep_dst2src)[block].push_back(writer);
          }
        }
      }
    }
772
    for (const BufferRegion &write : block->writes) {
773
774
775
776
777
778
      buffer_writers[write->buffer->data].push_back(block);
    }
  }
}

class PipelineInjector : private StmtExprMutator {
779
780
public:
  static Stmt Inject(const PrimFunc &func) {
781
782
    auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
    PipelineInjector injector(global_symbol);
783
784
    for (const auto &kv : func->buffer_map) {
      const Buffer &buffer = kv.second;
785
786
787
788
789
      injector.buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    return injector(func->body);
  }

790
791
private:
  explicit PipelineInjector(Optional<String> global_symbol)
792
      : global_symbol_(std::move(global_symbol)) {}
793
794
795
796

  /*!
   * \brief Check the pipeline satisfies the following conditions:
   * 1. No conflicting order: The order of each statement should be unique.
797
798
799
800
   * 2. Reordering of statements doesn't break buffer access dependencies.
   * Specifically, for dependency (e.g. read-after-write) from statement A to
   * statement B, it requires: case 1: stage(A) < stage(B) case 2: stage(A) ==
   * stage(B) and order(A) < order(B)
801
   */
802
803
  void ValidatePipelineBody(const PipelineInfo &pipeline_info,
                            const Array<Block> &original_order) {
804
805
    std::unordered_set<int> used_orders;
    std::unordered_map<int, int> stage_max_order;
806
807
808
809
    std::unordered_map<int, const Block *> order_to_block;
    std::unordered_map<const Block *, int> block_to_stage;
    for (const Block &block : original_order) {
      const auto &stmt_info = pipeline_info.at(block);
810
811
      int order = stmt_info.order;
      CHECK(!used_orders.count(order))
812
813
          << "ValueError: Two statements in the software pipeline cannot have "
             "the same order";
814
815
816
      used_orders.insert(order);
    }

817
818
    std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
        dep_src2dst;
819
820
    BuildDependencyGraph(original_order, &dep_src2dst, nullptr);

821
822
823
824
825
826
    for (const auto &pair : dep_src2dst) {
      const Block &src = pair.first;
      const auto &src_info = pipeline_info.at(src);
      const Array<Block> &dsts = pair.second;
      for (const Block &dst : dsts) {
        const auto &dst_info = pipeline_info.at(dst);
827
828
        CHECK_LE(src_info.stage, dst_info.stage)
            << "ValueError: statement " << dst << " in stage " << dst_info.stage
829
830
            << " cannot depends on statement " << src << " in a later stage "
            << src_info.stage;
831
        if (src_info.stage == dst_info.stage) {
832
833
834
835
          CHECK_LT(src_info.order, dst_info.order)
              << "ValueError: two statements with buffer "
                 "access dependency in the same stage of the "
                 "software pipeline cannot be reordered";
836
837
838
839
840
        }
      }
    }
  }

841
  Stmt VisitStmt_(const ForNode *op) final {
842
843
844
    // Step 1: Recursively rewrite the children first.
    For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
    if (!HasPipelineAnnotation(op)) {
845
      return for_node;
846
    }
847
848
849
    // Step 2: Find the body and buffer allocations of the pipeline. The body
    // can be direct child of the for-loop. If the for-loop has BlockRealize as
    // its child, the pipeline body will be the child of the block.
850
851
    Stmt pipeline_body_root{nullptr};
    bool pipeline_body_from_block = false;
852
    Array<Buffer> pipeline_allocs;
853
854
855
    if (const auto *realize = for_node->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
856
857
858
        ICHECK(buffer->IsInstance<BufferNode>());
        buffer_data_to_buffer_.Set(buffer->data, buffer);
      }
859
      pipeline_body_root = block->body;
860
      pipeline_allocs = block->alloc_buffers;
861
      pipeline_body_from_block = true;
862
    } else {
863
864
865
866
867
868
      pipeline_body_root = for_node->body;
    }

    const SeqStmtNode *pipeline_body_seq = nullptr;
    std::vector<std::function<Stmt(Stmt)>> rewrap_fns;
    auto append_attr_wrapper = [&rewrap_fns](const AttrStmtNode *attr) {
869
      Any node = attr->node;
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
      String attr_key = attr->attr_key;
      PrimExpr value = attr->value;
      Span span = attr->span;
      rewrap_fns.emplace_back(
          [node = std::move(node), attr_key = std::move(attr_key),
           value = std::move(value), span](Stmt body) -> Stmt {
            return AttrStmt(node, attr_key, value, body, span);
          });
    };
    {
      Stmt current = pipeline_body_root;
      while (true) {
        if (const auto *seq_stmt = current.as<SeqStmtNode>()) {
          pipeline_body_seq = seq_stmt;
          break;
        }
        if (const auto *if_then_else = current.as<IfThenElseNode>()) {
          ICHECK(!if_then_else->else_case.defined())
              << "InjectSoftwarePipeline: Can't handle the body of the loop "
                 "because the IfThenElse node has an else branch";
          PrimExpr condition = if_then_else->condition;
          Span span = if_then_else->span;
          rewrap_fns.emplace_back(
              [condition = std::move(condition), span](Stmt body) -> Stmt {
                return IfThenElse(condition, body, Stmt(), span);
              });
          current = if_then_else->then_case;
          continue;
        }
        if (const auto *let_stmt = current.as<LetStmtNode>()) {
          Var var = let_stmt->var;
          PrimExpr value = let_stmt->value;
          Span span = let_stmt->span;
          rewrap_fns.emplace_back([var = std::move(var),
                                   value = std::move(value),
                                   span](Stmt body) -> Stmt {
            return LetStmt(var, value, body, span);
          });
          current = let_stmt->body;
          continue;
        }
        if (const auto *attr = current.as<AttrStmtNode>()) {
          append_attr_wrapper(attr);
          current = attr->body;
          continue;
        }
        LOG(FATAL) << "ValueError: The body of the software pipeline should be "
                   << "SeqStmt, got " << current->GetTypeKey();
      }
919
    }
920
    ICHECK(pipeline_body_seq != nullptr);
921

922
923
    // Step 3: Blockize the components of the pipeline. Each child of the
    // pipelined loop will be converted into a block.
924
    PipelineInfo pipeline_info;
925
    Array<Block> original_order; // pipeline body blocks in the original order
926

927
    auto f_add_child = [&](const Stmt &child) {
928
929
930
      original_order.push_back(MakeBlock(child, buffer_data_to_buffer_));
    };
    for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) {
931
932
      const Stmt &child = pipeline_body_seq->seq[i];
      const auto *nested_block_realize = child.as<BlockRealizeNode>();
933
934
      if (nested_block_realize && is_one(nested_block_realize->predicate) &&
          nested_block_realize->block->body->IsInstance<SeqStmtNode>()) {
935
936
937
938
        const Block &nested_pipeline_block = nested_block_realize->block;
        ICHECK(nested_pipeline_block->match_buffers
                   .empty()); // match_buffer should have been lowered
        for (const auto &buffer : nested_pipeline_block->alloc_buffers) {
939
940
941
942
          pipeline_allocs.push_back(buffer);
          buffer_data_to_buffer_.Set(buffer->data, buffer);
        }
      }
943
      f_add_child(child);
944
945
    }

946
947
948
949
    auto pipeline_stages = Downcast<Array<Integer>>(
        op->annotations.at(tir::attr::software_pipeline_stage));
    auto pipeline_orders = Downcast<Array<Integer>>(
        op->annotations.at(tir::attr::software_pipeline_order));
950
951
    CHECK_EQ(pipeline_stages.size(), original_order.size())
        << "PrimFunc " << global_symbol_ << " has original order "
952
953
954
955
        << original_order.Map(
               [](const auto &block) { return block->name_hint; })
        << ", but pipeline annotation is " << pipeline_stages
        << " with different size";
956
957
    CHECK_EQ(pipeline_orders.size(), original_order.size())
        << "PrimFunc " << global_symbol_ << " has original order "
958
959
960
961
        << original_order.Map(
               [](const auto &block) { return block->name_hint; })
        << ", but pipeline annotation is " << pipeline_orders
        << " with different size";
962
963

    std::unordered_set<int> pipeline_async_stages;
964
965
    if (auto annot =
            op->annotations.Get(tir::attr::software_pipeline_async_stages)) {
966
      for (auto s : Downcast<Array<Integer>>(annot.value())) {
967
968
969
970
971
972
        pipeline_async_stages.insert(s->value);
      }
    }

    for (size_t i = 0; i < pipeline_stages.size(); i++) {
      int stage = static_cast<int>(pipeline_stages[i]->value);
973
974
975
976
977
      bool is_async =
          pipeline_async_stages.find(stage) != pipeline_async_stages.end();
      PipelineAnnotation stage_order{
          stage,
          /*order=*/static_cast<int>(pipeline_orders[i]->value), is_async};
978
979
980
981
982
983
      pipeline_info.emplace(original_order[i], stage_order);
    }

    ValidatePipelineBody(pipeline_info, original_order);

    // Step 4: Rewrite the pipeline body.
984
    Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
985
                                     tvm::ffi::GetRef<For>(op), pipeline_info)
986
                        .BuildPipeline();
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
    auto apply_wrappers = [&](Stmt stmt) {
      for (auto it = rewrap_fns.rbegin(); it != rewrap_fns.rend(); ++it) {
        stmt = (*it)(stmt);
      }
      return stmt;
    };
    if (!rewrap_fns.empty()) {
      if (pipeline_body_from_block) {
        BlockRealize pipeline_realize = Downcast<BlockRealize>(pipeline);
        Block pipeline_block = pipeline_realize->block;
        {
          BlockNode *block_node = pipeline_block.CopyOnWrite();
          block_node->body = apply_wrappers(block_node->body);
        }
        pipeline = BlockRealize(pipeline_realize->iter_values,
                                pipeline_realize->predicate, pipeline_block,
                                pipeline_realize->span);
      } else {
        pipeline = apply_wrappers(pipeline);
      }
    }
1008

1009
1010
1011
    if (const auto *realize = op->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
1012
1013
1014
1015
1016
1017
        buffer_data_to_buffer_.erase(buffer->data);
      }
    }
    return pipeline;
  }

1018
1019
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const auto &buffer : op->alloc_buffers) {
1020
1021
1022
1023
1024
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }

    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));

1025
1026
1027
1028
1029
1030
    Array<Array<BufferRegion>> access =
        GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
    BlockNode *n = block.CopyOnWrite();
    n->reads = access[0];
    n->writes = access[1];

1031
    for (const auto &buffer : op->alloc_buffers) {
1032
1033
      buffer_data_to_buffer_.erase(buffer->data);
    }
1034
    return block;
1035
1036
  }

1037
  bool HasPipelineAnnotation(const ForNode *op) const {
1038
1039
1040
1041
1042
1043
1044
1045
    auto it1 = op->annotations.find(tir::attr::software_pipeline_stage);
    auto it2 = op->annotations.find(tir::attr::software_pipeline_order);
    bool has_stage = it1 != op->annotations.end();
    bool has_order = it2 != op->annotations.end();
    if (has_stage && has_order) {
      return true;
    }
    if (has_stage) {
1046
      LOG(FATAL)
1047
          << "ValueError: Stage of the software pipeline is not defined.";
1048
1049
    }
    if (has_order) {
1050
      LOG(FATAL)
1051
          << "ValueError: Order of the software pipeline is not defined.";
1052
1053
1054
1055
1056
1057
1058
    }
    return false;
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Optional<String> global_symbol_;
};
1059
1060
} // namespace software_pipeline

1061
/*!
1062
1063
 * \brief Transform annotated loops into pipelined one that parallelize
 * producers and consumers. \return The IR transform pass.
1064
1065
1066
 */
tir::transform::Pass InjectSoftwarePipeline() {
  using namespace tir::transform;
1067
  auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
1068
    auto *fptr = f.CopyOnWrite();
1069
    fptr->body = software_pipeline::PipelineInjector::Inject(f);
1070
1071
1072
1073
1074
1075
    fptr->body = ConvertSSA(std::move(fptr->body));
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {});
}

1076
TVM_FFI_STATIC_INIT_BLOCK() {
1077
1078
1079
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.InjectSoftwarePipeline",
                        InjectSoftwarePipeline);
1080
}
1081

1082
1083
} // namespace tl
} // namespace tvm