inject_pipeline.cc 39.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file inject_software_pipeline.cc
22
23
 * \brief Transform annotated loops into pipelined one that parallelize
 * producers and consumers
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
 */
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h>

#include <unordered_set>

#include "support/utils.h"
#include "tir/schedule/utils.h"
#include "tir/transforms/ir_utils.h"

namespace tvm {
namespace tl {
using namespace tir;

/*!
 * \brief Create a block and infer the access region with the given body.
 *
42
43
44
 * The result is a opaque block that doesn't contain any block iter vars. In
 * case the body is a block realize without predicate, it is unnecessary to
 * create a new block, the block of the block realize will be returned.
45
46
47
48
49
 *
 * \param body The body of the block.
 * \param buffer_data_to_buffer The map from buffer data to buffer.
 * \return The result block.
 */
50
51
52
Block MakeBlock(const Stmt &body,
                const Map<Var, Buffer> &buffer_data_to_buffer) {
  if (const BlockRealizeNode *block_realize = body.as<BlockRealizeNode>()) {
53
54
55
56
57
    if (is_one(block_realize->predicate)) {
      // no need to create a new block
      return block_realize->block;
    }
  }
58
59
60
61
62
  Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
              /*body*/ body);
  Array<Array<BufferRegion>> access =
      GetBlockReadWriteRegion(block, buffer_data_to_buffer);
  BlockNode *n = block.CopyOnWrite();
63
64
65
66
67
68
69
70
71
72
73
74
  n->reads = access[0];
  n->writes = access[1];
  return block;
}

/*! Structure that represents the provided annotation per block or loop. */
struct PipelineAnnotation {
  int stage;
  int order;
  bool async;
};

75
76
using PipelineInfo = std::unordered_map<Block, PipelineAnnotation,
                                        ObjectPtrHash, ObjectPtrEqual>;
77
78

struct BufferAccessInfo {
79
80
  int def = -1; // the defining stage of the buffer
  int use = -1; // the last using stage of the buffer
81
82
};

83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
/*!
 * \brief Replace IfThenElse nodes with their then_case, preserving attribute
 * nodes \param body The statement to process \param condition The condition to
 * match in IfThenElse nodes \return The transformed statement
 */
Stmt replace_if_then_else(Stmt body, PrimExpr condition) {
  if (const auto *if_node = body.as<IfThenElseNode>()) {
    // If this is an IfThenElse with the matching condition, replace it with its
    // then_case
    if (if_node->condition.same_as(condition)) {
      return if_node->then_case;
    }
  } else if (const auto *attr_node = body.as<AttrStmtNode>()) {
    // For attribute nodes, preserve the attribute but process its body
    AttrStmt attr_stmt = GetRef<AttrStmt>(attr_node);
    attr_stmt.CopyOnWrite()->body =
        replace_if_then_else(attr_node->body, condition);
    return attr_stmt;
  } else if (const auto *block_node = body.as<BlockNode>()) {
    // For block nodes, process the body
    Block block = GetRef<Block>(block_node);
    block.CopyOnWrite()->body =
        replace_if_then_else(block_node->body, condition);
    return block;
  }
  // For any other node type, return it unchanged
  return body;
}

112
/*!
113
114
115
 * \brief Rewriter for the body of the software pipeline. This pass inserts
 * `floormod` to indices of the remapped buffer to select the version
 * corresponding to the pipeline stage.
116
117
 */
class PipelineBodyRewriter : public StmtExprMutator {
118
public:
119
120
121
  /*!
   * \brief Constructor of PipelineBodyRewriter.
   * \param buffer_data_to_buffer The map from buffer data to buffer.
122
123
124
125
126
127
128
   * \param buffer_remap The map from original buffer to the buffer with updated
   * shape for multi-versioning in the software pipeline. \param pipeline_loop
   * The original loop to be software pipelined. \param access_all_versions
   * Whether all versions the buffers in the software pipeline are accessed.
   * This will be used to update block access region. In the prologue and
   * epilogue of a two-stage software pipeline, only one version of these
   * buffers are accessed.
129
   */
130
131
132
  PipelineBodyRewriter(const Map<Var, Buffer> &buffer_data_to_buffer,
                       const Map<Buffer, Buffer> &buffer_remap,
                       For pipeline_loop, bool access_all_versions)
133
      : buffer_data_to_buffer_(buffer_data_to_buffer),
134
        buffer_remap_(buffer_remap), pipeline_loop_(pipeline_loop),
135
136
        access_all_versions_(access_all_versions) {}

137
138
139
private:
  BufferRegion
  RewritePipelineBufferRegion(const BufferRegion &buffer_region) const {
140
141
142
    auto it = buffer_remap_.find(buffer_region->buffer);
    if (it != buffer_remap_.end()) {
      Region new_region = buffer_region->region;
143
144
145
      const Buffer &new_buffer = (*it).second;
      // For pipeline buffers, relax the access region of the first dimension to
      // full extent if access_all_versions == true
146
147
148
      Range accessed_version =
          access_all_versions_
              ? Range::FromMinExtent(0, new_buffer->shape[0])
149
150
151
152
              : Range::FromMinExtent(
                    floormod((pipeline_loop_->loop_var - pipeline_loop_->min),
                             new_buffer->shape[0]),
                    Integer(1));
153
154
155
156
157
158
      new_region.insert(new_region.begin(), accessed_version);
      return BufferRegion(new_buffer, new_region);
    }
    return buffer_region;
  }

159
160
161
162
163
164
  PrimExpr RewriteBufferAccess(const Call &call,
                               const std::vector<int> arg_indices) {
    auto product = [](const Array<PrimExpr> &input) {
      return foldl(
          [](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
          make_const(DataType::Int(32), 1), input);
165
166
167
    };
    Array<PrimExpr> new_args = call->args;
    for (int i : arg_indices) {
168
169
      const Buffer &buffer =
          buffer_data_to_buffer_.at(Downcast<Var>(call->args[i]));
170
171
      auto it = buffer_remap_.find(buffer);
      if (it != buffer_remap_.end()) {
172
173
        const Buffer &new_buffer = (*it).second;
        const PrimExpr &old_index = call->args[i + 1];
174
175
176
177
178
179
180
        PrimExpr offset;
        if (new_buffer->strides.empty()) {
          offset = product(buffer->shape);
        } else {
          offset = new_buffer->strides[0];
        }
        PrimExpr new_index =
181
182
            old_index +
            floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
183
184
185
186
187
188
        new_args.Set(i + 1, new_index);
      }
    }
    return Call(call->dtype, call->op, new_args, call->span);
  }

189
190
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const Buffer &alloc_buffer : op->alloc_buffers) {
191
192
193
      buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer);
    }
    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
194
195
    BlockNode *n = block.CopyOnWrite();
    n->reads.MutateByApply([this](const BufferRegion &buffer_region) {
196
197
      return RewritePipelineBufferRegion(buffer_region);
    });
198
    n->writes.MutateByApply([this](const BufferRegion &buffer_region) {
199
200
      return RewritePipelineBufferRegion(buffer_region);
    });
201
    for (const Buffer &alloc_buffer : op->alloc_buffers) {
202
203
204
205
206
      buffer_data_to_buffer_.erase(alloc_buffer->data);
    }
    return std::move(block);
  }

