inject_pipeline.cc 38.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file inject_software_pipeline.cc
22
23
 * \brief Transform annotated loops into pipelined one that parallelize
 * producers and consumers
24
25
26
27
28
29
 */
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h>

#include <unordered_set>
30
#include <utility>
31
32
33
34
35
36
37
38
39

#include "support/utils.h"
#include "tir/schedule/utils.h"
#include "tir/transforms/ir_utils.h"

namespace tvm {
namespace tl {
using namespace tir;

40
41
namespace software_pipeline {

42
43
44
/*!
 * \brief Create a block and infer the access region with the given body.
 *
45
46
47
 * The result is a opaque block that doesn't contain any block iter vars. In
 * case the body is a block realize without predicate, it is unnecessary to
 * create a new block, the block of the block realize will be returned.
48
49
50
51
52
 *
 * \param body The body of the block.
 * \param buffer_data_to_buffer The map from buffer data to buffer.
 * \return The result block.
 */
53
54
55
Block MakeBlock(const Stmt &body,
                const Map<Var, Buffer> &buffer_data_to_buffer) {
  if (const BlockRealizeNode *block_realize = body.as<BlockRealizeNode>()) {
56
57
58
59
60
    if (is_one(block_realize->predicate)) {
      // no need to create a new block
      return block_realize->block;
    }
  }
61
62
63
64
65
  Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
              /*body*/ body);
  Array<Array<BufferRegion>> access =
      GetBlockReadWriteRegion(block, buffer_data_to_buffer);
  BlockNode *n = block.CopyOnWrite();
66
67
68
69
70
71
72
73
74
75
76
77
  n->reads = access[0];
  n->writes = access[1];
  return block;
}

/*! Structure that represents the provided annotation per block or loop. */
struct PipelineAnnotation {
  int stage;
  int order;
  bool async;
};

78
79
using PipelineInfo = std::unordered_map<Block, PipelineAnnotation,
                                        ObjectPtrHash, ObjectPtrEqual>;
80
81

struct BufferAccessInfo {
82
83
  int def = -1; // the defining stage of the buffer
  int use = -1; // the last using stage of the buffer
84
85
86
};

/*!
87
88
89
 * \brief Rewriter for the body of the software pipeline. This pass inserts
 * `floormod` to indices of the remapped buffer to select the version
 * corresponding to the pipeline stage.
90
91
 */
class PipelineBodyRewriter : public StmtExprMutator {
92
public:
93
94
95
  /*!
   * \brief Constructor of PipelineBodyRewriter.
   * \param buffer_data_to_buffer The map from buffer data to buffer.
96
97
98
99
100
101
   * \param buffer_remap The map from original buffer to the buffer with updated
   * shape for multi-versioning in the software pipeline. \param pipeline_loop
   * The original loop to be software pipelined. \param access_all_versions
   * Whether all versions the buffers in the software pipeline are accessed.
   * This will be used to update block access region. In the prologue and
   * epilogue of a two-stage software pipeline, only one version of these
102
   * buffers are accessed.
103
   */
104
105
106
  PipelineBodyRewriter(const Map<Var, Buffer> &buffer_data_to_buffer,
                       const Map<Buffer, Buffer> &buffer_remap,
                       For pipeline_loop, bool access_all_versions)
107
      : buffer_data_to_buffer_(buffer_data_to_buffer),
108
        buffer_remap_(buffer_remap), pipeline_loop_(std::move(pipeline_loop)),
109
        access_all_versions_(access_all_versions) {}
110

111
112
113
private:
  BufferRegion
  RewritePipelineBufferRegion(const BufferRegion &buffer_region) const {
114
115
116
    auto it = buffer_remap_.find(buffer_region->buffer);
    if (it != buffer_remap_.end()) {
      Region new_region = buffer_region->region;
117
118
119
      const Buffer &new_buffer = (*it).second;
      // For pipeline buffers, relax the access region of the first dimension to
      // full extent if access_all_versions == true
120
121
122
      Range accessed_version =
          access_all_versions_
              ? Range::FromMinExtent(0, new_buffer->shape[0])
123
124
125
126
              : Range::FromMinExtent(
                    floormod((pipeline_loop_->loop_var - pipeline_loop_->min),
                             new_buffer->shape[0]),
                    Integer(1));
127
128
129
130
131
132
      new_region.insert(new_region.begin(), accessed_version);
      return BufferRegion(new_buffer, new_region);
    }
    return buffer_region;
  }

133
  PrimExpr RewriteBufferAccess(const Call &call,
134
                               const std::vector<int> &arg_indices) {
135
136
    auto product = [](const Array<PrimExpr> &input) {
      return foldl(
137
138
139
          [](PrimExpr a, PrimExpr b, Span span) {
            return mul(std::move(a), std::move(b), std::move(span));
          },
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
          make_const(DataType::Int(32), 1), input);
    };
    Array<PrimExpr> new_args = call->args;
    for (int i : arg_indices) {
      const Buffer &buffer =
          buffer_data_to_buffer_.at(Downcast<Var>(call->args[i]));
      auto it = buffer_remap_.find(buffer);
      if (it != buffer_remap_.end()) {
        const Buffer &new_buffer = (*it).second;
        const PrimExpr &old_index = call->args[i + 1];
        PrimExpr offset;
        if (new_buffer->strides.empty()) {
          offset = product(buffer->shape);
        } else {
          offset = new_buffer->strides[0];
        }
        PrimExpr new_index =
            old_index +
            floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
        new_args.Set(i + 1, new_index);
      }
    }
    return Call(call->dtype, call->op, new_args, call->span);
  }

165
166
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const Buffer &alloc_buffer : op->alloc_buffers) {
167
168
169
      buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer);
    }
    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
170
171
    BlockNode *n = block.CopyOnWrite();
    n->reads.MutateByApply([this](const BufferRegion &buffer_region) {
172
173
      return RewritePipelineBufferRegion(buffer_region);
    });
174
    n->writes.MutateByApply([this](const BufferRegion &buffer_region) {
175
176
      return RewritePipelineBufferRegion(buffer_region);
    });
177
    for (const Buffer &alloc_buffer : op->alloc_buffers) {
178
179
      buffer_data_to_buffer_.erase(alloc_buffer->data);
    }
180
    return block;
181
182
  }

