"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "e09c58ee1a98c7d89cbb298f35de931e91ff0cec"
inject_pipeline.cc 38.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file inject_software_pipeline.cc
22
23
 * \brief Transform annotated loops into pipelined one that parallelize
 * producers and consumers
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
 */
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h>

#include <unordered_set>

#include "support/utils.h"
#include "tir/schedule/utils.h"
#include "tir/transforms/ir_utils.h"

namespace tvm {
namespace tl {
using namespace tir;

/*!
 * \brief Create a block and infer the access region with the given body.
 *
42
43
44
 * The result is a opaque block that doesn't contain any block iter vars. In
 * case the body is a block realize without predicate, it is unnecessary to
 * create a new block, the block of the block realize will be returned.
45
46
47
48
49
 *
 * \param body The body of the block.
 * \param buffer_data_to_buffer The map from buffer data to buffer.
 * \return The result block.
 */
50
51
52
Block MakeBlock(const Stmt &body,
                const Map<Var, Buffer> &buffer_data_to_buffer) {
  if (const BlockRealizeNode *block_realize = body.as<BlockRealizeNode>()) {
53
54
55
56
57
    if (is_one(block_realize->predicate)) {
      // no need to create a new block
      return block_realize->block;
    }
  }
58
59
60
61
62
  Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
              /*body*/ body);
  Array<Array<BufferRegion>> access =
      GetBlockReadWriteRegion(block, buffer_data_to_buffer);
  BlockNode *n = block.CopyOnWrite();
63
64
65
66
67
68
69
70
71
72
73
74
  n->reads = access[0];
  n->writes = access[1];
  return block;
}

/*! Structure that represents the provided annotation per block or loop. */
struct PipelineAnnotation {
  int stage;
  int order;
  bool async;
};

75
76
using PipelineInfo = std::unordered_map<Block, PipelineAnnotation,
                                        ObjectPtrHash, ObjectPtrEqual>;
77
78

struct BufferAccessInfo {
79
80
  int def = -1; // the defining stage of the buffer
  int use = -1; // the last using stage of the buffer
81
82
83
};

/*!
84
85
86
 * \brief Rewriter for the body of the software pipeline. This pass inserts
 * `floormod` to indices of the remapped buffer to select the version
 * corresponding to the pipeline stage.
87
88
 */
class PipelineBodyRewriter : public StmtExprMutator {
89
public:
90
91
92
  /*!
   * \brief Constructor of PipelineBodyRewriter.
   * \param buffer_data_to_buffer The map from buffer data to buffer.
93
94
95
96
97
98
99
   * \param buffer_remap The map from original buffer to the buffer with updated
   * shape for multi-versioning in the software pipeline. \param pipeline_loop
   * The original loop to be software pipelined. \param access_all_versions
   * Whether all versions the buffers in the software pipeline are accessed.
   * This will be used to update block access region. In the prologue and
   * epilogue of a two-stage software pipeline, only one version of these
   * buffers are accessed.
100
   */
101
102
103
  PipelineBodyRewriter(const Map<Var, Buffer> &buffer_data_to_buffer,
                       const Map<Buffer, Buffer> &buffer_remap,
                       For pipeline_loop, bool access_all_versions)
104
      : buffer_data_to_buffer_(buffer_data_to_buffer),
105
        buffer_remap_(buffer_remap), pipeline_loop_(pipeline_loop),
106
107
        access_all_versions_(access_all_versions) {}

108
109
110
private:
  BufferRegion
  RewritePipelineBufferRegion(const BufferRegion &buffer_region) const {
111
112
113
    auto it = buffer_remap_.find(buffer_region->buffer);
    if (it != buffer_remap_.end()) {
      Region new_region = buffer_region->region;
114
115
116
      const Buffer &new_buffer = (*it).second;
      // For pipeline buffers, relax the access region of the first dimension to
      // full extent if access_all_versions == true
117
118
119
      Range accessed_version =
          access_all_versions_
              ? Range::FromMinExtent(0, new_buffer->shape[0])
120
121
122
123
              : Range::FromMinExtent(
                    floormod((pipeline_loop_->loop_var - pipeline_loop_->min),
                             new_buffer->shape[0]),
                    Integer(1));
124
125
126
127
128
129
      new_region.insert(new_region.begin(), accessed_version);
      return BufferRegion(new_buffer, new_region);
    }
    return buffer_region;
  }

130
131
132
133
134
135
  PrimExpr RewriteBufferAccess(const Call &call,
                               const std::vector<int> arg_indices) {
    auto product = [](const Array<PrimExpr> &input) {
      return foldl(
          [](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
          make_const(DataType::Int(32), 1), input);
136
137
138
    };
    Array<PrimExpr> new_args = call->args;
    for (int i : arg_indices) {
139
140
      const Buffer &buffer =
          buffer_data_to_buffer_.at(Downcast<Var>(call->args[i]));
141
142
      auto it = buffer_remap_.find(buffer);
      if (it != buffer_remap_.end()) {
143
144
        const Buffer &new_buffer = (*it).second;
        const PrimExpr &old_index = call->args[i + 1];
145
146
147
148
149
150
151
        PrimExpr offset;
        if (new_buffer->strides.empty()) {
          offset = product(buffer->shape);
        } else {
          offset = new_buffer->strides[0];
        }
        PrimExpr new_index =
152
153
            old_index +
            floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
154
155
156
157
158
159
        new_args.Set(i + 1, new_index);
      }
    }
    return Call(call->dtype, call->op, new_args, call->span);
  }

160
161
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const Buffer &alloc_buffer : op->alloc_buffers) {
162
163
164
      buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer);
    }
    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
165
166
    BlockNode *n = block.CopyOnWrite();
    n->reads.MutateByApply([this](const BufferRegion &buffer_region) {
167
168
      return RewritePipelineBufferRegion(buffer_region);
    });
169
    n->writes.MutateByApply([this](const BufferRegion &buffer_region) {
170
171
      return RewritePipelineBufferRegion(buffer_region);
    });
172
    for (const Buffer &alloc_buffer : op->alloc_buffers) {
173
174
175
176
177
      buffer_data_to_buffer_.erase(alloc_buffer->data);
    }
    return std::move(block);
  }

