inject_pipeline.cc 38 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file inject_software_pipeline.cc
22
23
 * \brief Transform annotated loops into pipelined one that parallelize
 * producers and consumers
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
 */
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h>

#include <unordered_set>

#include "support/utils.h"
#include "tir/schedule/utils.h"
#include "tir/transforms/ir_utils.h"

namespace tvm {
namespace tl {
using namespace tir;

39
40
namespace software_pipeline {

41
42
43
/*!
 * \brief Create a block and infer the access region with the given body.
 *
44
45
46
 * The result is a opaque block that doesn't contain any block iter vars. In
 * case the body is a block realize without predicate, it is unnecessary to
 * create a new block, the block of the block realize will be returned.
47
48
49
50
51
 *
 * \param body The body of the block.
 * \param buffer_data_to_buffer The map from buffer data to buffer.
 * \return The result block.
 */
52
53
54
Block MakeBlock(const Stmt &body,
                const Map<Var, Buffer> &buffer_data_to_buffer) {
  if (const BlockRealizeNode *block_realize = body.as<BlockRealizeNode>()) {
55
56
57
58
59
    if (is_one(block_realize->predicate)) {
      // no need to create a new block
      return block_realize->block;
    }
  }
60
61
62
63
64
  Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
              /*body*/ body);
  Array<Array<BufferRegion>> access =
      GetBlockReadWriteRegion(block, buffer_data_to_buffer);
  BlockNode *n = block.CopyOnWrite();
65
66
67
68
69
70
71
72
73
74
75
76
  n->reads = access[0];
  n->writes = access[1];
  return block;
}

/*! Structure that represents the provided annotation per block or loop. */
struct PipelineAnnotation {
  int stage;
  int order;
  bool async;
};

77
78
using PipelineInfo = std::unordered_map<Block, PipelineAnnotation,
                                        ObjectPtrHash, ObjectPtrEqual>;
79
80

struct BufferAccessInfo {
81
82
  int def = -1; // the defining stage of the buffer
  int use = -1; // the last using stage of the buffer
83
84
85
};

/*!
86
87
88
 * \brief Rewriter for the body of the software pipeline. This pass inserts
 * `floormod` to indices of the remapped buffer to select the version
 * corresponding to the pipeline stage.
89
90
 */
class PipelineBodyRewriter : public StmtExprMutator {
91
public:
92
93
94
  /*!
   * \brief Constructor of PipelineBodyRewriter.
   * \param buffer_data_to_buffer The map from buffer data to buffer.
95
96
97
98
99
100
   * \param buffer_remap The map from original buffer to the buffer with updated
   * shape for multi-versioning in the software pipeline. \param pipeline_loop
   * The original loop to be software pipelined. \param access_all_versions
   * Whether all versions the buffers in the software pipeline are accessed.
   * This will be used to update block access region. In the prologue and
   * epilogue of a two-stage software pipeline, only one version of these
101
   * buffers are accessed.
102
   */
103
104
105
  PipelineBodyRewriter(const Map<Var, Buffer> &buffer_data_to_buffer,
                       const Map<Buffer, Buffer> &buffer_remap,
                       For pipeline_loop, bool access_all_versions)
106
      : buffer_data_to_buffer_(buffer_data_to_buffer),
107
        buffer_remap_(buffer_remap), pipeline_loop_(pipeline_loop),
108
        access_all_versions_(access_all_versions) {}
109

110
111
112
private:
  BufferRegion
  RewritePipelineBufferRegion(const BufferRegion &buffer_region) const {
113
114
115
    auto it = buffer_remap_.find(buffer_region->buffer);
    if (it != buffer_remap_.end()) {
      Region new_region = buffer_region->region;
116
117
118
      const Buffer &new_buffer = (*it).second;
      // For pipeline buffers, relax the access region of the first dimension to
      // full extent if access_all_versions == true
119
120
121
      Range accessed_version =
          access_all_versions_
              ? Range::FromMinExtent(0, new_buffer->shape[0])
122
123
124
125
              : Range::FromMinExtent(
                    floormod((pipeline_loop_->loop_var - pipeline_loop_->min),
                             new_buffer->shape[0]),
                    Integer(1));
126
127
128
129
130
131
      new_region.insert(new_region.begin(), accessed_version);
      return BufferRegion(new_buffer, new_region);
    }
    return buffer_region;
  }

132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
  PrimExpr RewriteBufferAccess(const Call &call,
                               const std::vector<int> arg_indices) {
    auto product = [](const Array<PrimExpr> &input) {
      return foldl(
          [](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
          make_const(DataType::Int(32), 1), input);
    };
    Array<PrimExpr> new_args = call->args;
    for (int i : arg_indices) {
      const Buffer &buffer =
          buffer_data_to_buffer_.at(Downcast<Var>(call->args[i]));
      auto it = buffer_remap_.find(buffer);
      if (it != buffer_remap_.end()) {
        const Buffer &new_buffer = (*it).second;
        const PrimExpr &old_index = call->args[i + 1];
        PrimExpr offset;
        if (new_buffer->strides.empty()) {
          offset = product(buffer->shape);
        } else {
          offset = new_buffer->strides[0];
        }
        PrimExpr new_index =
            old_index +
            floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
        new_args.Set(i + 1, new_index);
      }
    }
    return Call(call->dtype, call->op, new_args, call->span);
  }

162
163
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const Buffer &alloc_buffer : op->alloc_buffers) {
164
165
166
      buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer);
    }
    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
167
168
    BlockNode *n = block.CopyOnWrite();
    n->reads.MutateByApply([this](const BufferRegion &buffer_region) {
169
170
      return RewritePipelineBufferRegion(buffer_region);
    });
171
    n->writes.MutateByApply([this](const BufferRegion &buffer_region) {
172
173
      return RewritePipelineBufferRegion(buffer_region);
    });
174
    for (const Buffer &alloc_buffer : op->alloc_buffers) {
175
176
      buffer_data_to_buffer_.erase(alloc_buffer->data);
    }
177
    return block;
178
179
  }

180
  Stmt VisitStmt_(const BufferStoreNode *op) final {
181
182
183
    BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
    auto it = buffer_remap_.find(store->buffer);
    if (it == buffer_remap_.end()) {
184
      return store;
185
    }
186
187
    const Buffer &new_buffer = (*it).second;
    auto *n = store.CopyOnWrite();
188
    n->buffer = new_buffer;
189
190
    PrimExpr version = floormod(
        (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
191
    n->indices.insert(n->indices.begin(), version);
192
    return store;
193
194
  }

195
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
196
197
198
    BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
    auto it = buffer_remap_.find(load->buffer);
    if (it == buffer_remap_.end()) {
199
      return load;
200
    }
201
202
    const Buffer &new_buffer = (*it).second;
    auto *n = load.CopyOnWrite();
203
    n->buffer = new_buffer;
204
205
    PrimExpr version = floormod(
        (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
206
    n->indices.insert(n->indices.begin(), version);
207
    return load;
208
209
  }

210
  PrimExpr VisitExpr_(const CallNode *op) final {
211
    Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
212
213
214
215
    if (call->op.same_as(builtin::tvm_access_ptr())) {
      return RewriteBufferAccess(call, {1});
    }
    return call;
216
217
218
219
220
221
222
223
224
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Map<Buffer, Buffer> buffer_remap_;
  For pipeline_loop_;
  bool access_all_versions_;
};

/*!
225
226
 * \brief Rewriter for the software pipeline that rewrite a loop into a
 * pipelined one.
227
228
 */
class PipelineRewriter : public StmtExprMutator {
229
public:
230
231
232
  PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
                   const Array<Buffer> &pipeline_allocs,
                   const For &pipeline_loop, const PipelineInfo &pipeline_info)
233
      : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
234
235
        pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop),
        pipeline_info_(pipeline_info) {}
236
237

  Stmt BuildPipeline() {
238
239
240
241
242
    // Step 1: Analyze accesses to the buffers in the pipeline and compute the
    // number of versions need to maintain for each buffer.
    std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
        infos = GetBufferAccessInfo();
    for (const Buffer &buffer : pipeline_allocs_) {
243
244
245
246
247
248
249
      int num_versions = ComputeBufferVersions(buffer, infos.at(buffer));
      if (num_versions > 1) {
        buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions));
      }
    }

    ordered_stmts_.resize(pipeline_info_.size());
250
251
    for (const auto &[block, anno] : pipeline_info_) {
      ordered_stmts_.Set(anno.order, block);
252
253
    }

254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
    for (const Block &block : ordered_stmts_) {
      int stage = pipeline_info_[block].stage;
      if (pipeline_info_[block].async) {
        auto &state = async_states[stage];
        state.producer_head = pipeline_loop_->min - 1;
        for (auto write_region : block->writes) {
          auto buffer = write_region->buffer;
          state.dst_buffers.insert(buffer.get());
          if (buffer_remap_.count(buffer))
            state.dst_buffers.insert(buffer_remap_[buffer].get());
        }
      }
    }
    std::unordered_set<int> consumed;
    for (const Block &block : ordered_stmts_) {
      int stage = pipeline_info_[block].stage;
      if (pipeline_info_[block].async) {
        auto &state = async_states[stage];
        if (state.commit_groups.empty() || consumed.count(stage)) {
          state.commit_groups.push_back({});
        }
        state.commit_groups.back().push_back(pipeline_info_[block].order);
        consumed.erase(stage);
        for (auto write_region : block->writes) {
          auto buffer = buffer_remap_.count(write_region->buffer)
                            ? buffer_remap_[write_region->buffer]
                            : write_region->buffer;
          state.buffer_to_commit_group_[buffer.get()] =
              state.commit_groups.size() - 1;
        }
      }

      for (auto read_region : block->reads) {
        for (const auto &[producer_stage_id, producer_state] : async_states) {
          if (producer_stage_id <= stage &&
              producer_state.writes(read_region->buffer)) {
            consumed.insert(producer_stage_id);
          }
        }
293
294
      }
    }
295
296
297
298
299
300
301
302
303
304

    // Step 2: Emit the pipeline prologue, body and epilogue.
    Stmt prologue = EmitImpl(pipeline_loop_->min,
                             pipeline_loop_->min + max_stage_, true, true);
    Stmt body =
        EmitImpl(pipeline_loop_->min + max_stage_,
                 pipeline_loop_->min + pipeline_loop_->extent, false, false);
    Stmt epilogue = EmitImpl(
        pipeline_loop_->min + pipeline_loop_->extent,
        pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true);
305
306
307

    SeqStmt stmt = SeqStmt({prologue, body, epilogue});

308
309
    // Step 3: Make a new block that contains new buffer allocations after
    // pipeline rewriting.
310
    Array<Buffer> alloc_buffers;
311
    for (const auto &alloc : pipeline_allocs_) {
312
313
314
315
316
317
318
319
      alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc));
      buffer_data_to_buffer_.erase(alloc->data);
    }
    Block block = MakeBlock(stmt, buffer_data_to_buffer_);
    block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers);
    return BlockRealize({}, Bool(true), block);
  }