183
  Stmt VisitStmt_(const BufferStoreNode *op) final {
184
185
186
    BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
    auto it = buffer_remap_.find(store->buffer);
    if (it == buffer_remap_.end()) {
187
      return store;
188
    }
189
190
    const Buffer &new_buffer = (*it).second;
    auto *n = store.CopyOnWrite();
191
    n->buffer = new_buffer;
192
193
    PrimExpr version = floormod(
        (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
194
    n->indices.insert(n->indices.begin(), version);
195
    return store;
196
197
  }

198
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
199
200
201
    BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
    auto it = buffer_remap_.find(load->buffer);
    if (it == buffer_remap_.end()) {
202
      return load;
203
    }
204
205
    const Buffer &new_buffer = (*it).second;
    auto *n = load.CopyOnWrite();
206
    n->buffer = new_buffer;
207
208
    PrimExpr version = floormod(
        (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
209
    n->indices.insert(n->indices.begin(), version);
210
    return load;
211
212
  }

213
  PrimExpr VisitExpr_(const CallNode *op) final {
214
    Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
215
216
217
218
    if (call->op.same_as(builtin::tvm_access_ptr())) {
      return RewriteBufferAccess(call, {1});
    }
    return call;
219
220
221
222
223
224
225
226
227
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Map<Buffer, Buffer> buffer_remap_;
  For pipeline_loop_;
  bool access_all_versions_;
};

/*!
228
229
 * \brief Rewriter for the software pipeline that rewrite a loop into a
 * pipelined one.
230
231
 */
class PipelineRewriter : public StmtExprMutator {
232
public:
233
234
235
  PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
                   const Array<Buffer> &pipeline_allocs,
                   const For &pipeline_loop, const PipelineInfo &pipeline_info)
236
      : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
237
238
        pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop),
        pipeline_info_(pipeline_info) {}
239
240

  Stmt BuildPipeline() {
241
242
243
244
245
    // Step 1: Analyze accesses to the buffers in the pipeline and compute the
    // number of versions need to maintain for each buffer.
    std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
        infos = GetBufferAccessInfo();
    for (const Buffer &buffer : pipeline_allocs_) {
246
247
248
249
250
251
      int num_versions = ComputeBufferVersions(buffer, infos.at(buffer));
      if (num_versions > 1) {
        buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions));
      }
    }
    ordered_stmts_.resize(pipeline_info_.size());
252
253
    for (const auto &[block, anno] : pipeline_info_) {
      ordered_stmts_.Set(anno.order, block);
254
255
    }

256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
    for (const Block &block : ordered_stmts_) {
      int stage = pipeline_info_[block].stage;
      if (pipeline_info_[block].async) {
        auto &state = async_states[stage];
        state.producer_head = pipeline_loop_->min - 1;
        for (auto write_region : block->writes) {
          auto buffer = write_region->buffer;
          state.dst_buffers.insert(buffer.get());
          if (buffer_remap_.count(buffer))
            state.dst_buffers.insert(buffer_remap_[buffer].get());
        }
      }
    }
    std::unordered_set<int> consumed;
    for (const Block &block : ordered_stmts_) {
      int stage = pipeline_info_[block].stage;
      if (pipeline_info_[block].async) {
        auto &state = async_states[stage];
        if (state.commit_groups.empty() || consumed.count(stage)) {
          state.commit_groups.push_back({});
        }
        state.commit_groups.back().push_back(pipeline_info_[block].order);
        consumed.erase(stage);
        for (auto write_region : block->writes) {
          auto buffer = buffer_remap_.count(write_region->buffer)
                            ? buffer_remap_[write_region->buffer]
                            : write_region->buffer;
          state.buffer_to_commit_group_[buffer.get()] =
              state.commit_groups.size() - 1;
        }
      }

      for (auto read_region : block->reads) {
        for (const auto &[producer_stage_id, producer_state] : async_states) {
          if (producer_stage_id <= stage &&
              producer_state.writes(read_region->buffer)) {
            consumed.insert(producer_stage_id);
          }
        }
295
296
      }
    }
297
298
299
300
301
302
303
304
305
306

    // Step 2: Emit the pipeline prologue, body and epilogue.
    Stmt prologue = EmitImpl(pipeline_loop_->min,
                             pipeline_loop_->min + max_stage_, true, true);
    Stmt body =
        EmitImpl(pipeline_loop_->min + max_stage_,
                 pipeline_loop_->min + pipeline_loop_->extent, false, false);
    Stmt epilogue = EmitImpl(
        pipeline_loop_->min + pipeline_loop_->extent,
        pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true);
307
308
309

    SeqStmt stmt = SeqStmt({prologue, body, epilogue});

310
311
    // Step 3: Make a new block that contains new buffer allocations after
    // pipeline rewriting.
312
    Array<Buffer> alloc_buffers;
313
    for (const auto &alloc : pipeline_allocs_) {
314
315
316
317
318
319
320
321
      alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc));
      buffer_data_to_buffer_.erase(alloc->data);
    }
    Block block = MakeBlock(stmt, buffer_data_to_buffer_);
    block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers);
    return BlockRealize({}, Bool(true), block);
  }