207
  Stmt VisitStmt_(const BufferStoreNode *op) final {
208
209
210
211
212
    BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
    auto it = buffer_remap_.find(store->buffer);
    if (it == buffer_remap_.end()) {
      return std::move(store);
    }
213
214
    const Buffer &new_buffer = (*it).second;
    auto *n = store.CopyOnWrite();
215
    n->buffer = new_buffer;
216
217
    PrimExpr version = floormod(
        (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
218
219
220
221
    n->indices.insert(n->indices.begin(), version);
    return std::move(store);
  }

222
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
223
224
225
226
227
    BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
    auto it = buffer_remap_.find(load->buffer);
    if (it == buffer_remap_.end()) {
      return std::move(load);
    }
228
229
    const Buffer &new_buffer = (*it).second;
    auto *n = load.CopyOnWrite();
230
    n->buffer = new_buffer;
231
232
    PrimExpr version = floormod(
        (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
233
234
235
236
    n->indices.insert(n->indices.begin(), version);
    return std::move(load);
  }

237
  PrimExpr VisitExpr_(const CallNode *op) final {
238
239
240
241
242
243
244
245
246
247
248
249
250
251
    Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
    if (call->op.same_as(builtin::tvm_access_ptr())) {
      return RewriteBufferAccess(call, {1});
    }
    return call;
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Map<Buffer, Buffer> buffer_remap_;
  For pipeline_loop_;
  bool access_all_versions_;
};

/*!
252
253
 * \brief Rewriter for the software pipeline that rewrite a loop into a
 * pipelined one.
254
255
 */
class PipelineRewriter : public StmtExprMutator {
256
257
258
public:
  PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
                   const Array<Buffer> &pipeline_allocs,
259
260
                   const For &pipeline_loop, const PipelineInfo &pipeline_info,
                   PrimExpr predicate_condition = PrimExpr())
261
      : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
262
        pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop),
263
264
        pipeline_info_(pipeline_info),
        predicate_condition_(predicate_condition) {}
265
266

  Stmt BuildPipeline() {
267
268
269
270
271
    // Step 1: Analyze accesses to the buffers in the pipeline and compute the
    // number of versions need to maintain for each buffer.
    std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
        infos = GetBufferAccessInfo();
    for (const Buffer &buffer : pipeline_allocs_) {
272
273
274
275
276
277
278
      int num_versions = ComputeBufferVersions(buffer, infos.at(buffer));
      if (num_versions > 1) {
        buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions));
      }
    }

    ordered_stmts_.resize(pipeline_info_.size());
279
    for (const auto &[block, anno] : pipeline_info_) {
280
281
282
      ordered_stmts_.Set(anno.order, block);
    }

283
    for (const Block &block : ordered_stmts_) {
284
285
      int stage = pipeline_info_[block].stage;
      if (pipeline_info_[block].async) {
286
        auto &state = async_states[stage];
287
288
289
290
        state.producer_head = pipeline_loop_->min - 1;
        for (auto write_region : block->writes) {
          auto buffer = write_region->buffer;
          state.dst_buffers.insert(buffer.get());
291
292
          if (buffer_remap_.count(buffer))
            state.dst_buffers.insert(buffer_remap_[buffer].get());
293
294
295
296
        }
      }
    }
    std::unordered_set<int> consumed;
297
    for (const Block &block : ordered_stmts_) {
298
299
      int stage = pipeline_info_[block].stage;
      if (pipeline_info_[block].async) {
300
        auto &state = async_states[stage];
301
302
303
304
305
306
307
308
309
        if (state.commit_groups.empty() || consumed.count(stage)) {
          state.commit_groups.push_back({});
        }
        state.commit_groups.back().push_back(pipeline_info_[block].order);
        consumed.erase(stage);
        for (auto write_region : block->writes) {
          auto buffer = buffer_remap_.count(write_region->buffer)
                            ? buffer_remap_[write_region->buffer]
                            : write_region->buffer;
310
311
          state.buffer_to_commit_group_[buffer.get()] =
              state.commit_groups.size() - 1;
312
313
314
315
        }
      }

      for (auto read_region : block->reads) {
316
317
318
        for (const auto &[producer_stage_id, producer_state] : async_states) {
          if (producer_stage_id <= stage &&
              producer_state.writes(read_region->buffer)) {
319
320
321
322
323
324
325
            consumed.insert(producer_stage_id);
          }
        }
      }
    }

    // Step 2: Emit the pipeline prologue, body and epilogue.
326
327
328
329
330
331
332
333
    Stmt prologue = EmitImpl(pipeline_loop_->min,
                             pipeline_loop_->min + max_stage_, true, true);
    Stmt body =
        EmitImpl(pipeline_loop_->min + max_stage_,
                 pipeline_loop_->min + pipeline_loop_->extent, false, false);
    Stmt epilogue = EmitImpl(
        pipeline_loop_->min + pipeline_loop_->extent,
        pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true);
334
335
336

    SeqStmt stmt = SeqStmt({prologue, body, epilogue});

337
338
    // Step 3: Make a new block that contains new buffer allocations after
    // pipeline rewriting.
339
    Array<Buffer> alloc_buffers;
340
    for (const auto &alloc : pipeline_allocs_) {
341
342
343
344
345
346
347
348
      alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc));
      buffer_data_to_buffer_.erase(alloc->data);
    }
    Block block = MakeBlock(stmt, buffer_data_to_buffer_);
    block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers);
    return BlockRealize({}, Bool(true), block);
  }