320
private:
321
322
323
  /*!
   * \brief Analyze accesses to the buffers in the software pipeline.
   *
324
325
326
   * This method check the 'define' and 'use' stage of the buffers in the
   * software pipeline, which can be used to compute the number of versions
   * needed to maintain after rewriting.
327
328
329
   */
  std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
  GetBufferAccessInfo() {
330
331
332
333
    std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
        infos;
    for (const auto &pair : pipeline_info_) {
      const Block &block = pair.first;
334
335
336
      int stage = pair.second.stage;
      max_stage_ = std::max(max_stage_, stage);

337
      for (const BufferRegion &write : block->writes) {
338
339
340
        if (!infos.count(write->buffer)) {
          infos.emplace(write->buffer, BufferAccessInfo{});
        }
341
        auto &info = infos.at(write->buffer);
342
343
344
345
346
347
348
        if (info.def == -1) {
          info.def = stage;
        } else {
          info.def = std::min(info.def, stage);
        }
      }

349
      for (const BufferRegion &read : block->reads) {
350
351
352
        if (!infos.count(read->buffer)) {
          infos.emplace(read->buffer, BufferAccessInfo{});
        }
353
        auto &info = infos.at(read->buffer);
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
        info.use = std::max(info.use, stage);
      }
    }
    return infos;
  }

  /*!
   * \brief Check whether two regions have intersections.
   * \param region1 The first region.
   * \param region2 The second region.
   * \return Whether region1 and region2 have intersections.
   */
  bool MayConflict(Region region1, Region region2) {
    ICHECK(region1.size() == region2.size());
    for (size_t i = 0; i < region1.size(); i++) {
      Range dim1 = region1[i];
      Range dim2 = region2[i];
      auto int_set1 = arith::IntSet::FromRange(dim1);
      auto int_set2 = arith::IntSet::FromRange(dim2);
      if (arith::Intersect({int_set1, int_set2}).IsNothing()) {
        return false;
      }
    }
    return true;
  }

  /*!
381
382
   * \brief Compute the number of versions need to maintain for buffer accessed
   * in the software pipeline.
383
   *
384
385
386
387
388
389
   * This method applies liveness analysis to the target buffer to compute the
   * number of versions need to maintain during the software pipeline.
   * Annotation `attr::double_buffer_scope` is handled here which provides a way
   * to override the result of the analysis. Additional double buffering in the
   * software pipeline can be useful to eliminate synchronizations in GPU
   * devices.
390
391
392
393
394
   *
   * \param buffer The target buffer
   * \param buffer_info The access information of the target buffer.
   * \return The number of versions required for the target buffer.
   */