322
private:
323
324
325
  /*!
   * \brief Analyze accesses to the buffers in the software pipeline.
   *
326
327
328
   * This method check the 'define' and 'use' stage of the buffers in the
   * software pipeline, which can be used to compute the number of versions
   * needed to maintain after rewriting.
329
330
331
   */
  std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
  GetBufferAccessInfo() {
332
333
334
335
    std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
        infos;
    for (const auto &pair : pipeline_info_) {
      const Block &block = pair.first;
336
337
338
      int stage = pair.second.stage;
      max_stage_ = std::max(max_stage_, stage);

339
      for (const BufferRegion &write : block->writes) {
340
341
342
        if (!infos.count(write->buffer)) {
          infos.emplace(write->buffer, BufferAccessInfo{});
        }
343
        auto &info = infos.at(write->buffer);
344
345
346
347
348
349
350
        if (info.def == -1) {
          info.def = stage;
        } else {
          info.def = std::min(info.def, stage);
        }
      }

351
      for (const BufferRegion &read : block->reads) {
352
353
354
        if (!infos.count(read->buffer)) {
          infos.emplace(read->buffer, BufferAccessInfo{});
        }
355
        auto &info = infos.at(read->buffer);
356
357
358
359
360
361
362
363
364
365
366
367
        info.use = std::max(info.use, stage);
      }
    }
    return infos;
  }

  /*!
   * \brief Check whether two regions have intersections.
   * \param region1 The first region.
   * \param region2 The second region.
   * \return Whether region1 and region2 have intersections.
   */
368
  bool MayConflict(const Region &region1, const Region &region2) {
369
370
371
372
373
374
375
376
377
378
379
380
381
382
    ICHECK(region1.size() == region2.size());
    for (size_t i = 0; i < region1.size(); i++) {
      Range dim1 = region1[i];
      Range dim2 = region2[i];
      auto int_set1 = arith::IntSet::FromRange(dim1);
      auto int_set2 = arith::IntSet::FromRange(dim2);
      if (arith::Intersect({int_set1, int_set2}).IsNothing()) {
        return false;
      }
    }
    return true;
  }

  /*!
383
384
   * \brief Compute the number of versions need to maintain for buffer accessed
   * in the software pipeline.
385
   *
386
387
388
389
390
391
   * This method applies liveness analysis to the target buffer to compute the
   * number of versions need to maintain during the software pipeline.
   * Annotation `attr::double_buffer_scope` is handled here which provides a way
   * to override the result of the analysis. Additional double buffering in the
   * software pipeline can be useful to eliminate synchronizations in GPU
   * devices.
392
393
394
395
396
   *
   * \param buffer The target buffer
   * \param buffer_info The access information of the target buffer.
   * \return The number of versions required for the target buffer.
   */