178
  Stmt VisitStmt_(const BufferStoreNode *op) final {
179
180
181
182
183
    BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
    auto it = buffer_remap_.find(store->buffer);
    if (it == buffer_remap_.end()) {
      return std::move(store);
    }
184
185
    const Buffer &new_buffer = (*it).second;
    auto *n = store.CopyOnWrite();
186
    n->buffer = new_buffer;
187
188
    PrimExpr version = floormod(
        (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
189
190
191
192
    n->indices.insert(n->indices.begin(), version);
    return std::move(store);
  }

193
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
194
195
196
197
198
    BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
    auto it = buffer_remap_.find(load->buffer);
    if (it == buffer_remap_.end()) {
      return std::move(load);
    }
199
200
    const Buffer &new_buffer = (*it).second;
    auto *n = load.CopyOnWrite();
201
    n->buffer = new_buffer;
202
203
    PrimExpr version = floormod(
        (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
204
205
206
207
    n->indices.insert(n->indices.begin(), version);
    return std::move(load);
  }

208
  PrimExpr VisitExpr_(const CallNode *op) final {
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
    if (call->op.same_as(builtin::tvm_access_ptr())) {
      return RewriteBufferAccess(call, {1});
    }
    return call;
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Map<Buffer, Buffer> buffer_remap_;
  For pipeline_loop_;
  bool access_all_versions_;
};

/*!
223
224
 * \brief Rewriter for the software pipeline that rewrite a loop into a
 * pipelined one.
225
226
 */
class PipelineRewriter : public StmtExprMutator {
227
228
229
public:
  PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
                   const Array<Buffer> &pipeline_allocs,
230
231
                   const For &pipeline_loop, const PipelineInfo &pipeline_info,
                   PrimExpr predicate_condition = PrimExpr())
232
      : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
233
        pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop),
234
235
        pipeline_info_(pipeline_info),
        predicate_condition_(predicate_condition) {}
236
237

  Stmt BuildPipeline() {
238
239
240
241
242
    // Step 1: Analyze accesses to the buffers in the pipeline and compute the
    // number of versions need to maintain for each buffer.
    std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
        infos = GetBufferAccessInfo();
    for (const Buffer &buffer : pipeline_allocs_) {
243
244
245
246
247
248
249
      int num_versions = ComputeBufferVersions(buffer, infos.at(buffer));
      if (num_versions > 1) {
        buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions));
      }
    }

    ordered_stmts_.resize(pipeline_info_.size());
250
    for (const auto &[block, anno] : pipeline_info_) {
251
252
253
      ordered_stmts_.Set(anno.order, block);
    }

254
    for (const Block &block : ordered_stmts_) {
255
256
      int stage = pipeline_info_[block].stage;
      if (pipeline_info_[block].async) {
257
        auto &state = async_states[stage];
258
259
260
261
        state.producer_head = pipeline_loop_->min - 1;
        for (auto write_region : block->writes) {
          auto buffer = write_region->buffer;
          state.dst_buffers.insert(buffer.get());
262
263
          if (buffer_remap_.count(buffer))
            state.dst_buffers.insert(buffer_remap_[buffer].get());
264
265
266
267
        }
      }
    }
    std::unordered_set<int> consumed;
268
    for (const Block &block : ordered_stmts_) {
269
270
      int stage = pipeline_info_[block].stage;
      if (pipeline_info_[block].async) {
271
        auto &state = async_states[stage];
272
273
274
275
276
277
278
279
280
        if (state.commit_groups.empty() || consumed.count(stage)) {
          state.commit_groups.push_back({});
        }
        state.commit_groups.back().push_back(pipeline_info_[block].order);
        consumed.erase(stage);
        for (auto write_region : block->writes) {
          auto buffer = buffer_remap_.count(write_region->buffer)
                            ? buffer_remap_[write_region->buffer]
                            : write_region->buffer;
281
282
          state.buffer_to_commit_group_[buffer.get()] =
              state.commit_groups.size() - 1;
283
284
285
286
        }
      }

      for (auto read_region : block->reads) {
287
288
289
        for (const auto &[producer_stage_id, producer_state] : async_states) {
          if (producer_stage_id <= stage &&
              producer_state.writes(read_region->buffer)) {
290
291
292
293
294
295
296
            consumed.insert(producer_stage_id);
          }
        }
      }
    }

    // Step 2: Emit the pipeline prologue, body and epilogue.
297
298
299
300
301
302
303
304
    Stmt prologue = EmitImpl(pipeline_loop_->min,
                             pipeline_loop_->min + max_stage_, true, true);
    Stmt body =
        EmitImpl(pipeline_loop_->min + max_stage_,
                 pipeline_loop_->min + pipeline_loop_->extent, false, false);
    Stmt epilogue = EmitImpl(
        pipeline_loop_->min + pipeline_loop_->extent,
        pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true);
305
306
307

    SeqStmt stmt = SeqStmt({prologue, body, epilogue});

308
309
    // Step 3: Make a new block that contains new buffer allocations after
    // pipeline rewriting.
310
    Array<Buffer> alloc_buffers;
311
    for (const auto &alloc : pipeline_allocs_) {
312
313
314
315
316
317
318
319
      alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc));
      buffer_data_to_buffer_.erase(alloc->data);
    }
    Block block = MakeBlock(stmt, buffer_data_to_buffer_);
    block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers);
    return BlockRealize({}, Bool(true), block);
  }