395
396
  int ComputeBufferVersions(const Buffer &buffer,
                            const BufferAccessInfo &buffer_info) {
397
    if (buffer_info.def == -1) {
398
399
      // Keep the original number of versions as buffers defined outside the
      // software pipeline should not be mutated.
400
401
402
403
      return 1;
    }

    // `use - def + 1` is a upper bound of the needed versions
404
405
    // We optimize a few case where the number of versions can be smaller than
    // the upper bound
406
    int num_versions = buffer_info.use - buffer_info.def + 1;
407
    if (num_versions >= 2) {
408
409
410
411
      // A special case when `use - def + 1 == 2`. Double buffering is only
      // needed in this case when these exists a reader block_i and a writer
      // block_j such that order(block_i) < order(block_j) and stage(block_i) <
      // stage(block_j) and the access regions of block_i and block_j overlap.
412
      bool need_multi_version = false;
413
414
415
      for (const auto &pair1 : pipeline_info_) {
        const Block &writer_block = pair1.first;
        const auto &writer_info = pair1.second;
416

417
418
419
        auto it1 = std::find_if(writer_block->writes.begin(),
                                writer_block->writes.end(),
                                [&](const BufferRegion &buffer_region) {
420
421
422
423
424
425
                                  return buffer_region->buffer.same_as(buffer);
                                });
        if (it1 == writer_block->writes.end()) {
          continue;
        }

426
427
428
429
430
431
432
433
        for (const auto &pair2 : pipeline_info_) {
          const Block &reader_block = pair2.first;
          const auto &reader_info = pair2.second;
          auto it2 = std::find_if(
              reader_block->reads.begin(), reader_block->reads.end(),
              [&](const BufferRegion &buffer_region) {
                return buffer_region->buffer.same_as(buffer);
              });
434
435
436
          if (it2 == reader_block->reads.end()) {
            continue;
          }
437
438
          if (writer_info.order < reader_info.order &&
              writer_info.stage < reader_info.stage &&
439
440
441
442
443
444
445
              MayConflict((*it1)->region, (*it2)->region)) {
            need_multi_version = true;
            break;
          }
        }
      }
      if (!need_multi_version) {
446
        num_versions--;
447
448
449
450
451
452
      }
    }
    return num_versions;
  }

  /*!
453
454
   * \brief Rewrite buffer allocation to keep multiple versions of original
   * buffer for pipelined accesses. \param buffer The buffer to be resized.
455
456
457
   * \param num_versions The number of versions to keep.
   * \return The resized buffer.
   */