397
398
  int ComputeBufferVersions(const Buffer &buffer,
                            const BufferAccessInfo &buffer_info) {
399
    if (buffer_info.def == -1) {
400
401
      // Keep the original number of versions as buffers defined outside the
      // software pipeline should not be mutated.
402
403
404
405
      return 1;
    }

    // `use - def + 1` is a upper bound of the needed versions
406
407
    // We optimize a few case where the number of versions can be smaller than
    // the upper bound
408
    int num_versions = buffer_info.use - buffer_info.def + 1;
409
    if (num_versions >= 2) {
410
411
412
413
      // A special case when `use - def + 1 == 2`. Double buffering is only
      // needed in this case when these exists a reader block_i and a writer
      // block_j such that order(block_i) < order(block_j) and stage(block_i) <
      // stage(block_j) and the access regions of block_i and block_j overlap.
414
      bool need_multi_version = false;
415
416
417
      for (const auto &pair1 : pipeline_info_) {
        const Block &writer_block = pair1.first;
        const auto &writer_info = pair1.second;
418

419
420
421
        auto it1 = std::find_if(writer_block->writes.begin(),
                                writer_block->writes.end(),
                                [&](const BufferRegion &buffer_region) {
422
423
424
425
426
427
                                  return buffer_region->buffer.same_as(buffer);
                                });
        if (it1 == writer_block->writes.end()) {
          continue;
        }

428
429
430
431
432
433
434
435
        for (const auto &pair2 : pipeline_info_) {
          const Block &reader_block = pair2.first;
          const auto &reader_info = pair2.second;
          auto it2 = std::find_if(
              reader_block->reads.begin(), reader_block->reads.end(),
              [&](const BufferRegion &buffer_region) {
                return buffer_region->buffer.same_as(buffer);
              });
436
437
438
          if (it2 == reader_block->reads.end()) {
            continue;
          }
439
440
          if (writer_info.order < reader_info.order &&
              writer_info.stage < reader_info.stage &&
441
442
443
444
445
446
447
              MayConflict((*it1)->region, (*it2)->region)) {
            need_multi_version = true;
            break;
          }
        }
      }
      if (!need_multi_version) {
448
        num_versions--;
449
450
451
452
453
454
      }
    }
    return num_versions;
  }

  /*!
455
456
   * \brief Rewrite buffer allocation to keep multiple versions of original
   * buffer for pipelined accesses. \param buffer The buffer to be resized.
457
458
459
   * \param num_versions The number of versions to keep.
   * \return The resized buffer.
   */
460
  Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
461
462
    ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get()));
    new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
463
    if (!new_buffer->strides.empty()) {
464
465
466
467
468
469
470
      ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
      PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1];
      new_buffer->strides.insert(new_buffer->strides.begin(), stride_0);
    }
    return Buffer(new_buffer);
  }

471
472
  // Per-stage states that need to be tracked across pipeline prologue, body,
  // and epilogue.
473
474
  struct AsyncStateGlobal {
    // Buffers that this stage asynchronously writes.
475
476
477
478
479
480
481
    std::unordered_set<const BufferNode *> dst_buffers;
    // An imaginary index that the latest async operation associated with this
    // stage has written into. Only valid if all associated predicates are true,
    // so that we can count the number of async invocations exactly. When it is
    // valid, it is the "sum of extents of loops that have been executed" - 1,
    // e.g. for epilogue it is prologue extent + body extent - 1. This is only
    // needed to compute wait count for epilogue without async producers.
482
483
484
    PrimExpr producer_head;
    std::vector<std::vector<int>> commit_groups;
    std::unordered_map<const BufferNode *, int> buffer_to_commit_group_;
485
486
487
    bool writes(const Buffer &buf) const {
      return dst_buffers.count(buf.get()) > 0;
    }
488
489
  };

490
491
  // Per-stage states that are local to each of pipeline prologue, body, and
  // epilogue.
492
  struct AsyncStateLocal {
493
    struct PendingWait {
494
495
      // The index into a list of blocks, where async_wait_queue should be
      // attached at the beginning.
496
      int insert_before;
497
498
      // in_flight_count would be a more precise name, but the implementation
      // uses wait_count for brevity.
499
500
501
      PrimExpr wait_count{nullptr};

      bool valid() const { return wait_count.defined(); }
502
503
504
    };

    std::vector<PendingWait> pending_waits;
505

506
507
    // A symbolic expression representing the index the latest async operation
    // associated with this stage has written into, at the "current" iteration.
508
509
510
511
512
513
    Optional<PrimExpr> producer_head;
  };

  /*! Structure holding intermediate information for pipeline loop rewriting. */
  struct RewrittenBlockInfo {
    int stage;
514
    int order;
515
516
517
518
519
520
    PrimExpr predicate;
    Block block;
    PrimExpr access_index;
    bool is_async;
  };