349
private:
350
351
352
  /*!
   * \brief Analyze accesses to the buffers in the software pipeline.
   *
353
354
355
   * This method check the 'define' and 'use' stage of the buffers in the
   * software pipeline, which can be used to compute the number of versions
   * needed to maintain after rewriting.
356
357
358
   */
  std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
  GetBufferAccessInfo() {
359
360
361
362
    std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
        infos;
    for (const auto &pair : pipeline_info_) {
      const Block &block = pair.first;
363
364
365
      int stage = pair.second.stage;
      max_stage_ = std::max(max_stage_, stage);

366
      for (const BufferRegion &write : block->writes) {
367
368
369
        if (!infos.count(write->buffer)) {
          infos.emplace(write->buffer, BufferAccessInfo{});
        }
370
        auto &info = infos.at(write->buffer);
371
372
373
374
375
376
377
        if (info.def == -1) {
          info.def = stage;
        } else {
          info.def = std::min(info.def, stage);
        }
      }

378
      for (const BufferRegion &read : block->reads) {
379
380
381
        if (!infos.count(read->buffer)) {
          infos.emplace(read->buffer, BufferAccessInfo{});
        }
382
        auto &info = infos.at(read->buffer);
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
        info.use = std::max(info.use, stage);
      }
    }
    return infos;
  }

  /*!
   * \brief Check whether two regions have intersections.
   * \param region1 The first region.
   * \param region2 The second region.
   * \return Whether region1 and region2 have intersections.
   */
  bool MayConflict(Region region1, Region region2) {
    ICHECK(region1.size() == region2.size());
    for (size_t i = 0; i < region1.size(); i++) {
      Range dim1 = region1[i];
      Range dim2 = region2[i];
      auto int_set1 = arith::IntSet::FromRange(dim1);
      auto int_set2 = arith::IntSet::FromRange(dim2);
      if (arith::Intersect({int_set1, int_set2}).IsNothing()) {
        return false;
      }
    }
    return true;
  }

  /*!
410
411
   * \brief Compute the number of versions need to maintain for buffer accessed
   * in the software pipeline.
412
   *
413
414
415
416
417
418
   * This method applies liveness analysis to the target buffer to compute the
   * number of versions need to maintain during the software pipeline.
   * Annotation `attr::double_buffer_scope` is handled here which provides a way
   * to override the result of the analysis. Additional double buffering in the
   * software pipeline can be useful to eliminate synchronizations in GPU
   * devices.
419
420
421
422
423
   *
   * \param buffer The target buffer
   * \param buffer_info The access information of the target buffer.
   * \return The number of versions required for the target buffer.
   */