458
  Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
459
460
461
462
463
464
465
466
467
468
    ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get()));
    new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
    if (new_buffer->strides.size()) {
      ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
      PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1];
      new_buffer->strides.insert(new_buffer->strides.begin(), stride_0);
    }
    return Buffer(new_buffer);
  }

469
470
  // Per-stage states that need to be tracked across pipeline prologue, body,
  // and epilogue.
471
472
  struct AsyncStateGlobal {
    // Buffers that this stage asynchronously writes.
473
474
475
476
477
478
479
    std::unordered_set<const BufferNode *> dst_buffers;
    // An imaginary index that the latest async operation associated with this
    // stage has written into. Only valid if all associated predicates are true,
    // so that we can count the number of async invocations exactly. When it is
    // valid, it is the "sum of extents of loops that have been executed" - 1,
    // e.g. for epilogue it is prologue extent + body extent - 1. This is only
    // needed to compute wait count for epilogue without async producers.
480
481
482
    PrimExpr producer_head;
    std::vector<std::vector<int>> commit_groups;
    std::unordered_map<const BufferNode *, int> buffer_to_commit_group_;
483
484
485
    bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; }
  };

486
487
  // Per-stage states that are local to each of pipeline prologue, body, and
  // epilogue.
488
  struct AsyncStateLocal {
489
    struct PendingWait {
490
491
      // The index into a list of blocks, where async_wait_queue should be
      // attached at the beginning.
492
      int insert_before;
493
494
      // in_flight_count would be a more precise name, but the implementation
      // uses wait_count for brevity.
495
496
497
      PrimExpr wait_count{nullptr};

      bool valid() const { return wait_count.defined(); }
498
499
500
    };

    std::vector<PendingWait> pending_waits;
501

502
503
    // A symbolic expression representing the index the latest async operation
    // associated with this stage has written into, at the "current" iteration.
504
505
506
507
508
509
    Optional<PrimExpr> producer_head;
  };

  /*! Structure holding intermediate information for pipeline loop rewriting. */
  struct RewrittenBlockInfo {
    int stage;
510
    int order;
511
512
513
514
515
516
    PrimExpr predicate;
    Block block;
    PrimExpr access_index;
    bool is_async;
  };