521
522
  void PopulateWaitCounts(const std::vector<RewrittenBlockInfo> &new_blocks,
                          std::map<int, AsyncStateLocal> *async_states_local) {
523
524
525
    for (size_t i = 0; i < new_blocks.size(); ++i) {
      int producer_stage_idx = -1;
      for (auto read_region : new_blocks[i].block->reads) {
526
527
528
        for (const auto &[stage, state] : async_states) {
          if (stage <= new_blocks[i].stage &&
              state.writes(read_region->buffer)) {
529
530
            // Found an earlier stage where read_region->buffer was
            // asynchronously written
531
            ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage)
532
                << "A dependency on multiple async stages is not supported";
533
            producer_stage_idx = stage;
534
535
536
          }
        }
      }
537
538
      if (producer_stage_idx == -1)
        continue;
539
      const auto &state = async_states[producer_stage_idx];
540
      auto &dep_local_state = (*async_states_local)[producer_stage_idx];
541
542
543
544
545
546
547
548
549
550
551
      PrimExpr in_flight_cnt = 0;
      for (const auto &group : state.commit_groups) {
        PrimExpr consumer_head = new_blocks[i].access_index;
        PrimExpr producer_head;
        if (dep_local_state.producer_head.defined()) {
          producer_head = dep_local_state.producer_head.value();
          // if the group is after the wait point, minus by 1
          if (group.front() > new_blocks[i].order)
            producer_head -= 1;
        } else {
          producer_head = state.producer_head;
552
        }
553
        in_flight_cnt += producer_head - consumer_head;
554
555
      }

556
557
558
559
560
561
562
563
564
565
566
567
      // We can relax the in-flight-count by the number of independent commit.
      std::unordered_set<int> dependent_groups;
      for (const auto &read_region : new_blocks[i].block->reads) {
        if (state.buffer_to_commit_group_.count(read_region->buffer.get()))
          dependent_groups.insert(
              state.buffer_to_commit_group_.at(read_region->buffer.get()));
      }
      for (int i = int(state.commit_groups.size()) - 1; i >= 0; i--) {
        if (dependent_groups.count(i) == 0)
          in_flight_cnt += 1;
        else
          break; // stop relaxing
568
      }
569
570
571
      in_flight_cnt = analyzer_.Simplify(in_flight_cnt);
      dep_local_state.pending_waits.push_back(
          {static_cast<int>(i), in_flight_cnt});
572
573
574
    }
  }

575
576
  // Given pipelined blocks and async-related information, generate final loop
  // statements with async scopes (if any).
577
  Array<Stmt> CompletePipelineLoopStatements(
578
      const std::vector<RewrittenBlockInfo> &blocks,
579
      const std::map<int, AsyncStateLocal> &async_states_local) const {
580
    std::vector<RewrittenBlockInfo> new_blocks = blocks;
581
    for (const auto &[stage_id, state] : async_states_local) {
582
583
584
585
586
587
588
      for (const auto &pw : state.pending_waits) {
        auto &block = new_blocks[pw.insert_before].block;
        BlockNode *n = block.CopyOnWrite();
        auto zero = make_zero(DataType::Int(32));
        n->body = AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
                           AttrStmt(zero, tir::attr::async_wait_inflight_count,
                                    pw.wait_count, n->body));
589
      }
590
    }
591

592
593
594
595
596
    // mark the last async stmt as commit
    std::unordered_set<int> commit_group_indices;
    for (const auto &[stage_id, state] : async_states) {
      for (size_t i = 0; i < state.commit_groups.size(); ++i) {
        commit_group_indices.insert(state.commit_groups[i].back());
597
598
599
600
601
      }
    }

    Array<Stmt> stmts;

602
603
604
605
606
607
608
    for (size_t i = 0; i < new_blocks.size(); i++) {
      Block block = new_blocks[i].block;
      if (commit_group_indices.count(new_blocks[i].order)) {
        auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)),
                                           tir::attr::async_commit_queue_scope,
                                           new_blocks[i].stage, block->body);
        block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_);
609
      }
610
      stmts.push_back(BlockRealize({}, new_blocks[i].predicate, block));
611
612
613
614
615
616
617
618
619
620
621
622
    }

    return stmts;
  }

  /*!
   * \brief Emit the pipeline loop in the given range.
   * \param start The start of the range
   * \param end The end of the range
   * \param unroll_loop Whether the loop should be unrolled.
   * \return The result loop.
   */
623
  Stmt EmitImpl(const PrimExpr &start, const PrimExpr &end, bool unroll_loop,
624
                bool need_bound_check) {
625
626
    PrimExpr new_loop_var;
    PrimExpr extent = end - start;
627
628
629
    auto make_nop = []() {
      return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {}));
    };
630
631
632

    bool is_unit_loop = analyzer_.CanProveEqual(extent, 1);
    if (is_unit_loop) {
633
      new_loop_var = start; // use constants as the loop var for unit loops
634
635
636
637
638
639
640
641
642
643
    } else {
      new_loop_var = pipeline_loop_->loop_var.copy_with_suffix("");
      analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end));
    }

    std::vector<RewrittenBlockInfo> new_blocks;

    // Async related
    std::map<int, AsyncStateLocal> async_states_local;