424
425
  int ComputeBufferVersions(const Buffer &buffer,
                            const BufferAccessInfo &buffer_info) {
426
    if (buffer_info.def == -1) {
427
428
      // Keep the original number of versions as buffers defined outside the
      // software pipeline should not be mutated.
429
430
431
432
      return 1;
    }

    // `use - def + 1` is a upper bound of the needed versions
433
434
    // We optimize a few case where the number of versions can be smaller than
    // the upper bound
435
436
    int num_versions = buffer_info.use - buffer_info.def + 1;
    if (num_versions >= 2) {
437
438
439
440
      // A special case when `use - def + 1 == 2`. Double buffering is only
      // needed in this case when these exists a reader block_i and a writer
      // block_j such that order(block_i) < order(block_j) and stage(block_i) <
      // stage(block_j) and the access regions of block_i and block_j overlap.
441
      bool need_multi_version = false;
442
443
444
      for (const auto &pair1 : pipeline_info_) {
        const Block &writer_block = pair1.first;
        const auto &writer_info = pair1.second;
445

446
447
448
        auto it1 = std::find_if(writer_block->writes.begin(),
                                writer_block->writes.end(),
                                [&](const BufferRegion &buffer_region) {
449
450
451
452
453
454
                                  return buffer_region->buffer.same_as(buffer);
                                });
        if (it1 == writer_block->writes.end()) {
          continue;
        }

455
456
457
458
459
460
461
462
        for (const auto &pair2 : pipeline_info_) {
          const Block &reader_block = pair2.first;
          const auto &reader_info = pair2.second;
          auto it2 = std::find_if(
              reader_block->reads.begin(), reader_block->reads.end(),
              [&](const BufferRegion &buffer_region) {
                return buffer_region->buffer.same_as(buffer);
              });
463
464
465
          if (it2 == reader_block->reads.end()) {
            continue;
          }
466
467
          if (writer_info.order < reader_info.order &&
              writer_info.stage < reader_info.stage &&
468
469
470
471
472
473
474
475
476
477
478
479
480
481
              MayConflict((*it1)->region, (*it2)->region)) {
            need_multi_version = true;
            break;
          }
        }
      }
      if (!need_multi_version) {
        num_versions--;
      }
    }
    return num_versions;
  }

  /*!
482
483
   * \brief Rewrite buffer allocation to keep multiple versions of original
   * buffer for pipelined accesses. \param buffer The buffer to be resized.
484
485
486
   * \param num_versions The number of versions to keep.
   * \return The resized buffer.
   */
487
  Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
488
489
490
491
492
493
494
495
496
497
    ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get()));
    new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
    if (new_buffer->strides.size()) {
      ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
      PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1];
      new_buffer->strides.insert(new_buffer->strides.begin(), stride_0);
    }
    return Buffer(new_buffer);
  }

498
499
  // Per-stage states that need to be tracked across pipeline prologue, body,
  // and epilogue.
500
501
  struct AsyncStateGlobal {
    // Buffers that this stage asynchronously writes.
502
503
504
505
506
507
508
    std::unordered_set<const BufferNode *> dst_buffers;
    // An imaginary index that the latest async operation associated with this
    // stage has written into. Only valid if all associated predicates are true,
    // so that we can count the number of async invocations exactly. When it is
    // valid, it is the "sum of extents of loops that have been executed" - 1,
    // e.g. for epilogue it is prologue extent + body extent - 1. This is only
    // needed to compute wait count for epilogue without async producers.
509
510
    PrimExpr producer_head;
    std::vector<std::vector<int>> commit_groups;
511
    std::unordered_map<const BufferNode *, int> buffer_to_commit_group_;
512
513
514
    bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; }
  };

515
516
  // Per-stage states that are local to each of pipeline prologue, body, and
  // epilogue.
517
518
  struct AsyncStateLocal {
    struct PendingWait {
519
520
      // The index into a list of blocks, where async_wait_queue should be
      // attached at the beginning.
521
      int insert_before;
522
523
      // in_flight_count would be a more precise name, but the implementation
      // uses wait_count for brevity.
524
525
526
527
528
529
530
      PrimExpr wait_count{nullptr};

      bool valid() const { return wait_count.defined(); }
    };

    std::vector<PendingWait> pending_waits;

531
532
    // A symbolic expression representing the index the latest async operation
    // associated with this stage has written into, at the "current" iteration.
533
534
535
536
537
538
539
540
541
542
543
544
545
    Optional<PrimExpr> producer_head;
  };

  /*! Structure holding intermediate information for pipeline loop rewriting. */
  struct RewrittenBlockInfo {
    int stage;
    int order;
    PrimExpr predicate;
    Block block;
    PrimExpr access_index;
    bool is_async;
  };