517
518
  void PopulateWaitCounts(const std::vector<RewrittenBlockInfo> &new_blocks,
                          std::map<int, AsyncStateLocal> *async_states_local) {
519
520
521
    for (size_t i = 0; i < new_blocks.size(); ++i) {
      int producer_stage_idx = -1;
      for (auto read_region : new_blocks[i].block->reads) {
522
523
524
        for (const auto &[stage, state] : async_states) {
          if (stage <= new_blocks[i].stage &&
              state.writes(read_region->buffer)) {
525
526
            // Found an earlier stage where read_region->buffer was
            // asynchronously written
527
            ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage)
528
                << "A dependency on multiple async stages is not supported";
529
            producer_stage_idx = stage;
530
531
532
          }
        }
      }
533
534
      if (producer_stage_idx == -1)
        continue;
535
      const auto &state = async_states[producer_stage_idx];
536
      auto &dep_local_state = (*async_states_local)[producer_stage_idx];
537
538
539
540
541
542
543
544
545
546
547
      PrimExpr in_flight_cnt = 0;
      for (const auto &group : state.commit_groups) {
        PrimExpr consumer_head = new_blocks[i].access_index;
        PrimExpr producer_head;
        if (dep_local_state.producer_head.defined()) {
          producer_head = dep_local_state.producer_head.value();
          // if the group is after the wait point, minus by 1
          if (group.front() > new_blocks[i].order)
            producer_head -= 1;
        } else {
          producer_head = state.producer_head;
548
        }
549
        in_flight_cnt += producer_head - consumer_head;
550
551
      }

552
553
554
555
556
557
558
559
560
561
562
563
      // We can relax the in-flight-count by the number of independent commit.
      std::unordered_set<int> dependent_groups;
      for (const auto &read_region : new_blocks[i].block->reads) {
        if (state.buffer_to_commit_group_.count(read_region->buffer.get()))
          dependent_groups.insert(
              state.buffer_to_commit_group_.at(read_region->buffer.get()));
      }
      for (int i = int(state.commit_groups.size()) - 1; i >= 0; i--) {
        if (dependent_groups.count(i) == 0)
          in_flight_cnt += 1;
        else
          break; // stop relaxing
564
      }
565
566
567
      in_flight_cnt = analyzer_.Simplify(in_flight_cnt);
      dep_local_state.pending_waits.push_back(
          {static_cast<int>(i), in_flight_cnt});
568
569
570
    }
  }

571
572
  // Given pipelined blocks and async-related information, generate final loop
  // statements with async scopes (if any).
573
  Array<Stmt> CompletePipelineLoopStatements(
574
      const std::vector<RewrittenBlockInfo> &blocks,
575
      const std::map<int, AsyncStateLocal> &async_states_local) const {
576
    std::vector<RewrittenBlockInfo> new_blocks = blocks;
577
    for (const auto &[stage_id, state] : async_states_local) {
578
579
580
581
582
583
584
      for (const auto &pw : state.pending_waits) {
        auto &block = new_blocks[pw.insert_before].block;
        BlockNode *n = block.CopyOnWrite();
        auto zero = make_zero(DataType::Int(32));
        n->body = AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
                           AttrStmt(zero, tir::attr::async_wait_inflight_count,
                                    pw.wait_count, n->body));
585
      }
586
    }
587

588
589
590
591
592
    // mark the last async stmt as commit
    std::unordered_set<int> commit_group_indices;
    for (const auto &[stage_id, state] : async_states) {
      for (size_t i = 0; i < state.commit_groups.size(); ++i) {
        commit_group_indices.insert(state.commit_groups[i].back());
593
594
595
596
597
      }
    }

    Array<Stmt> stmts;

598
599
600
601
602
603
604
    for (size_t i = 0; i < new_blocks.size(); i++) {
      Block block = new_blocks[i].block;
      if (commit_group_indices.count(new_blocks[i].order)) {
        auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)),
                                           tir::attr::async_commit_queue_scope,
                                           new_blocks[i].stage, block->body);
        block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_);
605
      }
606
      stmts.push_back(BlockRealize({}, new_blocks[i].predicate, block));
607
608
609
610
611
612
613
614
615
616
617
618
    }

    return stmts;
  }

  /*!
   * \brief Emit the pipeline loop in the given range.
   * \param start The start of the range
   * \param end The end of the range
   * \param unroll_loop Whether the loop should be unrolled.
   * \return The result loop.
   */
619
  Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop,
620
                bool need_bound_check) {
621
622
    PrimExpr new_loop_var;
    PrimExpr extent = end - start;
623
624
625
    auto make_nop = []() {
      return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {}));
    };
626
627
628

    bool is_unit_loop = analyzer_.CanProveEqual(extent, 1);
    if (is_unit_loop) {
629
      new_loop_var = start; // use constants as the loop var for unit loops
630
631
632
633
634
635
636
637
638
639
    } else {
      new_loop_var = pipeline_loop_->loop_var.copy_with_suffix("");
      analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end));
    }

    std::vector<RewrittenBlockInfo> new_blocks;

    // Async related
    std::map<int, AsyncStateLocal> async_states_local;