320
private:
321
322
323
  /*!
   * \brief Analyze accesses to the buffers in the software pipeline.
   *
324
325
326
   * This method check the 'define' and 'use' stage of the buffers in the
   * software pipeline, which can be used to compute the number of versions
   * needed to maintain after rewriting.
327
328
329
   */
  std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
  GetBufferAccessInfo() {
330
331
332
333
    std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
        infos;
    for (const auto &pair : pipeline_info_) {
      const Block &block = pair.first;
334
335
336
      int stage = pair.second.stage;
      max_stage_ = std::max(max_stage_, stage);

337
      for (const BufferRegion &write : block->writes) {
338
339
340
        if (!infos.count(write->buffer)) {
          infos.emplace(write->buffer, BufferAccessInfo{});
        }
341
        auto &info = infos.at(write->buffer);
342
343
344
345
346
347
348
        if (info.def == -1) {
          info.def = stage;
        } else {
          info.def = std::min(info.def, stage);
        }
      }

349
      for (const BufferRegion &read : block->reads) {
350
351
352
        if (!infos.count(read->buffer)) {
          infos.emplace(read->buffer, BufferAccessInfo{});
        }
353
        auto &info = infos.at(read->buffer);
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
        info.use = std::max(info.use, stage);
      }
    }
    return infos;
  }

  /*!
   * \brief Check whether two regions have intersections.
   * \param region1 The first region.
   * \param region2 The second region.
   * \return Whether region1 and region2 have intersections.
   */
  bool MayConflict(Region region1, Region region2) {
    ICHECK(region1.size() == region2.size());
    for (size_t i = 0; i < region1.size(); i++) {
      Range dim1 = region1[i];
      Range dim2 = region2[i];
      auto int_set1 = arith::IntSet::FromRange(dim1);
      auto int_set2 = arith::IntSet::FromRange(dim2);
      if (arith::Intersect({int_set1, int_set2}).IsNothing()) {
        return false;
      }
    }
    return true;
  }

  /*!
381
382
   * \brief Compute the number of versions need to maintain for buffer accessed
   * in the software pipeline.
383
   *
384
385
386
387
388
389
   * This method applies liveness analysis to the target buffer to compute the
   * number of versions need to maintain during the software pipeline.
   * Annotation `attr::double_buffer_scope` is handled here which provides a way
   * to override the result of the analysis. Additional double buffering in the
   * software pipeline can be useful to eliminate synchronizations in GPU
   * devices.
390
391
392
393
394
   *
   * \param buffer The target buffer
   * \param buffer_info The access information of the target buffer.
   * \return The number of versions required for the target buffer.
   */