644
    for (const Block &block : ordered_stmts_) {
645
      int stage = pipeline_info_.at(block).stage;
646
647
      int order = pipeline_info_.at(block).order;
      PrimExpr inbound = Bool(true);
648
      PrimExpr skewed_loop_var = new_loop_var - stage;
649
650
651
652
      if (need_bound_check)
        inbound =
            analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) &&
            (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent);
653
654
655
      if (analyzer_.CanProve(!inbound)) {
        continue;
      }
656
657
658
      Block new_block = Downcast<Block>(
          PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_,
                               pipeline_loop_, max_stage_ != 1)(block));
659
660
661
662

      PrimExpr delta = start - pipeline_loop_->min;
      // This variable corresponds to
      // - "producer_head" if this stage is an async producer
663
664
      // - "consumer_head" if this stage reads from asynchronously written
      // buffers.
665
      PrimExpr normalized_access_index =
666
          is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
667

668
669
      // Adjust the block predicate and the body according to the final loop
      // bound
670
671
672
673
674
      //  [pipeline_loop_->min, extent).
      if (!is_unit_loop) {
        Var loop_iter = Downcast<Var>(new_loop_var);
        inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}});
      }
675
676
      new_block = Downcast<Block>(Substitute(
          new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
677

678
      if (pipeline_info_[block].async) {
679
        auto &local_state = async_states_local[stage];
680
        local_state.producer_head = normalized_access_index;
681
682
683
        BlockNode *n = new_block.CopyOnWrite();
        n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope,
                           1, n->body);
684
685
      }

686
687
      new_blocks.push_back({stage, order, inbound, new_block,
                            normalized_access_index,
688
                            pipeline_info_[block].async});
689
    }
690

691
692
693
    PopulateWaitCounts(new_blocks, &async_states_local);

    auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local);
694

695
696
    Stmt new_loop{nullptr};

697
    if (stmts.empty()) {
698
699
      return make_nop();
    }
700

701
702
    if (stmts.size() == 1) {
      new_loop = stmts[0];
703
    } else {
704
      new_loop = SeqStmt(stmts);
705
706
707
    }

    if (!is_unit_loop) {
708
709
710
711
712
713
714
715
716
      Map<String, Any> preserved_annotations;
      for (const auto &kv : pipeline_loop_->annotations) {
        const String &key = kv.first;
        if (kv.first != tir::attr::software_pipeline_stage &&
            kv.first != tir::attr::software_pipeline_order &&
            kv.first != tir::attr::software_pipeline_async_stages) {
          preserved_annotations.Set(key, kv.second);
        }
      }
717
      new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
718
                     unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind,
719
                     std::move(new_loop), std::nullopt, preserved_annotations);
720
721
    }
    // Update producer heads in the global async states.
722
723
    for (const auto &[stage_id, state] : async_states_local) {
      async_states[stage_id].producer_head += extent;
724
725
    }

726
    return BlockRealize({}, Bool(true),
727
                        MakeBlock(new_loop, buffer_data_to_buffer_));
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
  }

  arith::Analyzer analyzer_;
  Map<Var, Buffer> buffer_data_to_buffer_;
  Array<Buffer> pipeline_allocs_;
  For pipeline_loop_;
  PipelineInfo pipeline_info_;
  int max_stage_ = -1;
  Map<Buffer, Buffer> buffer_remap_;
  Array<Block> ordered_stmts_;
  std::map<int, AsyncStateGlobal> async_states;
};

/*!
 * \brief Build the dependency graph among a array of blocks.
 * \param[in] blocks The array of blocks.
744
745
746
 * \param[out] dep_src2dst Optional, a map to store dependency edges from the
 * source to the destination. \param[out] dep_dst2src Optional, a map to store
 * dependency edges from the destination to the source.
747
 */
748
749
750
751
752
void BuildDependencyGraph(const Array<Block> &blocks,
                          std::unordered_map<Block, Array<Block>, ObjectPtrHash,
                                             ObjectPtrEqual> *dep_src2dst,
                          std::unordered_map<Block, Array<Block>, ObjectPtrHash,
                                             ObjectPtrEqual> *dep_dst2src) {
753
754
  std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
      buffer_writers;
755
756
757

  for (const Block &block : blocks) {
    for (const BufferRegion &read : block->reads) {
758
759
      auto it = buffer_writers.find(read->buffer->data);
      if (it != buffer_writers.end()) {
760
        for (const Block &writer : it->second) {
761
762
763
764
765
766
767
768
769
          if (dep_src2dst != nullptr) {
            (*dep_src2dst)[writer].push_back(block);
          }
          if (dep_dst2src != nullptr) {
            (*dep_dst2src)[block].push_back(writer);
          }
        }
      }
    }
770
    for (const BufferRegion &write : block->writes) {
771
772
773
774
775
776
      buffer_writers[write->buffer->data].push_back(block);
    }
  }
}