640
    for (const Block &block : ordered_stmts_) {
641
      int stage = pipeline_info_.at(block).stage;
642
643
      int order = pipeline_info_.at(block).order;
      PrimExpr inbound = Bool(true);
644
      PrimExpr skewed_loop_var = new_loop_var - stage;
645
646
647
648
      if (need_bound_check)
        inbound =
            analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) &&
            (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent);
649
650
651
      if (analyzer_.CanProve(!inbound)) {
        continue;
      }
652
653
654
      Block new_block = Downcast<Block>(
          PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_,
                               pipeline_loop_, max_stage_ != 1)(block));
655
656
657
658

      PrimExpr delta = start - pipeline_loop_->min;
      // This variable corresponds to
      // - "producer_head" if this stage is an async producer
659
660
      // - "consumer_head" if this stage reads from asynchronously written
      // buffers.
661
      PrimExpr normalized_access_index =
662
          is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
663

664
665
      // Adjust the block predicate and the body according to the final loop
      // bound
666
667
668
669
670
      //  [pipeline_loop_->min, extent).
      if (!is_unit_loop) {
        Var loop_iter = Downcast<Var>(new_loop_var);
        inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}});
      }
671
672
      new_block = Downcast<Block>(Substitute(
          new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
673
      if (pipeline_info_[block].async) {
674
        auto &local_state = async_states_local[stage];
675
        local_state.producer_head = normalized_access_index;
676
677
678
        BlockNode *n = new_block.CopyOnWrite();
        n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope,
                           1, n->body);
679
680
      }

681
682
      new_blocks.push_back({stage, order, inbound, new_block,
                            normalized_access_index,
683
                            pipeline_info_[block].async});
684
    }
685

686
687
688
    PopulateWaitCounts(new_blocks, &async_states_local);

    auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local);
689

690
691
    Stmt new_loop{nullptr};

692
    if (stmts.empty()) {
693
694
      return make_nop();
    }
695

696
697
    if (stmts.size() == 1) {
      new_loop = stmts[0];
698
    } else {
699
      new_loop = SeqStmt(stmts);
700
701
702
    }

    if (!is_unit_loop) {
703
704
705
706
707
708
709
710
711
      Map<String, Any> preserved_annotations;
      for (const auto &kv : pipeline_loop_->annotations) {
        const String &key = kv.first;
        if (kv.first != tir::attr::software_pipeline_stage &&
            kv.first != tir::attr::software_pipeline_order &&
            kv.first != tir::attr::software_pipeline_async_stages) {
          preserved_annotations.Set(key, kv.second);
        }
      }
712
      new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
713
                     unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind,
714
                     std::move(new_loop), std::nullopt, preserved_annotations);
715
716
    }
    // Update producer heads in the global async states.
717
718
    for (const auto &[stage_id, state] : async_states_local) {
      async_states[stage_id].producer_head += extent;
719
720
    }

721
722
    return BlockRealize({}, Bool(true),
                        MakeBlock(std::move(new_loop), buffer_data_to_buffer_));
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
  }

  arith::Analyzer analyzer_;
  Map<Var, Buffer> buffer_data_to_buffer_;
  Array<Buffer> pipeline_allocs_;
  For pipeline_loop_;
  PipelineInfo pipeline_info_;
  int max_stage_ = -1;
  Map<Buffer, Buffer> buffer_remap_;
  Array<Block> ordered_stmts_;
  std::map<int, AsyncStateGlobal> async_states;
};

/*!
 * \brief Build the dependency graph among a array of blocks.
 * \param[in] blocks The array of blocks.
739
740
741
 * \param[out] dep_src2dst Optional, a map to store dependency edges from the
 * source to the destination. \param[out] dep_dst2src Optional, a map to store
 * dependency edges from the destination to the source.
742
 */
743
744
745
746
747
void BuildDependencyGraph(const Array<Block> &blocks,
                          std::unordered_map<Block, Array<Block>, ObjectPtrHash,
                                             ObjectPtrEqual> *dep_src2dst,
                          std::unordered_map<Block, Array<Block>, ObjectPtrHash,
                                             ObjectPtrEqual> *dep_dst2src) {
748
749
  std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
      buffer_writers;
750
751
752

  for (const Block &block : blocks) {
    for (const BufferRegion &read : block->reads) {
753
754
      auto it = buffer_writers.find(read->buffer->data);
      if (it != buffer_writers.end()) {
755
        for (const Block &writer : it->second) {
756
757
758
759
760
761
762
763
764
          if (dep_src2dst != nullptr) {
            (*dep_src2dst)[writer].push_back(block);
          }
          if (dep_dst2src != nullptr) {
            (*dep_dst2src)[block].push_back(writer);
          }
        }
      }
    }
765
    for (const BufferRegion &write : block->writes) {
766
767
768
769
770
771
      buffer_writers[write->buffer->data].push_back(block);
    }
  }
}