546
547
  void PopulateWaitCounts(const std::vector<RewrittenBlockInfo> &new_blocks,
                          std::map<int, AsyncStateLocal> *async_states_local) {
548
549
550
    for (size_t i = 0; i < new_blocks.size(); ++i) {
      int producer_stage_idx = -1;
      for (auto read_region : new_blocks[i].block->reads) {
551
552
553
554
555
        for (const auto &[stage, state] : async_states) {
          if (stage <= new_blocks[i].stage &&
              state.writes(read_region->buffer)) {
            // Found an earlier stage where read_region->buffer was
            // asynchronously written
556
557
558
559
560
561
            ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage)
                << "A dependency on multiple async stages is not supported";
            producer_stage_idx = stage;
          }
        }
      }
562
563
564
565
      if (producer_stage_idx == -1)
        continue;
      const auto &state = async_states[producer_stage_idx];
      auto &dep_local_state = (*async_states_local)[producer_stage_idx];
566
      PrimExpr in_flight_cnt = 0;
567
      for (const auto &group : state.commit_groups) {
568
569
570
571
572
        PrimExpr consumer_head = new_blocks[i].access_index;
        PrimExpr producer_head;
        if (dep_local_state.producer_head.defined()) {
          producer_head = dep_local_state.producer_head.value();
          // if the group is after the wait point, minus by 1
573
574
          if (group.front() > new_blocks[i].order)
            producer_head -= 1;
575
576
577
578
579
580
581
582
        } else {
          producer_head = state.producer_head;
        }
        in_flight_cnt += producer_head - consumer_head;
      }

      // We can relax the in-flight-count by the number of independent commit.
      std::unordered_set<int> dependent_groups;
583
      for (const auto &read_region : new_blocks[i].block->reads) {
584
        if (state.buffer_to_commit_group_.count(read_region->buffer.get()))
585
586
          dependent_groups.insert(
              state.buffer_to_commit_group_.at(read_region->buffer.get()));
587
588
589
590
591
      }
      for (int i = int(state.commit_groups.size()) - 1; i >= 0; i--) {
        if (dependent_groups.count(i) == 0)
          in_flight_cnt += 1;
        else
592
          break; // stop relaxing
593
594
      }
      in_flight_cnt = analyzer_.Simplify(in_flight_cnt);
595
596
      dep_local_state.pending_waits.push_back(
          {static_cast<int>(i), in_flight_cnt});
597
598
599
    }
  }

600
601
  // Given pipelined blocks and async-related information, generate final loop
  // statements with async scopes (if any).
602
  Array<Stmt> CompletePipelineLoopStatements(
603
604
      const std::vector<RewrittenBlockInfo> &blocks,
      const std::map<int, AsyncStateLocal> &async_states_local) const {
605
    std::vector<RewrittenBlockInfo> new_blocks = blocks;
606
607
608
609
    for (const auto &[stage_id, state] : async_states_local) {
      for (const auto &pw : state.pending_waits) {
        auto &block = new_blocks[pw.insert_before].block;
        BlockNode *n = block.CopyOnWrite();
610
        auto zero = make_zero(DataType::Int(32));
611
612
613
        n->body = AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
                           AttrStmt(zero, tir::attr::async_wait_inflight_count,
                                    pw.wait_count, n->body));
614
615
616
617
618
      }
    }

    // mark the last async stmt as commit
    std::unordered_set<int> commit_group_indices;
619
    for (const auto &[stage_id, state] : async_states) {
620
621
622
623
624
625
626
627
628
629
      for (size_t i = 0; i < state.commit_groups.size(); ++i) {
        commit_group_indices.insert(state.commit_groups[i].back());
      }
    }

    Array<Stmt> stmts;

    for (size_t i = 0; i < new_blocks.size(); i++) {
      Block block = new_blocks[i].block;
      if (commit_group_indices.count(new_blocks[i].order)) {
630
631
632
        auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)),
                                           tir::attr::async_commit_queue_scope,
                                           new_blocks[i].stage, block->body);
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
        block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_);
      }
      stmts.push_back(BlockRealize({}, new_blocks[i].predicate, block));
    }

    return stmts;
  }

  /*!
   * \brief Emit the pipeline loop in the given range.
   * \param start The start of the range
   * \param end The end of the range
   * \param unroll_loop Whether the loop should be unrolled.
   * \return The result loop.
   */
648
649
  Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop,
                bool need_bound_check) {
650
651
    PrimExpr new_loop_var;
    PrimExpr extent = end - start;
652
653
654
    auto make_nop = []() {
      return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {}));
    };
655
656
657

    bool is_unit_loop = analyzer_.CanProveEqual(extent, 1);
    if (is_unit_loop) {
658
      new_loop_var = start; // use constants as the loop var for unit loops
659
660
661
662
663
664
665
666
667
    } else {
      new_loop_var = pipeline_loop_->loop_var.copy_with_suffix("");
      analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end));
    }

    std::vector<RewrittenBlockInfo> new_blocks;

    // Async related
    std::map<int, AsyncStateLocal> async_states_local;