395
396
  int ComputeBufferVersions(const Buffer &buffer,
                            const BufferAccessInfo &buffer_info) {
397
    if (buffer_info.def == -1) {
398
399
      // Keep the original number of versions as buffers defined outside the
      // software pipeline should not be mutated.
400
401
402
403
      return 1;
    }

    // `use - def + 1` is a upper bound of the needed versions
404
405
    // We optimize a few case where the number of versions can be smaller than
    // the upper bound
406
407
    int num_versions = buffer_info.use - buffer_info.def + 1;
    if (num_versions >= 2) {
408
409
410
411
      // A special case when `use - def + 1 == 2`. Double buffering is only
      // needed in this case when these exists a reader block_i and a writer
      // block_j such that order(block_i) < order(block_j) and stage(block_i) <
      // stage(block_j) and the access regions of block_i and block_j overlap.
412
      bool need_multi_version = false;
413
414
415
      for (const auto &pair1 : pipeline_info_) {
        const Block &writer_block = pair1.first;
        const auto &writer_info = pair1.second;
416

417
418
419
        auto it1 = std::find_if(writer_block->writes.begin(),
                                writer_block->writes.end(),
                                [&](const BufferRegion &buffer_region) {
420
421
422
423
424
425
                                  return buffer_region->buffer.same_as(buffer);
                                });
        if (it1 == writer_block->writes.end()) {
          continue;
        }

426
427
428
429
430
431
432
433
        for (const auto &pair2 : pipeline_info_) {
          const Block &reader_block = pair2.first;
          const auto &reader_info = pair2.second;
          auto it2 = std::find_if(
              reader_block->reads.begin(), reader_block->reads.end(),
              [&](const BufferRegion &buffer_region) {
                return buffer_region->buffer.same_as(buffer);
              });
434
435
436
          if (it2 == reader_block->reads.end()) {
            continue;
          }
437
438
          if (writer_info.order < reader_info.order &&
              writer_info.stage < reader_info.stage &&
439
440
441
442
443
444
445
446
447
448
449
450
451
452
              MayConflict((*it1)->region, (*it2)->region)) {
            need_multi_version = true;
            break;
          }
        }
      }
      if (!need_multi_version) {
        num_versions--;
      }
    }
    return num_versions;
  }

  /*!
453
454
   * \brief Rewrite buffer allocation to keep multiple versions of original
   * buffer for pipelined accesses. \param buffer The buffer to be resized.
455
456
457
   * \param num_versions The number of versions to keep.
   * \return The resized buffer.
   */
458
  Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
459
460
461
462
463
464
465
466
467
468
    ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get()));
    new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
    if (new_buffer->strides.size()) {
      ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
      PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1];
      new_buffer->strides.insert(new_buffer->strides.begin(), stride_0);
    }
    return Buffer(new_buffer);
  }

469
470
  // Per-stage states that need to be tracked across pipeline prologue, body,
  // and epilogue.
471
472
  struct AsyncStateGlobal {
    // Buffers that this stage asynchronously writes.
473
474
475
476
477
478
479
    std::unordered_set<const BufferNode *> dst_buffers;
    // An imaginary index that the latest async operation associated with this
    // stage has written into. Only valid if all associated predicates are true,
    // so that we can count the number of async invocations exactly. When it is
    // valid, it is the "sum of extents of loops that have been executed" - 1,
    // e.g. for epilogue it is prologue extent + body extent - 1. This is only
    // needed to compute wait count for epilogue without async producers.
480
481
    PrimExpr producer_head;
    std::vector<std::vector<int>> commit_groups;
482
    std::unordered_map<const BufferNode *, int> buffer_to_commit_group_;
483
484
485
    bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; }
  };

486
487
  // Per-stage states that are local to each of pipeline prologue, body, and
  // epilogue.
488
489
  struct AsyncStateLocal {
    struct PendingWait {
490
491
      // The index into a list of blocks, where async_wait_queue should be
      // attached at the beginning.
492
      int insert_before;
493
494
      // in_flight_count would be a more precise name, but the implementation
      // uses wait_count for brevity.
495
496
497
498
499
500
501
      PrimExpr wait_count{nullptr};

      bool valid() const { return wait_count.defined(); }
    };

    std::vector<PendingWait> pending_waits;

502
503
    // A symbolic expression representing the index the latest async operation
    // associated with this stage has written into, at the "current" iteration.
504
505
506
507
508
509
510
511
512
513
514
515
516
    Optional<PrimExpr> producer_head;
  };

  /*! Structure holding intermediate information for pipeline loop rewriting. */
  struct RewrittenBlockInfo {
    int stage;
    int order;
    PrimExpr predicate;
    Block block;
    PrimExpr access_index;
    bool is_async;
  };