class PipelineInjector : private StmtExprMutator {
777
778
public:
  static Stmt Inject(const PrimFunc &func) {
779
780
    auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
    PipelineInjector injector(global_symbol);
781
782
    for (const auto &kv : func->buffer_map) {
      const Buffer &buffer = kv.second;
783
784
785
786
787
      injector.buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    return injector(func->body);
  }

788
789
private:
  explicit PipelineInjector(Optional<String> global_symbol)
790
      : global_symbol_(std::move(global_symbol)) {}
791
792
793
794

  /*!
   * \brief Check the pipeline satisfies the following conditions:
   * 1. No conflicting order: The order of each statement should be unique.
795
796
797
798
   * 2. Reordering of statements doesn't break buffer access dependencies.
   * Specifically, for dependency (e.g. read-after-write) from statement A to
   * statement B, it requires: case 1: stage(A) < stage(B) case 2: stage(A) ==
   * stage(B) and order(A) < order(B)
799
   */
800
801
  void ValidatePipelineBody(const PipelineInfo &pipeline_info,
                            const Array<Block> &original_order) {
802
803
    std::unordered_set<int> used_orders;
    std::unordered_map<int, int> stage_max_order;
804
805
806
807
    std::unordered_map<int, const Block *> order_to_block;
    std::unordered_map<const Block *, int> block_to_stage;
    for (const Block &block : original_order) {
      const auto &stmt_info = pipeline_info.at(block);
808
809
      int order = stmt_info.order;
      CHECK(!used_orders.count(order))
810
811
          << "ValueError: Two statements in the software pipeline cannot have "
             "the same order";
812
813
814
      used_orders.insert(order);
    }

815
816
    std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
        dep_src2dst;
817
818
    BuildDependencyGraph(original_order, &dep_src2dst, nullptr);

819
820
821
822
823
824
    for (const auto &pair : dep_src2dst) {
      const Block &src = pair.first;
      const auto &src_info = pipeline_info.at(src);
      const Array<Block> &dsts = pair.second;
      for (const Block &dst : dsts) {
        const auto &dst_info = pipeline_info.at(dst);
825
826
        CHECK_LE(src_info.stage, dst_info.stage)
            << "ValueError: statement " << dst << " in stage " << dst_info.stage
827
828
            << " cannot depends on statement " << src << " in a later stage "
            << src_info.stage;
829
        if (src_info.stage == dst_info.stage) {
830
831
832
833
          CHECK_LT(src_info.order, dst_info.order)
              << "ValueError: two statements with buffer "
                 "access dependency in the same stage of the "
                 "software pipeline cannot be reordered";
834
835
836
837
838
        }
      }
    }
  }

839
  Stmt VisitStmt_(const ForNode *op) final {
840
841
842
    // Step 1: Recursively rewrite the children first.
    For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
    if (!HasPipelineAnnotation(op)) {
843
      return for_node;
844
    }
845
846
847
    // Step 2: Find the body and buffer allocations of the pipeline. The body
    // can be direct child of the for-loop. If the for-loop has BlockRealize as
    // its child, the pipeline body will be the child of the block.
848
849
    Stmt pipeline_body{nullptr};
    Array<Buffer> pipeline_allocs;
850
851
852
    if (const auto *realize = for_node->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
853
854
855
        ICHECK(buffer->IsInstance<BufferNode>());
        buffer_data_to_buffer_.Set(buffer->data, buffer);
      }
856
      pipeline_body = block->body;
857
858
859
860
861
      pipeline_allocs = block->alloc_buffers;
    } else {
      pipeline_body = for_node->body;
    }

862
863
864
865
    const SeqStmtNode *pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
    CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
                                "should be SeqStmt, got "
                             << pipeline_body->GetTypeKey();
866

867
868
    // Step 3: Blockize the components of the pipeline. Each child of the
    // pipelined loop will be converted into a block.
869
    PipelineInfo pipeline_info;
870
    Array<Block> original_order; // pipeline body blocks in the original order
871

872
    auto f_add_child = [&](const Stmt &child) {
873
874
875
      original_order.push_back(MakeBlock(child, buffer_data_to_buffer_));
    };
    for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) {
876
877
      const auto *nested_block_realize =
          pipeline_body_seq->seq[i].as<BlockRealizeNode>();
878
879
      if (nested_block_realize && is_one(nested_block_realize->predicate) &&
          nested_block_realize->block->body->IsInstance<SeqStmtNode>()) {
880
881
882
883
        const Block &nested_pipeline_block = nested_block_realize->block;
        ICHECK(nested_pipeline_block->match_buffers
                   .empty()); // match_buffer should have been lowered
        for (const auto &buffer : nested_pipeline_block->alloc_buffers) {
884
885
886
          pipeline_allocs.push_back(buffer);
          buffer_data_to_buffer_.Set(buffer->data, buffer);
        }
887
        const auto *nested_seq = nested_pipeline_block->body.as<SeqStmtNode>();
888
889
890
891
892
893
894
895
        for (size_t j = 0; j < nested_seq->seq.size(); j++) {
          f_add_child(nested_seq->seq[j]);
        }
      } else {
        f_add_child(pipeline_body_seq->seq[i]);
      }
    }

896
897
898
899
    auto pipeline_stages = Downcast<Array<Integer>>(
        op->annotations.at(tir::attr::software_pipeline_stage));
    auto pipeline_orders = Downcast<Array<Integer>>(
        op->annotations.at(tir::attr::software_pipeline_order));
900
901
    CHECK_EQ(pipeline_stages.size(), original_order.size())
        << "PrimFunc " << global_symbol_ << " has original order "
902
903
904
905
        << original_order.Map(
               [](const auto &block) { return block->name_hint; })
        << ", but pipeline annotation is " << pipeline_stages
        << " with different size";
906
907
    CHECK_EQ(pipeline_orders.size(), original_order.size())
        << "PrimFunc " << global_symbol_ << " has original order "
908
909
910
911
        << original_order.Map(
               [](const auto &block) { return block->name_hint; })
        << ", but pipeline annotation is " << pipeline_orders
        << " with different size";
912
913