668
    PrimExpr normalized_access_index;
669

670
    for (const Block &block : ordered_stmts_) {
671
672
673
674
675
      int stage = pipeline_info_.at(block).stage;
      int order = pipeline_info_.at(block).order;
      PrimExpr inbound = Bool(true);
      PrimExpr skewed_loop_var = new_loop_var - stage;
      if (need_bound_check)
676
677
678
        inbound =
            analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) &&
            (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent);
679
680
681
      if (analyzer_.CanProve(!inbound)) {
        continue;
      }
682
683
684
      Block new_block = Downcast<Block>(
          PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_,
                               pipeline_loop_, max_stage_ != 1)(block));
685
686
687
688

      PrimExpr delta = start - pipeline_loop_->min;
      // This variable corresponds to
      // - "producer_head" if this stage is an async producer
689
690
      // - "consumer_head" if this stage reads from asynchronously written
      // buffers.
691
      normalized_access_index =
692
          is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
693

694
695
      // Adjust the block predicate and the body according to the final loop
      // bound
696
697
698
699
700
      //  [pipeline_loop_->min, extent).
      if (!is_unit_loop) {
        Var loop_iter = Downcast<Var>(new_loop_var);
        inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}});
      }
701
702
      new_block = Downcast<Block>(Substitute(
          new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
703
704
705
706
707
708
709
      if (predicate_condition_.defined()) {
        BlockNode *n = new_block.CopyOnWrite();
        n->body = IfThenElse(
            Substitute(predicate_condition_,
                       {{pipeline_loop_->loop_var, normalized_access_index}}),
            n->body);
      }
710
      if (pipeline_info_[block].async) {
711
        auto &local_state = async_states_local[stage];
712
        local_state.producer_head = normalized_access_index;
713
714
715
        BlockNode *n = new_block.CopyOnWrite();
        n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope,
                           1, n->body);
716
717
      }

718
719
720
      new_blocks.push_back({stage, order, inbound, new_block,
                            normalized_access_index,
                            pipeline_info_[block].async});
721
722
723
    }

    PopulateWaitCounts(new_blocks, &async_states_local);
724

725
    auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local);
726

727
728
    Stmt new_loop{nullptr};

729
    if (stmts.empty()) {
730
731
      return make_nop();
    }
732

733
734
    if (stmts.size() == 1) {
      new_loop = stmts[0];
735
    } else {
736
      new_loop = SeqStmt(stmts);
737
738
739
740
    }

    if (!is_unit_loop) {
      Map<String, ObjectRef> preserved_annotations;
741
742
      for (const auto &kv : pipeline_loop_->annotations) {
        const String &key = kv.first;
743
744
745
746
747
748
749
        if (kv.first != tir::attr::software_pipeline_stage &&
            kv.first != tir::attr::software_pipeline_order &&
            kv.first != tir::attr::software_pipeline_async_stages) {
          preserved_annotations.Set(key, kv.second);
        }
      }
      new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
750
751
                     unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind,
                     std::move(new_loop), NullOpt, preserved_annotations);
752
753
    }
    // Update producer heads in the global async states.
754
    for (const auto &[stage_id, state] : async_states_local) {
755
756
757
      async_states[stage_id].producer_head += extent;
    }

758
759
    return BlockRealize({}, Bool(true),
                        MakeBlock(std::move(new_loop), buffer_data_to_buffer_));
760
761
762
763
764
765
766
  }

  arith::Analyzer analyzer_;
  Map<Var, Buffer> buffer_data_to_buffer_;
  Array<Buffer> pipeline_allocs_;
  For pipeline_loop_;
  PipelineInfo pipeline_info_;
767
  PrimExpr predicate_condition_;
768
769
770
771
772
773
774
775
776
  int max_stage_ = -1;
  Map<Buffer, Buffer> buffer_remap_;
  Array<Block> ordered_stmts_;
  std::map<int, AsyncStateGlobal> async_states;
};

/*!
 * \brief Build the dependency graph among a array of blocks.
 * \param[in] blocks The array of blocks.
777
778
779
 * \param[out] dep_src2dst Optional, a map to store dependency edges from the
 * source to the destination. \param[out] dep_dst2src Optional, a map to store
 * dependency edges from the destination to the source.
780
 */
781
782
783
784
785
786
787
788
789
790
void BuildDependencyGraph(const Array<Block> &blocks,
                          std::unordered_map<Block, Array<Block>, ObjectPtrHash,
                                             ObjectPtrEqual> *dep_src2dst,
                          std::unordered_map<Block, Array<Block>, ObjectPtrHash,
                                             ObjectPtrEqual> *dep_dst2src) {
  std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
      buffer_writers;

  for (const Block &block : blocks) {
    for (const BufferRegion &read : block->reads) {
791
792
      auto it = buffer_writers.find(read->buffer->data);
      if (it != buffer_writers.end()) {
793
        for (const Block &writer : it->second) {
794
795
796
797
798
799
800
801
802
          if (dep_src2dst != nullptr) {
            (*dep_src2dst)[writer].push_back(block);
          }
          if (dep_dst2src != nullptr) {
            (*dep_dst2src)[block].push_back(writer);
          }
        }
      }
    }
803
    for (const BufferRegion &write : block->writes) {
804
805
806
807
808
809
      buffer_writers[write->buffer->data].push_back(block);
    }
  }
}