517
518
  void PopulateWaitCounts(const std::vector<RewrittenBlockInfo> &new_blocks,
                          std::map<int, AsyncStateLocal> *async_states_local) {
519
520
521
    for (size_t i = 0; i < new_blocks.size(); ++i) {
      int producer_stage_idx = -1;
      for (auto read_region : new_blocks[i].block->reads) {
522
523
524
525
526
        for (const auto &[stage, state] : async_states) {
          if (stage <= new_blocks[i].stage &&
              state.writes(read_region->buffer)) {
            // Found an earlier stage where read_region->buffer was
            // asynchronously written
527
528
529
530
531
532
            ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage)
                << "A dependency on multiple async stages is not supported";
            producer_stage_idx = stage;
          }
        }
      }
533
534
535
536
      if (producer_stage_idx == -1)
        continue;
      const auto &state = async_states[producer_stage_idx];
      auto &dep_local_state = (*async_states_local)[producer_stage_idx];
537
      PrimExpr in_flight_cnt = 0;
538
      for (const auto &group : state.commit_groups) {
539
540
541
542
543
        PrimExpr consumer_head = new_blocks[i].access_index;
        PrimExpr producer_head;
        if (dep_local_state.producer_head.defined()) {
          producer_head = dep_local_state.producer_head.value();
          // if the group is after the wait point, minus by 1
544
545
          if (group.front() > new_blocks[i].order)
            producer_head -= 1;
546
547
548
549
550
551
552
553
        } else {
          producer_head = state.producer_head;
        }
        in_flight_cnt += producer_head - consumer_head;
      }

      // We can relax the in-flight-count by the number of independent commit.
      std::unordered_set<int> dependent_groups;
554
      for (const auto &read_region : new_blocks[i].block->reads) {
555
        if (state.buffer_to_commit_group_.count(read_region->buffer.get()))
556
557
          dependent_groups.insert(
              state.buffer_to_commit_group_.at(read_region->buffer.get()));
558
559
560
561
562
      }
      for (int i = int(state.commit_groups.size()) - 1; i >= 0; i--) {
        if (dependent_groups.count(i) == 0)
          in_flight_cnt += 1;
        else
563
          break; // stop relaxing
564
565
      }
      in_flight_cnt = analyzer_.Simplify(in_flight_cnt);
566
567
      dep_local_state.pending_waits.push_back(
          {static_cast<int>(i), in_flight_cnt});
568
569
570
    }
  }

571
572
  // Given pipelined blocks and async-related information, generate final loop
  // statements with async scopes (if any).
573
  Array<Stmt> CompletePipelineLoopStatements(
574
575
      const std::vector<RewrittenBlockInfo> &blocks,
      const std::map<int, AsyncStateLocal> &async_states_local) const {
576
    std::vector<RewrittenBlockInfo> new_blocks = blocks;
577
578
579
580
    for (const auto &[stage_id, state] : async_states_local) {
      for (const auto &pw : state.pending_waits) {
        auto &block = new_blocks[pw.insert_before].block;
        BlockNode *n = block.CopyOnWrite();
581
        auto zero = make_zero(DataType::Int(32));
582
583
584
        n->body = AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
                           AttrStmt(zero, tir::attr::async_wait_inflight_count,
                                    pw.wait_count, n->body));
585
586
587
588
589
      }
    }

    // mark the last async stmt as commit
    std::unordered_set<int> commit_group_indices;
590
    for (const auto &[stage_id, state] : async_states) {
591
592
593
594
595
596
597
598
599
600
      for (size_t i = 0; i < state.commit_groups.size(); ++i) {
        commit_group_indices.insert(state.commit_groups[i].back());
      }
    }

    Array<Stmt> stmts;

    for (size_t i = 0; i < new_blocks.size(); i++) {
      Block block = new_blocks[i].block;
      if (commit_group_indices.count(new_blocks[i].order)) {
601
602
603
        auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)),
                                           tir::attr::async_commit_queue_scope,
                                           new_blocks[i].stage, block->body);
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
        block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_);
      }
      stmts.push_back(BlockRealize({}, new_blocks[i].predicate, block));
    }

    return stmts;
  }

  /*!
   * \brief Emit the pipeline loop in the given range.
   * \param start The start of the range
   * \param end The end of the range
   * \param unroll_loop Whether the loop should be unrolled.
   * \return The result loop.
   */
619
620
  Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop,
                bool need_bound_check) {
621
622
623
    PrimExpr new_loop_var;
    PrimExpr extent = end - start;

624
625
626
    auto make_nop = []() {
      return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {}));
    };
627
628
629

    bool is_unit_loop = analyzer_.CanProveEqual(extent, 1);
    if (is_unit_loop) {
630
      new_loop_var = start; // use constants as the loop var for unit loops
631
632
633
634
635
636
637
638
639
    } else {
      new_loop_var = pipeline_loop_->loop_var.copy_with_suffix("");
      analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end));
    }

    std::vector<RewrittenBlockInfo> new_blocks;

    // Async related
    std::map<int, AsyncStateLocal> async_states_local;
640
    PrimExpr normalized_access_index;
641

642
    for (const Block &block : ordered_stmts_) {
643
644
645
646
647
      int stage = pipeline_info_.at(block).stage;
      int order = pipeline_info_.at(block).order;
      PrimExpr inbound = Bool(true);
      PrimExpr skewed_loop_var = new_loop_var - stage;
      if (need_bound_check)
648
649
650
        inbound =
            analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) &&
            (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent);