class PipelineInjector : private StmtExprMutator {
772
773
public:
  static Stmt Inject(const PrimFunc &func) {
774
775
    auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
    PipelineInjector injector(global_symbol);
776
777
    for (const auto &kv : func->buffer_map) {
      const Buffer &buffer = kv.second;
778
779
780
781
782
      injector.buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    return injector(func->body);
  }

783
784
785
private:
  explicit PipelineInjector(Optional<String> global_symbol)
      : global_symbol_(global_symbol) {}
786
787
788
789

  /*!
   * \brief Check the pipeline satisfies the following conditions:
   * 1. No conflicting order: The order of each statement should be unique.
790
791
792
793
   * 2. Reordering of statements doesn't break buffer access dependencies.
   * Specifically, for dependency (e.g. read-after-write) from statement A to
   * statement B, it requires: case 1: stage(A) < stage(B) case 2: stage(A) ==
   * stage(B) and order(A) < order(B)
794
   */
795
796
  void ValidatePipelineBody(const PipelineInfo &pipeline_info,
                            const Array<Block> &original_order) {
797
798
    std::unordered_set<int> used_orders;
    std::unordered_map<int, int> stage_max_order;
799
800
801
802
    std::unordered_map<int, const Block *> order_to_block;
    std::unordered_map<const Block *, int> block_to_stage;
    for (const Block &block : original_order) {
      const auto &stmt_info = pipeline_info.at(block);
803
804
      int order = stmt_info.order;
      CHECK(!used_orders.count(order))
805
806
          << "ValueError: Two statements in the software pipeline cannot have "
             "the same order";
807
808
809
      used_orders.insert(order);
    }

810
811
    std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
        dep_src2dst;
812
813
    BuildDependencyGraph(original_order, &dep_src2dst, nullptr);

814
815
816
817
818
819
    for (const auto &pair : dep_src2dst) {
      const Block &src = pair.first;
      const auto &src_info = pipeline_info.at(src);
      const Array<Block> &dsts = pair.second;
      for (const Block &dst : dsts) {
        const auto &dst_info = pipeline_info.at(dst);
820
821
        CHECK_LE(src_info.stage, dst_info.stage)
            << "ValueError: statement " << dst << " in stage " << dst_info.stage
822
823
            << " cannot depends on statement " << src << " in a later stage "
            << src_info.stage;
824
        if (src_info.stage == dst_info.stage) {
825
826
827
828
          CHECK_LT(src_info.order, dst_info.order)
              << "ValueError: two statements with buffer "
                 "access dependency in the same stage of the "
                 "software pipeline cannot be reordered";
829
830
831
832
833
        }
      }
    }
  }

834
  Stmt VisitStmt_(const ForNode *op) final {
835
836
837
    // Step 1: Recursively rewrite the children first.
    For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
    if (!HasPipelineAnnotation(op)) {
838
      return for_node;
839
    }
840
841
842
    // Step 2: Find the body and buffer allocations of the pipeline. The body
    // can be direct child of the for-loop. If the for-loop has BlockRealize as
    // its child, the pipeline body will be the child of the block.
843
844
    Stmt pipeline_body{nullptr};
    Array<Buffer> pipeline_allocs;
845
846
847
    if (const auto *realize = for_node->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
848
849
850
        ICHECK(buffer->IsInstance<BufferNode>());
        buffer_data_to_buffer_.Set(buffer->data, buffer);
      }
851
      pipeline_body = block->body;
852
853
854
855
856
      pipeline_allocs = block->alloc_buffers;
    } else {
      pipeline_body = for_node->body;
    }

857
858
859
860
    const SeqStmtNode *pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
    CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
                                "should be SeqStmt, got "
                             << pipeline_body->GetTypeKey();
861

862
863
    // Step 3: Blockize the components of the pipeline. Each child of the
    // pipelined loop will be converted into a block.
864
    PipelineInfo pipeline_info;
865
    Array<Block> original_order; // pipeline body blocks in the original order
866

867
    auto f_add_child = [&](const Stmt &child) {
868
869
870
      original_order.push_back(MakeBlock(child, buffer_data_to_buffer_));
    };
    for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) {
871
872
      const auto *nested_block_realize =
          pipeline_body_seq->seq[i].as<BlockRealizeNode>();
873
874
      if (nested_block_realize && is_one(nested_block_realize->predicate) &&
          nested_block_realize->block->body->IsInstance<SeqStmtNode>()) {
875
876
877
878
        const Block &nested_pipeline_block = nested_block_realize->block;
        ICHECK(nested_pipeline_block->match_buffers
                   .empty()); // match_buffer should have been lowered
        for (const auto &buffer : nested_pipeline_block->alloc_buffers) {
879
880
881
          pipeline_allocs.push_back(buffer);
          buffer_data_to_buffer_.Set(buffer->data, buffer);
        }
882
        const auto *nested_seq = nested_pipeline_block->body.as<SeqStmtNode>();
883
884
885
886
887
888
889
890
        for (size_t j = 0; j < nested_seq->seq.size(); j++) {
          f_add_child(nested_seq->seq[j]);
        }
      } else {
        f_add_child(pipeline_body_seq->seq[i]);
      }
    }

891
892
893
894
    auto pipeline_stages = Downcast<Array<Integer>>(
        op->annotations.at(tir::attr::software_pipeline_stage));
    auto pipeline_orders = Downcast<Array<Integer>>(
        op->annotations.at(tir::attr::software_pipeline_order));
895
896
    CHECK_EQ(pipeline_stages.size(), original_order.size())
        << "PrimFunc " << global_symbol_ << " has original order "
897
898
899
900
        << original_order.Map(
               [](const auto &block) { return block->name_hint; })
        << ", but pipeline annotation is " << pipeline_stages
        << " with different size";
901
902
    CHECK_EQ(pipeline_orders.size(), original_order.size())
        << "PrimFunc " << global_symbol_ << " has original order "
903
904
905
906
        << original_order.Map(
               [](const auto &block) { return block->name_hint; })
        << ", but pipeline annotation is " << pipeline_orders
        << " with different size";
907
908