class PipelineInjector : private StmtExprMutator {
810
811
public:
  static Stmt Inject(const PrimFunc &func) {
812
813
    auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
    PipelineInjector injector(global_symbol);
814
815
    for (const auto &kv : func->buffer_map) {
      const Buffer &buffer = kv.second;
816
817
818
819
820
      injector.buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    return injector(func->body);
  }

821
822
823
private:
  explicit PipelineInjector(Optional<String> global_symbol)
      : global_symbol_(global_symbol) {}
824
825
826
827

  /*!
   * \brief Check the pipeline satisfies the following conditions:
   * 1. No conflicting order: The order of each statement should be unique.
828
829
830
831
   * 2. Reordering of statements doesn't break buffer access dependencies.
   * Specifically, for dependency (e.g. read-after-write) from statement A to
   * statement B, it requires: case 1: stage(A) < stage(B) case 2: stage(A) ==
   * stage(B) and order(A) < order(B)
832
   */
833
834
  void ValidatePipelineBody(const PipelineInfo &pipeline_info,
                            const Array<Block> &original_order) {
835
836
    std::unordered_set<int> used_orders;
    std::unordered_map<int, int> stage_max_order;
837
838
839
840
    std::unordered_map<int, const Block *> order_to_block;
    std::unordered_map<const Block *, int> block_to_stage;
    for (const Block &block : original_order) {
      const auto &stmt_info = pipeline_info.at(block);
841
842
      int order = stmt_info.order;
      CHECK(!used_orders.count(order))
843
844
          << "ValueError: Two statements in the software pipeline cannot have "
             "the same order";
845
846
847
      used_orders.insert(order);
    }

848
849
    std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
        dep_src2dst;
850
851
    BuildDependencyGraph(original_order, &dep_src2dst, nullptr);

852
853
854
855
856
857
    for (const auto &pair : dep_src2dst) {
      const Block &src = pair.first;
      const auto &src_info = pipeline_info.at(src);
      const Array<Block> &dsts = pair.second;
      for (const Block &dst : dsts) {
        const auto &dst_info = pipeline_info.at(dst);
858
859
        CHECK_LE(src_info.stage, dst_info.stage)
            << "ValueError: statement " << dst << " in stage " << dst_info.stage
860
861
            << " cannot depends on statement " << src << " in a later stage "
            << src_info.stage;
862
        if (src_info.stage == dst_info.stage) {
863
864
865
866
          CHECK_LT(src_info.order, dst_info.order)
              << "ValueError: two statements with buffer "
                 "access dependency in the same stage of the "
                 "software pipeline cannot be reordered";
867
868
869
870
871
        }
      }
    }
  }

872
  Stmt VisitStmt_(const ForNode *op) final {
873
874
875
876
877
    // Step 1: Recursively rewrite the children first.
    For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
    if (!HasPipelineAnnotation(op)) {
      return std::move(for_node);
    }
878
879
880
    // Step 2: Find the body and buffer allocations of the pipeline. The body
    // can be direct child of the for-loop. If the for-loop has BlockRealize as
    // its child, the pipeline body will be the child of the block.
881
    Stmt pipeline_body{nullptr};
882
    PrimExpr predicate_condition{nullptr};
883
    Array<Buffer> pipeline_allocs;
884
885
886
    if (const auto *realize = for_node->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
887
888
889
        ICHECK(buffer->IsInstance<BufferNode>());
        buffer_data_to_buffer_.Set(buffer->data, buffer);
      }
890
891
892
893
894
895
896
897
898
      if (const auto *if_then_else = block->body.as<IfThenElseNode>()) {
        ICHECK(!if_then_else->else_case.defined())
            << "Pipeline_Planning: Can't handle the body of the loop because "
               "it is not a SeqStmt";
        pipeline_body = if_then_else->then_case;
        predicate_condition = if_then_else->condition;
      } else {
        pipeline_body = block->body;
      }
899
900
901
902
903
      pipeline_allocs = block->alloc_buffers;
    } else {
      pipeline_body = for_node->body;
    }

904
905
906
907
    const SeqStmtNode *pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
    CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
                                "should be SeqStmt, got "
                             << pipeline_body->GetTypeKey();
908

909
910
    // Step 3: Blockize the components of the pipeline. Each child of the
    // pipelined loop will be converted into a block.
911
    PipelineInfo pipeline_info;
912
    Array<Block> original_order; // pipeline body blocks in the original order
913

914
    auto f_add_child = [&](const Stmt &child) {
915
916
917
      original_order.push_back(MakeBlock(child, buffer_data_to_buffer_));
    };
    for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) {
918
919
      const auto *nested_block_realize =
          pipeline_body_seq->seq[i].as<BlockRealizeNode>();
920
921
      if (nested_block_realize && is_one(nested_block_realize->predicate) &&
          nested_block_realize->block->body->IsInstance<SeqStmtNode>()) {
922
923
924
925
        const Block &nested_pipeline_block = nested_block_realize->block;
        ICHECK(nested_pipeline_block->match_buffers
                   .empty()); // match_buffer should have been lowered
        for (const auto &buffer : nested_pipeline_block->alloc_buffers) {
926
927
928
          pipeline_allocs.push_back(buffer);
          buffer_data_to_buffer_.Set(buffer->data, buffer);
        }
929
        const auto *nested_seq = nested_pipeline_block->body.as<SeqStmtNode>();
930
931
932
933
934
935
936
937
        for (size_t j = 0; j < nested_seq->seq.size(); j++) {
          f_add_child(nested_seq->seq[j]);
        }
      } else {
        f_add_child(pipeline_body_seq->seq[i]);
      }
    }

938
939
940
941
    auto pipeline_stages = Downcast<Array<Integer>>(
        op->annotations.at(tir::attr::software_pipeline_stage));
    auto pipeline_orders = Downcast<Array<Integer>>(
        op->annotations.at(tir::attr::software_pipeline_order));
942
943
    CHECK_EQ(pipeline_stages.size(), original_order.size())
        << "PrimFunc " << global_symbol_ << " has original order "
944
945
946
947
        << original_order.Map(
               [](const auto &block) { return block->name_hint; })
        << ", but pipeline annotation is " << pipeline_stages
        << " with different size";
948
949
    CHECK_EQ(pipeline_orders.size(), original_order.size())
        << "PrimFunc " << global_symbol_ << " has original order "
950
951
952
953
        << original_order.Map(
               [](const auto &block) { return block->name_hint; })
        << ", but pipeline annotation is " << pipeline_orders
        << " with different size";
954
955