651
652
653
      if (analyzer_.CanProve(!inbound)) {
        continue;
      }
654
655
656
      Block new_block = Downcast<Block>(
          PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_,
                               pipeline_loop_, max_stage_ != 1)(block));
657
658
659
660

      PrimExpr delta = start - pipeline_loop_->min;
      // This variable corresponds to
      // - "producer_head" if this stage is an async producer
661
662
      // - "consumer_head" if this stage reads from asynchronously written
      // buffers.
663
      normalized_access_index =
664
          is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
665

666
667
      // Adjust the block predicate and the body according to the final loop
      // bound
668
669
670
671
672
      //  [pipeline_loop_->min, extent).
      if (!is_unit_loop) {
        Var loop_iter = Downcast<Var>(new_loop_var);
        inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}});
      }
673
674
      new_block = Downcast<Block>(Substitute(
          new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
675
676
677
678
679
680
681
      if (predicate_condition_.defined()) {
        BlockNode *n = new_block.CopyOnWrite();
        n->body = IfThenElse(
            Substitute(predicate_condition_,
                       {{pipeline_loop_->loop_var, normalized_access_index}}),
            n->body);
      }
682
      if (pipeline_info_[block].async) {
683
        auto &local_state = async_states_local[stage];
684
        local_state.producer_head = normalized_access_index;
685
686
687
        BlockNode *n = new_block.CopyOnWrite();
        n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope,
                           1, n->body);
688
689
      }

690
691
692
      new_blocks.push_back({stage, order, inbound, new_block,
                            normalized_access_index,
                            pipeline_info_[block].async});
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
    }

    PopulateWaitCounts(new_blocks, &async_states_local);
    auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local);
    Stmt new_loop{nullptr};

    if (stmts.empty()) {
      return make_nop();
    }
    if (stmts.size() == 1) {
      new_loop = stmts[0];
    } else {
      new_loop = SeqStmt(stmts);
    }

    if (!is_unit_loop) {
      Map<String, ObjectRef> preserved_annotations;
710
711
      for (const auto &kv : pipeline_loop_->annotations) {
        const String &key = kv.first;
712
713
714
715
716
717
718
        if (kv.first != tir::attr::software_pipeline_stage &&
            kv.first != tir::attr::software_pipeline_order &&
            kv.first != tir::attr::software_pipeline_async_stages) {
          preserved_annotations.Set(key, kv.second);
        }
      }
      new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
719
720
                     unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind,
                     std::move(new_loop), NullOpt, preserved_annotations);
721
722
    }
    // Update producer heads in the global async states.
723
    for (const auto &[stage_id, state] : async_states_local) {
724
725
726
      async_states[stage_id].producer_head += extent;
    }

727
728
    return BlockRealize({}, Bool(true),
                        MakeBlock(std::move(new_loop), buffer_data_to_buffer_));
729
730
731
732
733
734
735
  }

  arith::Analyzer analyzer_;
  Map<Var, Buffer> buffer_data_to_buffer_;
  Array<Buffer> pipeline_allocs_;
  For pipeline_loop_;
  PipelineInfo pipeline_info_;
736
  PrimExpr predicate_condition_;
737
738
739
740
741
742
743
744
745
  int max_stage_ = -1;
  Map<Buffer, Buffer> buffer_remap_;
  Array<Block> ordered_stmts_;
  std::map<int, AsyncStateGlobal> async_states;
};

/*!
 * \brief Build the dependency graph among a array of blocks.
 * \param[in] blocks The array of blocks.
746
747
748
 * \param[out] dep_src2dst Optional, a map to store dependency edges from the
 * source to the destination. \param[out] dep_dst2src Optional, a map to store
 * dependency edges from the destination to the source.
749
 */
750
751
752
753
754
755
756
757
758
759
void BuildDependencyGraph(const Array<Block> &blocks,
                          std::unordered_map<Block, Array<Block>, ObjectPtrHash,
                                             ObjectPtrEqual> *dep_src2dst,
                          std::unordered_map<Block, Array<Block>, ObjectPtrHash,
                                             ObjectPtrEqual> *dep_dst2src) {
  std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
      buffer_writers;

  for (const Block &block : blocks) {
    for (const BufferRegion &read : block->reads) {
760
761
      auto it = buffer_writers.find(read->buffer->data);
      if (it != buffer_writers.end()) {
762
        for (const Block &writer : it->second) {
763
764
765
766
767
768
769
770
771
          if (dep_src2dst != nullptr) {
            (*dep_src2dst)[writer].push_back(block);
          }
          if (dep_dst2src != nullptr) {
            (*dep_dst2src)[block].push_back(writer);
          }
        }
      }
    }
772
    for (const BufferRegion &write : block->writes) {
773
774
775
776
777
778
      buffer_writers[write->buffer->data].push_back(block);
    }
  }
}