    std::unordered_set<int> pipeline_async_stages;
914
915
    if (auto annot =
            op->annotations.Get(tir::attr::software_pipeline_async_stages)) {
916
      for (auto s : Downcast<Array<Integer>>(annot.value())) {
917
918
919
920
921
922
        pipeline_async_stages.insert(s->value);
      }
    }

    for (size_t i = 0; i < pipeline_stages.size(); i++) {
      int stage = static_cast<int>(pipeline_stages[i]->value);
923
924
925
926
927
      bool is_async =
          pipeline_async_stages.find(stage) != pipeline_async_stages.end();
      PipelineAnnotation stage_order{
          stage,
          /*order=*/static_cast<int>(pipeline_orders[i]->value), is_async};
928
929
930
931
932
933
      pipeline_info.emplace(original_order[i], stage_order);
    }

    ValidatePipelineBody(pipeline_info, original_order);

    // Step 4: Rewrite the pipeline body.
934
935
936
    Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
                                     GetRef<For>(op), pipeline_info)
                        .BuildPipeline();
937

938
939
940
    if (const auto *realize = op->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
941
942
943
944
945
946
        buffer_data_to_buffer_.erase(buffer->data);
      }
    }
    return pipeline;
  }

947
948
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const auto &buffer : op->alloc_buffers) {
949
950
951
952
953
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }

    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));

954
955
956
957
958
959
    Array<Array<BufferRegion>> access =
        GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
    BlockNode *n = block.CopyOnWrite();
    n->reads = access[0];
    n->writes = access[1];

960
    for (const auto &buffer : op->alloc_buffers) {
961
962
      buffer_data_to_buffer_.erase(buffer->data);
    }
963
    return block;
964
965
  }

966
  bool HasPipelineAnnotation(const ForNode *op) const {
967
968
969
970
971
972
973
974
    auto it1 = op->annotations.find(tir::attr::software_pipeline_stage);
    auto it2 = op->annotations.find(tir::attr::software_pipeline_order);
    bool has_stage = it1 != op->annotations.end();
    bool has_order = it2 != op->annotations.end();
    if (has_stage && has_order) {
      return true;
    }
    if (has_stage) {
975
      LOG(FATAL)
976
          << "ValueError: Stage of the software pipeline is not defined.";
977
978
    }
    if (has_order) {
979
      LOG(FATAL)
980
          << "ValueError: Order of the software pipeline is not defined.";
981
982
983
984
985
986
987
    }
    return false;
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Optional<String> global_symbol_;
};
988
989
} // namespace software_pipeline

990
/*!
991
992
 * \brief Transform annotated loops into pipelined one that parallelize
 * producers and consumers. \return The IR transform pass.
993
994
995
 */
tir::transform::Pass InjectSoftwarePipeline() {
  using namespace tir::transform;
996
  auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
997
    auto *fptr = f.CopyOnWrite();
998
    fptr->body = software_pipeline::PipelineInjector::Inject(f);
999
1000
1001
1002
1003
1004
    fptr->body = ConvertSSA(std::move(fptr->body));
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {});
}

1005
1006
1007
1008
1009
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.InjectSoftwarePipeline",
                        InjectSoftwarePipeline);
});
1010

1011
1012
} // namespace tl
} // namespace tvm