    std::unordered_set<int> pipeline_async_stages;
909
910
    if (auto annot =
            op->annotations.Get(tir::attr::software_pipeline_async_stages)) {
911
      for (auto s : Downcast<Array<Integer>>(annot.value())) {
912
913
914
915
916
917
        pipeline_async_stages.insert(s->value);
      }
    }

    for (size_t i = 0; i < pipeline_stages.size(); i++) {
      int stage = static_cast<int>(pipeline_stages[i]->value);
918
919
920
921
922
      bool is_async =
          pipeline_async_stages.find(stage) != pipeline_async_stages.end();
      PipelineAnnotation stage_order{
          stage,
          /*order=*/static_cast<int>(pipeline_orders[i]->value), is_async};
923
924
925
926
927
928
      pipeline_info.emplace(original_order[i], stage_order);
    }

    ValidatePipelineBody(pipeline_info, original_order);

    // Step 4: Rewrite the pipeline body.
929
930
931
    Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
                                     GetRef<For>(op), pipeline_info)
                        .BuildPipeline();
932

933
934
935
    if (const auto *realize = op->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
936
937
938
939
940
941
        buffer_data_to_buffer_.erase(buffer->data);
      }
    }
    return pipeline;
  }

942
943
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const auto &buffer : op->alloc_buffers) {
944
945
946
947
948
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }

    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));

949
    for (const auto &buffer : op->alloc_buffers) {
950
951
      buffer_data_to_buffer_.erase(buffer->data);
    }
952
    return block;
953
954
  }

955
  bool HasPipelineAnnotation(const ForNode *op) const {
956
957
958
959
960
961
962
963
    auto it1 = op->annotations.find(tir::attr::software_pipeline_stage);
    auto it2 = op->annotations.find(tir::attr::software_pipeline_order);
    bool has_stage = it1 != op->annotations.end();
    bool has_order = it2 != op->annotations.end();
    if (has_stage && has_order) {
      return true;
    }
    if (has_stage) {
964
      LOG(FATAL)
965
          << "ValueError: Stage of the software pipeline is not defined.";
966
967
    }
    if (has_order) {
968
      LOG(FATAL)
969
          << "ValueError: Order of the software pipeline is not defined.";
970
971
972
973
974
975
976
    }
    return false;
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Optional<String> global_symbol_;
};
977
978
} // namespace software_pipeline

979
/*!
980
981
 * \brief Transform annotated loops into pipelined one that parallelize
 * producers and consumers. \return The IR transform pass.
982
983
984
985
 */
tir::transform::Pass InjectSoftwarePipeline() {
  using namespace tir::transform;
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
986
    auto *fptr = f.CopyOnWrite();
987
    fptr->body = software_pipeline::PipelineInjector::Inject(f);
988
989
990
991
992
993
    fptr->body = ConvertSSA(std::move(fptr->body));
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {});
}

994
995
996
997
998
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.InjectSoftwarePipeline",
                        InjectSoftwarePipeline);
});
999

1000
1001
} // namespace tl
} // namespace tvm