class PipelineInjector : private StmtExprMutator {
779
780
public:
  static Stmt Inject(const PrimFunc &func) {
781
782
    auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
    PipelineInjector injector(global_symbol);
783
784
    for (const auto &kv : func->buffer_map) {
      const Buffer &buffer = kv.second;
785
786
787
788
789
      injector.buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    return injector(func->body);
  }

790
791
792
private:
  explicit PipelineInjector(Optional<String> global_symbol)
      : global_symbol_(global_symbol) {}
793
794
795
796

  /*!
   * \brief Check the pipeline satisfies the following conditions:
   * 1. No conflicting order: The order of each statement should be unique.
797
798
799
800
   * 2. Reordering of statements doesn't break buffer access dependencies.
   * Specifically, for dependency (e.g. read-after-write) from statement A to
   * statement B, it requires: case 1: stage(A) < stage(B) case 2: stage(A) ==
   * stage(B) and order(A) < order(B)
801
   */
802
803
  void ValidatePipelineBody(const PipelineInfo &pipeline_info,
                            const Array<Block> &original_order) {
804
805
    std::unordered_set<int> used_orders;
    std::unordered_map<int, int> stage_max_order;
806
807
808
809
    std::unordered_map<int, const Block *> order_to_block;
    std::unordered_map<const Block *, int> block_to_stage;
    for (const Block &block : original_order) {
      const auto &stmt_info = pipeline_info.at(block);
810
811
      int order = stmt_info.order;
      CHECK(!used_orders.count(order))
812
813
          << "ValueError: Two statements in the software pipeline cannot have "
             "the same order";
814
815
816
      used_orders.insert(order);
    }

817
818
    std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
        dep_src2dst;
819
820
    BuildDependencyGraph(original_order, &dep_src2dst, nullptr);

821
822
823
824
825
826
    for (const auto &pair : dep_src2dst) {
      const Block &src = pair.first;
      const auto &src_info = pipeline_info.at(src);
      const Array<Block> &dsts = pair.second;
      for (const Block &dst : dsts) {
        const auto &dst_info = pipeline_info.at(dst);
827
828
        CHECK_LE(src_info.stage, dst_info.stage)
            << "ValueError: statement " << dst << " in stage " << dst_info.stage
829
830
            << " cannot depends on statement " << src << " in a later stage "
            << src_info.stage;
831
        if (src_info.stage == dst_info.stage) {
832
833
834
835
          CHECK_LT(src_info.order, dst_info.order)
              << "ValueError: two statements with buffer "
                 "access dependency in the same stage of the "
                 "software pipeline cannot be reordered";
836
837
838
839
840
        }
      }
    }
  }

841
  Stmt VisitStmt_(const ForNode *op) final {
842
843
844
845
846
    // Step 1: Recursively rewrite the children first.
    For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
    if (!HasPipelineAnnotation(op)) {
      return std::move(for_node);
    }
847
848
849
    // Step 2: Find the body and buffer allocations of the pipeline. The body
    // can be direct child of the for-loop. If the for-loop has BlockRealize as
    // its child, the pipeline body will be the child of the block.
850
    Stmt pipeline_body{nullptr};
851
    PrimExpr predicate_condition{nullptr};
852
    Array<Buffer> pipeline_allocs;
853
854
855
    if (const auto *realize = for_node->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
856
857
858
        ICHECK(buffer->IsInstance<BufferNode>());
        buffer_data_to_buffer_.Set(buffer->data, buffer);
      }
859
860
861
862
863
864
865
866
867
      if (const auto *if_then_else = block->body.as<IfThenElseNode>()) {
        ICHECK(!if_then_else->else_case.defined())
            << "Pipeline_Planning: Can't handle the body of the loop because "
               "it is not a SeqStmt";
        pipeline_body = if_then_else->then_case;
        predicate_condition = if_then_else->condition;
      } else {
        pipeline_body = block->body;
      }
868
869
870
871
872
      pipeline_allocs = block->alloc_buffers;
    } else {
      pipeline_body = for_node->body;
    }

873
874
875
876
    const SeqStmtNode *pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
    CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
                                "should be SeqStmt, got "
                             << pipeline_body->GetTypeKey();
877

878
879
    // Step 3: Blockize the components of the pipeline. Each child of the
    // pipelined loop will be converted into a block.
880
    PipelineInfo pipeline_info;
881
    Array<Block> original_order; // pipeline body blocks in the original order
882

883
    auto f_add_child = [&](const Stmt &child) {
884
885
886
      original_order.push_back(MakeBlock(child, buffer_data_to_buffer_));
    };
    for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) {
887
888
      const auto *nested_block_realize =
          pipeline_body_seq->seq[i].as<BlockRealizeNode>();
889
890
      if (nested_block_realize && is_one(nested_block_realize->predicate) &&
          nested_block_realize->block->body->IsInstance<SeqStmtNode>()) {
891
892
893
894
        const Block &nested_pipeline_block = nested_block_realize->block;
        ICHECK(nested_pipeline_block->match_buffers
                   .empty()); // match_buffer should have been lowered
        for (const auto &buffer : nested_pipeline_block->alloc_buffers) {
895
896
897
          pipeline_allocs.push_back(buffer);
          buffer_data_to_buffer_.Set(buffer->data, buffer);
        }
898
        const auto *nested_seq = nested_pipeline_block->body.as<SeqStmtNode>();
899
900
901
902
903
904
905
906
        for (size_t j = 0; j < nested_seq->seq.size(); j++) {
          f_add_child(nested_seq->seq[j]);
        }
      } else {
        f_add_child(pipeline_body_seq->seq[i]);
      }
    }

907
908
909
910
    auto pipeline_stages = Downcast<Array<Integer>>(
        op->annotations.at(tir::attr::software_pipeline_stage));
    auto pipeline_orders = Downcast<Array<Integer>>(
        op->annotations.at(tir::attr::software_pipeline_order));
911
912
    CHECK_EQ(pipeline_stages.size(), original_order.size())
        << "PrimFunc " << global_symbol_ << " has original order "
913
914
915
916
        << original_order.Map(
               [](const auto &block) { return block->name_hint; })
        << ", but pipeline annotation is " << pipeline_stages
        << " with different size";
917
918
    CHECK_EQ(pipeline_orders.size(), original_order.size())
        << "PrimFunc " << global_symbol_ << " has original order "
919
920
921
922
        << original_order.Map(
               [](const auto &block) { return block->name_hint; })
        << ", but pipeline annotation is " << pipeline_orders
        << " with different size";
923
924