    std::unordered_set<int> pipeline_async_stages;
956
957
    if (auto annot =
            op->annotations.Get(tir::attr::software_pipeline_async_stages)) {
958
959
960
961
962
963
964
      for (auto s : Downcast<Array<Integer>>(annot)) {
        pipeline_async_stages.insert(s->value);
      }
    }

    for (size_t i = 0; i < pipeline_stages.size(); i++) {
      int stage = static_cast<int>(pipeline_stages[i]->value);
965
966
967
968
969
      bool is_async =
          pipeline_async_stages.find(stage) != pipeline_async_stages.end();
      PipelineAnnotation stage_order{
          stage,
          /*order=*/static_cast<int>(pipeline_orders[i]->value), is_async};
970
971
972
973
974
975
      pipeline_info.emplace(original_order[i], stage_order);
    }

    ValidatePipelineBody(pipeline_info, original_order);

    // Step 4: Rewrite the pipeline body.
976
977
978
979
    Stmt pipeline =
        PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
                         GetRef<For>(op), pipeline_info, predicate_condition)
            .BuildPipeline();
980

981
982
983
    if (const auto *realize = op->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
984
985
986
987
988
989
        buffer_data_to_buffer_.erase(buffer->data);
      }
    }
    return pipeline;
  }

990
991
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const auto &buffer : op->alloc_buffers) {
992
993
994
995
996
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }

    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));

997
    for (const auto &buffer : op->alloc_buffers) {
998
999
1000
1001
1002
      buffer_data_to_buffer_.erase(buffer->data);
    }
    return std::move(block);
  }

1003
  bool HasPipelineAnnotation(const ForNode *op) const {
1004
1005
1006
1007
1008
1009
1010
1011
    auto it1 = op->annotations.find(tir::attr::software_pipeline_stage);
    auto it2 = op->annotations.find(tir::attr::software_pipeline_order);
    bool has_stage = it1 != op->annotations.end();
    bool has_order = it2 != op->annotations.end();
    if (has_stage && has_order) {
      return true;
    }
    if (has_stage) {
1012
      LOG(FATAL)
1013
          << "ValueError: Stage of the software pipeline is not defined.";
1014
1015
    }
    if (has_order) {
1016
      LOG(FATAL)
1017
          << "ValueError: Order of the software pipeline is not defined.";
1018
1019
1020
1021
1022
1023
1024
1025
1026
    }
    return false;
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Optional<String> global_symbol_;
};

/*!
1027
1028
 * \brief Transform annotated loops into pipelined one that parallelize
 * producers and consumers. \return The IR transform pass.
1029
1030
1031
1032
 */
tir::transform::Pass InjectSoftwarePipeline() {
  using namespace tir::transform;
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
1033
    auto *fptr = f.CopyOnWrite();
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
    fptr->body = PipelineInjector::Inject(f);
    fptr->body = ConvertSSA(std::move(fptr->body));
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {});
}

TVM_REGISTER_GLOBAL("tl.transform.InjectSoftwarePipeline")
    .set_body_typed(InjectSoftwarePipeline);

1044
1045
} // namespace tl
} // namespace tvm