    std::unordered_set<int> pipeline_async_stages;
925
926
    if (auto annot =
            op->annotations.Get(tir::attr::software_pipeline_async_stages)) {
927
928
929
930
931
932
933
      for (auto s : Downcast<Array<Integer>>(annot)) {
        pipeline_async_stages.insert(s->value);
      }
    }

    for (size_t i = 0; i < pipeline_stages.size(); i++) {
      int stage = static_cast<int>(pipeline_stages[i]->value);
934
935
936
937
938
      bool is_async =
          pipeline_async_stages.find(stage) != pipeline_async_stages.end();
      PipelineAnnotation stage_order{
          stage,
          /*order=*/static_cast<int>(pipeline_orders[i]->value), is_async};
939
940
941
942
943
944
      pipeline_info.emplace(original_order[i], stage_order);
    }

    ValidatePipelineBody(pipeline_info, original_order);

    // Step 4: Rewrite the pipeline body.
945
946
947
948
    Stmt pipeline =
        PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
                         GetRef<For>(op), pipeline_info, predicate_condition)
            .BuildPipeline();
949

950
951
952
    if (const auto *realize = op->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
953
954
955
956
957
958
        buffer_data_to_buffer_.erase(buffer->data);
      }
    }
    return pipeline;
  }

959
960
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const auto &buffer : op->alloc_buffers) {
961
962
963
964
965
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }

    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));

966
    for (const auto &buffer : op->alloc_buffers) {
967
968
969
970
971
      buffer_data_to_buffer_.erase(buffer->data);
    }
    return std::move(block);
  }

972
  bool HasPipelineAnnotation(const ForNode *op) const {
973
974
975
976
977
978
979
980
    auto it1 = op->annotations.find(tir::attr::software_pipeline_stage);
    auto it2 = op->annotations.find(tir::attr::software_pipeline_order);
    bool has_stage = it1 != op->annotations.end();
    bool has_order = it2 != op->annotations.end();
    if (has_stage && has_order) {
      return true;
    }
    if (has_stage) {
981
982
      LOG(FATAL)
          << "ValueError: Order of the software pipeline is not defined.";
983
984
    }
    if (has_order) {
985
986
      LOG(FATAL)
          << "ValueError: Stage of the software pipeline is not defined.";
987
988
989
990
991
992
993
994
995
    }
    return false;
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Optional<String> global_symbol_;
};

/*!
996
997
 * \brief Transform annotated loops into pipelined one that parallelize
 * producers and consumers. \return The IR transform pass.
998
999
1000
1001
 */
tir::transform::Pass InjectSoftwarePipeline() {
  using namespace tir::transform;
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
1002
    auto *fptr = f.CopyOnWrite();
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
    fptr->body = PipelineInjector::Inject(f);
    fptr->body = ConvertSSA(std::move(fptr->body));
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {});
}

TVM_REGISTER_GLOBAL("tl.transform.InjectSoftwarePipeline")
    .set_body_typed(InjectSoftwarePipeline);

1013
1014
} // namespace tl
} // namespace tvm