"docs/source/feature_engineering/index.rst" did not exist on "f7cf3ea5d9cfb4a0e87096ff81bc5c9ea595b96b"
inject_pipeline.cc 42.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file inject_software_pipeline.cc
22
23
 * \brief Transform annotated loops into pipelined one that parallelize
 * producers and consumers
24
25
26
27
28
 */
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h>

29
#include <functional>
30
#include <unordered_set>
31
#include <utility>
32
33
34
35
36
37
38
39

#include "support/utils.h"
#include "tir/schedule/utils.h"
#include "tir/transforms/ir_utils.h"

namespace tvm {
namespace tl {
using namespace tir;
40
using namespace ffi;
41
42
namespace software_pipeline {

43
44
45
46
47
struct LetWrapper {
  Var var;
  PrimExpr value;
};

48
49
50
/*!
 * \brief Create a block and infer the access region with the given body.
 *
51
52
53
 * The result is a opaque block that doesn't contain any block iter vars. In
 * case the body is a block realize without predicate, it is unnecessary to
 * create a new block, the block of the block realize will be returned.
54
55
56
57
58
 *
 * \param body The body of the block.
 * \param buffer_data_to_buffer The map from buffer data to buffer.
 * \return The result block.
 */
59
60
61
Block MakeBlock(const Stmt &body,
                const Map<Var, Buffer> &buffer_data_to_buffer) {
  if (const BlockRealizeNode *block_realize = body.as<BlockRealizeNode>()) {
62
63
64
65
66
    if (is_one(block_realize->predicate)) {
      // no need to create a new block
      return block_realize->block;
    }
  }
67
68
69
70
71
  Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
              /*body*/ body);
  Array<Array<BufferRegion>> access =
      GetBlockReadWriteRegion(block, buffer_data_to_buffer);
  BlockNode *n = block.CopyOnWrite();
72
73
74
75
76
77
78
79
80
81
82
83
  n->reads = access[0];
  n->writes = access[1];
  return block;
}

/*! Structure that represents the provided annotation per block or loop. */
struct PipelineAnnotation {
  int stage;
  int order;
  bool async;
};

84
85
using PipelineInfo = std::unordered_map<Block, PipelineAnnotation,
                                        ObjectPtrHash, ObjectPtrEqual>;
86
87

struct BufferAccessInfo {
88
89
  int def = -1; // the defining stage of the buffer
  int use = -1; // the last using stage of the buffer
90
91
92
};

/*!
93
94
95
 * \brief Rewriter for the body of the software pipeline. This pass inserts
 * `floormod` to indices of the remapped buffer to select the version
 * corresponding to the pipeline stage.
96
97
 */
class PipelineBodyRewriter : public StmtExprMutator {
98
public:
99
100
101
  /*!
   * \brief Constructor of PipelineBodyRewriter.
   * \param buffer_data_to_buffer The map from buffer data to buffer.
102
103
104
105
106
107
   * \param buffer_remap The map from original buffer to the buffer with updated
   * shape for multi-versioning in the software pipeline. \param pipeline_loop
   * The original loop to be software pipelined. \param access_all_versions
   * Whether all versions the buffers in the software pipeline are accessed.
   * This will be used to update block access region. In the prologue and
   * epilogue of a two-stage software pipeline, only one version of these
108
   * buffers are accessed.
109
   */
110
111
112
  PipelineBodyRewriter(const Map<Var, Buffer> &buffer_data_to_buffer,
                       const Map<Buffer, Buffer> &buffer_remap,
                       For pipeline_loop, bool access_all_versions)
113
      : buffer_data_to_buffer_(buffer_data_to_buffer),
114
        buffer_remap_(buffer_remap), pipeline_loop_(std::move(pipeline_loop)),
115
        access_all_versions_(access_all_versions) {}
116

117
118
119
private:
  BufferRegion
  RewritePipelineBufferRegion(const BufferRegion &buffer_region) const {
120
121
122
    auto it = buffer_remap_.find(buffer_region->buffer);
    if (it != buffer_remap_.end()) {
      Region new_region = buffer_region->region;
123
124
125
      const Buffer &new_buffer = (*it).second;
      // For pipeline buffers, relax the access region of the first dimension to
      // full extent if access_all_versions == true
126
127
128
      Range accessed_version =
          access_all_versions_
              ? Range::FromMinExtent(0, new_buffer->shape[0])
129
130
131
132
              : Range::FromMinExtent(
                    floormod((pipeline_loop_->loop_var - pipeline_loop_->min),
                             new_buffer->shape[0]),
                    Integer(1));
133
134
135
136
137
138
      new_region.insert(new_region.begin(), accessed_version);
      return BufferRegion(new_buffer, new_region);
    }
    return buffer_region;
  }

139
  PrimExpr RewriteBufferAccess(const Call &call,
140
                               const std::vector<int> &arg_indices) {
141
142
    auto product = [](const Array<PrimExpr> &input) {
      return foldl(
143
144
145
          [](PrimExpr a, PrimExpr b, Span span) {
            return mul(std::move(a), std::move(b), std::move(span));
          },
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
          make_const(DataType::Int(32), 1), input);
    };
    Array<PrimExpr> new_args = call->args;
    for (int i : arg_indices) {
      const Buffer &buffer =
          buffer_data_to_buffer_.at(Downcast<Var>(call->args[i]));
      auto it = buffer_remap_.find(buffer);
      if (it != buffer_remap_.end()) {
        const Buffer &new_buffer = (*it).second;
        const PrimExpr &old_index = call->args[i + 1];
        PrimExpr offset;
        if (new_buffer->strides.empty()) {
          offset = product(buffer->shape);
        } else {
          offset = new_buffer->strides[0];
        }
        PrimExpr new_index =
            old_index +
            floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
        new_args.Set(i + 1, new_index);
      }
    }
    return Call(call->dtype, call->op, new_args, call->span);
  }

171
172
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const Buffer &alloc_buffer : op->alloc_buffers) {
173
174
175
      buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer);
    }
    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
176
177
    BlockNode *n = block.CopyOnWrite();
    n->reads.MutateByApply([this](const BufferRegion &buffer_region) {
178
179
      return RewritePipelineBufferRegion(buffer_region);
    });
180
    n->writes.MutateByApply([this](const BufferRegion &buffer_region) {
181
182
      return RewritePipelineBufferRegion(buffer_region);
    });
183
    for (const Buffer &alloc_buffer : op->alloc_buffers) {
184
185
      buffer_data_to_buffer_.erase(alloc_buffer->data);
    }
186
    return block;
187
188
  }

189
  Stmt VisitStmt_(const BufferStoreNode *op) final {
190
191
192
    BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
    auto it = buffer_remap_.find(store->buffer);
    if (it == buffer_remap_.end()) {
193
      return store;
194
    }
195
196
    const Buffer &new_buffer = (*it).second;
    auto *n = store.CopyOnWrite();
197
    n->buffer = new_buffer;
198
199
    PrimExpr version = floormod(
        (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
200
    n->indices.insert(n->indices.begin(), version);
201
    return store;
202
203
  }

204
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
205
206
207
    BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
    auto it = buffer_remap_.find(load->buffer);
    if (it == buffer_remap_.end()) {
208
      return load;
209
    }
210
211
    const Buffer &new_buffer = (*it).second;
    auto *n = load.CopyOnWrite();
212
    n->buffer = new_buffer;
213
214
    PrimExpr version = floormod(
        (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
215
    n->indices.insert(n->indices.begin(), version);
216
    return load;
217
218
  }

219
  PrimExpr VisitExpr_(const CallNode *op) final {
220
    Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
221
222
223
224
    if (call->op.same_as(builtin::tvm_access_ptr())) {
      return RewriteBufferAccess(call, {1});
    }
    return call;
225
226
227
228
229
230
231
232
233
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Map<Buffer, Buffer> buffer_remap_;
  For pipeline_loop_;
  bool access_all_versions_;
};

/*!
234
235
 * \brief Rewriter for the software pipeline that rewrite a loop into a
 * pipelined one.
236
237
 */
class PipelineRewriter : public StmtExprMutator {
238
public:
239
240
  PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
                   const Array<Buffer> &pipeline_allocs,
241
242
                   const For &pipeline_loop, const PipelineInfo &pipeline_info,
                   const std::vector<LetWrapper> &loop_var_let_wrappers)
243
      : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
244
        pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop),
245
246
        pipeline_info_(pipeline_info),
        loop_var_let_wrappers_(loop_var_let_wrappers) {}
247
248

  Stmt BuildPipeline() {
249
250
251
252
253
    // Step 1: Analyze accesses to the buffers in the pipeline and compute the
    // number of versions need to maintain for each buffer.
    std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
        infos = GetBufferAccessInfo();
    for (const Buffer &buffer : pipeline_allocs_) {
254
255
256
257
258
259
      int num_versions = ComputeBufferVersions(buffer, infos.at(buffer));
      if (num_versions > 1) {
        buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions));
      }
    }
    ordered_stmts_.resize(pipeline_info_.size());
260
261
    for (const auto &[block, anno] : pipeline_info_) {
      ordered_stmts_.Set(anno.order, block);
262
263
    }

264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
    for (const Block &block : ordered_stmts_) {
      int stage = pipeline_info_[block].stage;
      if (pipeline_info_[block].async) {
        auto &state = async_states[stage];
        state.producer_head = pipeline_loop_->min - 1;
        for (auto write_region : block->writes) {
          auto buffer = write_region->buffer;
          state.dst_buffers.insert(buffer.get());
          if (buffer_remap_.count(buffer))
            state.dst_buffers.insert(buffer_remap_[buffer].get());
        }
      }
    }
    std::unordered_set<int> consumed;
    for (const Block &block : ordered_stmts_) {
      int stage = pipeline_info_[block].stage;
      if (pipeline_info_[block].async) {
        auto &state = async_states[stage];
        if (state.commit_groups.empty() || consumed.count(stage)) {
          state.commit_groups.push_back({});
        }
        state.commit_groups.back().push_back(pipeline_info_[block].order);
        consumed.erase(stage);
        for (auto write_region : block->writes) {
          auto buffer = buffer_remap_.count(write_region->buffer)
                            ? buffer_remap_[write_region->buffer]
                            : write_region->buffer;
          state.buffer_to_commit_group_[buffer.get()] =
              state.commit_groups.size() - 1;
        }
      }

      for (auto read_region : block->reads) {
        for (const auto &[producer_stage_id, producer_state] : async_states) {
          if (producer_stage_id <= stage &&
              producer_state.writes(read_region->buffer)) {
            consumed.insert(producer_stage_id);
          }
        }
303
304
      }
    }
305
306
307
308
309
310
311
312
313
314

    // Step 2: Emit the pipeline prologue, body and epilogue.
    Stmt prologue = EmitImpl(pipeline_loop_->min,
                             pipeline_loop_->min + max_stage_, true, true);
    Stmt body =
        EmitImpl(pipeline_loop_->min + max_stage_,
                 pipeline_loop_->min + pipeline_loop_->extent, false, false);
    Stmt epilogue = EmitImpl(
        pipeline_loop_->min + pipeline_loop_->extent,
        pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true);
315
316
317

    SeqStmt stmt = SeqStmt({prologue, body, epilogue});

318
319
    // Step 3: Make a new block that contains new buffer allocations after
    // pipeline rewriting.
320
    Array<Buffer> alloc_buffers;
321
    for (const auto &alloc : pipeline_allocs_) {
322
323
324
325
326
327
328
329
      alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc));
      buffer_data_to_buffer_.erase(alloc->data);
    }
    Block block = MakeBlock(stmt, buffer_data_to_buffer_);
    block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers);
    return BlockRealize({}, Bool(true), block);
  }

330
private:
331
332
333
  /*!
   * \brief Analyze accesses to the buffers in the software pipeline.
   *
334
335
336
   * This method check the 'define' and 'use' stage of the buffers in the
   * software pipeline, which can be used to compute the number of versions
   * needed to maintain after rewriting.
337
338
339
   */
  std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
  GetBufferAccessInfo() {
340
341
342
343
    std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
        infos;
    for (const auto &pair : pipeline_info_) {
      const Block &block = pair.first;
344
345
346
      int stage = pair.second.stage;
      max_stage_ = std::max(max_stage_, stage);

347
      for (const BufferRegion &write : block->writes) {
348
349
350
        if (!infos.count(write->buffer)) {
          infos.emplace(write->buffer, BufferAccessInfo{});
        }
351
        auto &info = infos.at(write->buffer);
352
353
354
355
356
357
358
        if (info.def == -1) {
          info.def = stage;
        } else {
          info.def = std::min(info.def, stage);
        }
      }

359
      for (const BufferRegion &read : block->reads) {
360
361
362
        if (!infos.count(read->buffer)) {
          infos.emplace(read->buffer, BufferAccessInfo{});
        }
363
        auto &info = infos.at(read->buffer);
364
365
366
367
368
369
370
371
372
373
374
375
        info.use = std::max(info.use, stage);
      }
    }
    return infos;
  }

  /*!
   * \brief Check whether two regions have intersections.
   * \param region1 The first region.
   * \param region2 The second region.
   * \return Whether region1 and region2 have intersections.
   */
376
  bool MayConflict(const Region &region1, const Region &region2) {
377
378
379
380
381
382
383
384
385
386
387
388
389
390
    ICHECK(region1.size() == region2.size());
    for (size_t i = 0; i < region1.size(); i++) {
      Range dim1 = region1[i];
      Range dim2 = region2[i];
      auto int_set1 = arith::IntSet::FromRange(dim1);
      auto int_set2 = arith::IntSet::FromRange(dim2);
      if (arith::Intersect({int_set1, int_set2}).IsNothing()) {
        return false;
      }
    }
    return true;
  }

  /*!
391
392
   * \brief Compute the number of versions need to maintain for buffer accessed
   * in the software pipeline.
393
   *
394
395
396
397
398
399
   * This method applies liveness analysis to the target buffer to compute the
   * number of versions need to maintain during the software pipeline.
   * Annotation `attr::double_buffer_scope` is handled here which provides a way
   * to override the result of the analysis. Additional double buffering in the
   * software pipeline can be useful to eliminate synchronizations in GPU
   * devices.
400
401
402
403
404
   *
   * \param buffer The target buffer
   * \param buffer_info The access information of the target buffer.
   * \return The number of versions required for the target buffer.
   */
405
406
  int ComputeBufferVersions(const Buffer &buffer,
                            const BufferAccessInfo &buffer_info) {
407
    if (buffer_info.def == -1) {
408
409
      // Keep the original number of versions as buffers defined outside the
      // software pipeline should not be mutated.
410
411
412
413
      return 1;
    }

    // `use - def + 1` is a upper bound of the needed versions
414
415
    // We optimize a few case where the number of versions can be smaller than
    // the upper bound
416
    int num_versions = buffer_info.use - buffer_info.def + 1;
417
    if (num_versions >= 2) {
418
419
420
421
      // A special case when `use - def + 1 == 2`. Double buffering is only
      // needed in this case when these exists a reader block_i and a writer
      // block_j such that order(block_i) < order(block_j) and stage(block_i) <
      // stage(block_j) and the access regions of block_i and block_j overlap.
422
      bool need_multi_version = false;
423
424
425
      for (const auto &pair1 : pipeline_info_) {
        const Block &writer_block = pair1.first;
        const auto &writer_info = pair1.second;
426

427
428
429
        auto it1 = std::find_if(writer_block->writes.begin(),
                                writer_block->writes.end(),
                                [&](const BufferRegion &buffer_region) {
430
431
432
433
434
435
                                  return buffer_region->buffer.same_as(buffer);
                                });
        if (it1 == writer_block->writes.end()) {
          continue;
        }

436
437
438
439
440
441
442
443
        for (const auto &pair2 : pipeline_info_) {
          const Block &reader_block = pair2.first;
          const auto &reader_info = pair2.second;
          auto it2 = std::find_if(
              reader_block->reads.begin(), reader_block->reads.end(),
              [&](const BufferRegion &buffer_region) {
                return buffer_region->buffer.same_as(buffer);
              });
444
445
446
          if (it2 == reader_block->reads.end()) {
            continue;
          }
447
448
          if (writer_info.order < reader_info.order &&
              writer_info.stage < reader_info.stage &&
449
450
451
452
453
454
455
              MayConflict((*it1)->region, (*it2)->region)) {
            need_multi_version = true;
            break;
          }
        }
      }
      if (!need_multi_version) {
456
        num_versions--;
457
458
459
460
461
462
      }
    }
    return num_versions;
  }

  /*!
463
464
   * \brief Rewrite buffer allocation to keep multiple versions of original
   * buffer for pipelined accesses. \param buffer The buffer to be resized.
465
466
467
   * \param num_versions The number of versions to keep.
   * \return The resized buffer.
   */
468
  Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
469
470
    ObjectPtr<BufferNode> new_buffer =
        tvm::ffi::make_object<BufferNode>(*(buffer.get()));
471
    new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
472
    if (!new_buffer->strides.empty()) {
473
474
475
476
477
478
479
      ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
      PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1];
      new_buffer->strides.insert(new_buffer->strides.begin(), stride_0);
    }
    return Buffer(new_buffer);
  }

480
481
  // Per-stage states that need to be tracked across pipeline prologue, body,
  // and epilogue.
482
483
  struct AsyncStateGlobal {
    // Buffers that this stage asynchronously writes.
484
485
486
487
488
489
490
    std::unordered_set<const BufferNode *> dst_buffers;
    // An imaginary index that the latest async operation associated with this
    // stage has written into. Only valid if all associated predicates are true,
    // so that we can count the number of async invocations exactly. When it is
    // valid, it is the "sum of extents of loops that have been executed" - 1,
    // e.g. for epilogue it is prologue extent + body extent - 1. This is only
    // needed to compute wait count for epilogue without async producers.
491
492
493
    PrimExpr producer_head;
    std::vector<std::vector<int>> commit_groups;
    std::unordered_map<const BufferNode *, int> buffer_to_commit_group_;
494
495
496
    bool writes(const Buffer &buf) const {
      return dst_buffers.count(buf.get()) > 0;
    }
497
498
  };

499
500
  // Per-stage states that are local to each of pipeline prologue, body, and
  // epilogue.
501
  struct AsyncStateLocal {
502
    struct PendingWait {
503
504
      // The index into a list of blocks, where async_wait_queue should be
      // attached at the beginning.
505
      int insert_before;
506
507
      // in_flight_count would be a more precise name, but the implementation
      // uses wait_count for brevity.
508
509
510
      PrimExpr wait_count{nullptr};

      bool valid() const { return wait_count.defined(); }
511
512
513
    };

    std::vector<PendingWait> pending_waits;
514

515
516
    // A symbolic expression representing the index the latest async operation
    // associated with this stage has written into, at the "current" iteration.
517
518
519
520
521
522
    Optional<PrimExpr> producer_head;
  };

  /*! Structure holding intermediate information for pipeline loop rewriting. */
  struct RewrittenBlockInfo {
    int stage;
523
    int order;
524
525
526
527
528
529
    PrimExpr predicate;
    Block block;
    PrimExpr access_index;
    bool is_async;
  };

530
531
  void PopulateWaitCounts(const std::vector<RewrittenBlockInfo> &new_blocks,
                          std::map<int, AsyncStateLocal> *async_states_local) {
532
533
534
    for (size_t i = 0; i < new_blocks.size(); ++i) {
      int producer_stage_idx = -1;
      for (auto read_region : new_blocks[i].block->reads) {
535
536
537
        for (const auto &[stage, state] : async_states) {
          if (stage <= new_blocks[i].stage &&
              state.writes(read_region->buffer)) {
538
539
            // Found an earlier stage where read_region->buffer was
            // asynchronously written
540
            ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage)
541
                << "A dependency on multiple async stages is not supported";
542
            producer_stage_idx = stage;
543
544
545
          }
        }
      }
546
547
      if (producer_stage_idx == -1)
        continue;
548
      const auto &state = async_states[producer_stage_idx];
549
      auto &dep_local_state = (*async_states_local)[producer_stage_idx];
550
551
552
553
554
555
556
557
558
559
560
      PrimExpr in_flight_cnt = 0;
      for (const auto &group : state.commit_groups) {
        PrimExpr consumer_head = new_blocks[i].access_index;
        PrimExpr producer_head;
        if (dep_local_state.producer_head.defined()) {
          producer_head = dep_local_state.producer_head.value();
          // if the group is after the wait point, minus by 1
          if (group.front() > new_blocks[i].order)
            producer_head -= 1;
        } else {
          producer_head = state.producer_head;
561
        }
562
        in_flight_cnt += producer_head - consumer_head;
563
564
      }

565
566
567
568
569
570
571
572
573
574
575
576
      // We can relax the in-flight-count by the number of independent commit.
      std::unordered_set<int> dependent_groups;
      for (const auto &read_region : new_blocks[i].block->reads) {
        if (state.buffer_to_commit_group_.count(read_region->buffer.get()))
          dependent_groups.insert(
              state.buffer_to_commit_group_.at(read_region->buffer.get()));
      }
      for (int i = int(state.commit_groups.size()) - 1; i >= 0; i--) {
        if (dependent_groups.count(i) == 0)
          in_flight_cnt += 1;
        else
          break; // stop relaxing
577
      }
578
579
580
      in_flight_cnt = analyzer_.Simplify(in_flight_cnt);
      dep_local_state.pending_waits.push_back(
          {static_cast<int>(i), in_flight_cnt});
581
582
583
    }
  }

584
585
  // Given pipelined blocks and async-related information, generate final loop
  // statements with async scopes (if any).
586
  Array<Stmt> CompletePipelineLoopStatements(
587
      const std::vector<RewrittenBlockInfo> &blocks,
588
      const std::map<int, AsyncStateLocal> &async_states_local) const {
589
    std::vector<RewrittenBlockInfo> new_blocks = blocks;
590
    for (const auto &[stage_id, state] : async_states_local) {
591
592
593
594
595
596
597
      for (const auto &pw : state.pending_waits) {
        auto &block = new_blocks[pw.insert_before].block;
        BlockNode *n = block.CopyOnWrite();
        auto zero = make_zero(DataType::Int(32));
        n->body = AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
                           AttrStmt(zero, tir::attr::async_wait_inflight_count,
                                    pw.wait_count, n->body));
598
      }
599
    }
600

601
602
603
604
605
    // mark the last async stmt as commit
    std::unordered_set<int> commit_group_indices;
    for (const auto &[stage_id, state] : async_states) {
      for (size_t i = 0; i < state.commit_groups.size(); ++i) {
        commit_group_indices.insert(state.commit_groups[i].back());
606
607
608
609
610
      }
    }

    Array<Stmt> stmts;

611
612
613
614
615
616
617
    for (size_t i = 0; i < new_blocks.size(); i++) {
      Block block = new_blocks[i].block;
      if (commit_group_indices.count(new_blocks[i].order)) {
        auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)),
                                           tir::attr::async_commit_queue_scope,
                                           new_blocks[i].stage, block->body);
        block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_);
618
      }
619
      stmts.push_back(BlockRealize({}, new_blocks[i].predicate, block));
620
621
622
623
624
625
626
627
628
629
630
631
    }

    return stmts;
  }

  /*!
   * \brief Emit the pipeline loop in the given range.
   * \param start The start of the range
   * \param end The end of the range
   * \param unroll_loop Whether the loop should be unrolled.
   * \return The result loop.
   */
632
  Stmt EmitImpl(const PrimExpr &start, const PrimExpr &end, bool unroll_loop,
633
                bool need_bound_check) {
634
635
    PrimExpr new_loop_var;
    PrimExpr extent = end - start;
636
637
638
    auto make_nop = []() {
      return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {}));
    };
639
640
641

    bool is_unit_loop = analyzer_.CanProveEqual(extent, 1);
    if (is_unit_loop) {
642
      new_loop_var = start; // use constants as the loop var for unit loops
643
644
645
646
647
648
649
650
651
652
    } else {
      new_loop_var = pipeline_loop_->loop_var.copy_with_suffix("");
      analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end));
    }

    std::vector<RewrittenBlockInfo> new_blocks;

    // Async related
    std::map<int, AsyncStateLocal> async_states_local;

653
    for (const Block &block : ordered_stmts_) {
654
      int stage = pipeline_info_.at(block).stage;
655
656
      int order = pipeline_info_.at(block).order;
      PrimExpr inbound = Bool(true);
657
      PrimExpr skewed_loop_var = new_loop_var - stage;
658
659
660
661
      if (need_bound_check)
        inbound =
            analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) &&
            (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent);
662
663
664
      if (analyzer_.CanProve(!inbound)) {
        continue;
      }
665
666
667
      Block new_block = Downcast<Block>(
          PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_,
                               pipeline_loop_, max_stage_ != 1)(block));
668
669
670
671

      PrimExpr delta = start - pipeline_loop_->min;
      // This variable corresponds to
      // - "producer_head" if this stage is an async producer
672
673
      // - "consumer_head" if this stage reads from asynchronously written
      // buffers.
674
      PrimExpr normalized_access_index =
675
          is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
676

677
678
      // Adjust the block predicate and the body according to the final loop
      // bound
679
680
681
682
683
      //  [pipeline_loop_->min, extent).
      if (!is_unit_loop) {
        Var loop_iter = Downcast<Var>(new_loop_var);
        inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}});
      }
684
685
      new_block = Downcast<Block>(Substitute(
          new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
686

687
688
689
690
691
692
693
694
695
696
697
698
699
700
      // If there were Let-wrappers outside the original pipeline body that
      // depended on the pipeline loop var, push them into each rewritten
      // block with the correct per-block substitution.
      if (!loop_var_let_wrappers_.empty()) {
        BlockNode *n = new_block.CopyOnWrite();
        Stmt inner = n->body;
        for (const auto &lw : loop_var_let_wrappers_) {
          PrimExpr substituted = Substitute(
              lw.value, {{pipeline_loop_->loop_var, normalized_access_index}});
          inner = LetStmt(lw.var, substituted, inner);
        }
        n->body = inner;
      }

701
      if (pipeline_info_[block].async) {
702
        auto &local_state = async_states_local[stage];
703
        local_state.producer_head = normalized_access_index;
704
705
706
        BlockNode *n = new_block.CopyOnWrite();
        n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope,
                           1, n->body);
707
708
      }

709
710
      new_blocks.push_back({stage, order, inbound, new_block,
                            normalized_access_index,
711
                            pipeline_info_[block].async});
712
    }
713

714
715
716
    PopulateWaitCounts(new_blocks, &async_states_local);

    auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local);
717

718
719
    Stmt new_loop{nullptr};

720
    if (stmts.empty()) {
721
722
      return make_nop();
    }
723

724
725
    if (stmts.size() == 1) {
      new_loop = stmts[0];
726
    } else {
727
      new_loop = SeqStmt(stmts);
728
729
730
    }

    if (!is_unit_loop) {
731
732
733
734
735
736
737
738
739
      Map<String, Any> preserved_annotations;
      for (const auto &kv : pipeline_loop_->annotations) {
        const String &key = kv.first;
        if (kv.first != tir::attr::software_pipeline_stage &&
            kv.first != tir::attr::software_pipeline_order &&
            kv.first != tir::attr::software_pipeline_async_stages) {
          preserved_annotations.Set(key, kv.second);
        }
      }
740
      new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
741
                     unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind,
742
                     std::move(new_loop), std::nullopt, preserved_annotations);
743
744
    }
    // Update producer heads in the global async states.
745
746
    for (const auto &[stage_id, state] : async_states_local) {
      async_states[stage_id].producer_head += extent;
747
748
    }

749
    return BlockRealize({}, Bool(true),
750
                        MakeBlock(new_loop, buffer_data_to_buffer_));
751
752
753
754
755
756
757
758
759
760
761
  }

  arith::Analyzer analyzer_;
  Map<Var, Buffer> buffer_data_to_buffer_;
  Array<Buffer> pipeline_allocs_;
  For pipeline_loop_;
  PipelineInfo pipeline_info_;
  int max_stage_ = -1;
  Map<Buffer, Buffer> buffer_remap_;
  Array<Block> ordered_stmts_;
  std::map<int, AsyncStateGlobal> async_states;
762
  std::vector<LetWrapper> loop_var_let_wrappers_;
763
764
765
766
767
};

/*!
 * \brief Build the dependency graph among a array of blocks.
 * \param[in] blocks The array of blocks.
768
769
770
 * \param[out] dep_src2dst Optional, a map to store dependency edges from the
 * source to the destination. \param[out] dep_dst2src Optional, a map to store
 * dependency edges from the destination to the source.
771
 */
772
773
774
775
776
void BuildDependencyGraph(const Array<Block> &blocks,
                          std::unordered_map<Block, Array<Block>, ObjectPtrHash,
                                             ObjectPtrEqual> *dep_src2dst,
                          std::unordered_map<Block, Array<Block>, ObjectPtrHash,
                                             ObjectPtrEqual> *dep_dst2src) {
777
778
  std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
      buffer_writers;
779
780
781

  for (const Block &block : blocks) {
    for (const BufferRegion &read : block->reads) {
782
783
      auto it = buffer_writers.find(read->buffer->data);
      if (it != buffer_writers.end()) {
784
        for (const Block &writer : it->second) {
785
786
787
788
789
790
791
792
793
          if (dep_src2dst != nullptr) {
            (*dep_src2dst)[writer].push_back(block);
          }
          if (dep_dst2src != nullptr) {
            (*dep_dst2src)[block].push_back(writer);
          }
        }
      }
    }
794
    for (const BufferRegion &write : block->writes) {
795
796
797
798
799
800
      buffer_writers[write->buffer->data].push_back(block);
    }
  }
}

class PipelineInjector : private StmtExprMutator {
801
802
public:
  static Stmt Inject(const PrimFunc &func) {
803
804
    auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
    PipelineInjector injector(global_symbol);
805
806
    for (const auto &kv : func->buffer_map) {
      const Buffer &buffer = kv.second;
807
808
809
810
811
      injector.buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    return injector(func->body);
  }

812
813
private:
  explicit PipelineInjector(Optional<String> global_symbol)
814
      : global_symbol_(std::move(global_symbol)) {}
815
816
817
818

  /*!
   * \brief Check the pipeline satisfies the following conditions:
   * 1. No conflicting order: The order of each statement should be unique.
819
820
821
822
   * 2. Reordering of statements doesn't break buffer access dependencies.
   * Specifically, for dependency (e.g. read-after-write) from statement A to
   * statement B, it requires: case 1: stage(A) < stage(B) case 2: stage(A) ==
   * stage(B) and order(A) < order(B)
823
   */
824
825
  void ValidatePipelineBody(const PipelineInfo &pipeline_info,
                            const Array<Block> &original_order) {
826
827
    std::unordered_set<int> used_orders;
    std::unordered_map<int, int> stage_max_order;
828
829
830
831
    std::unordered_map<int, const Block *> order_to_block;
    std::unordered_map<const Block *, int> block_to_stage;
    for (const Block &block : original_order) {
      const auto &stmt_info = pipeline_info.at(block);
832
833
      int order = stmt_info.order;
      CHECK(!used_orders.count(order))
834
835
          << "ValueError: Two statements in the software pipeline cannot have "
             "the same order";
836
837
838
      used_orders.insert(order);
    }

839
840
    std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
        dep_src2dst;
841
842
    BuildDependencyGraph(original_order, &dep_src2dst, nullptr);

843
844
845
846
847
848
    for (const auto &pair : dep_src2dst) {
      const Block &src = pair.first;
      const auto &src_info = pipeline_info.at(src);
      const Array<Block> &dsts = pair.second;
      for (const Block &dst : dsts) {
        const auto &dst_info = pipeline_info.at(dst);
849
850
        CHECK_LE(src_info.stage, dst_info.stage)
            << "ValueError: statement " << dst << " in stage " << dst_info.stage
851
852
            << " cannot depends on statement " << src << " in a later stage "
            << src_info.stage;
853
        if (src_info.stage == dst_info.stage) {
854
855
856
857
          CHECK_LT(src_info.order, dst_info.order)
              << "ValueError: two statements with buffer "
                 "access dependency in the same stage of the "
                 "software pipeline cannot be reordered";
858
859
860
861
862
        }
      }
    }
  }

863
  Stmt VisitStmt_(const ForNode *op) final {
864
865
866
    // Step 1: Recursively rewrite the children first.
    For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
    if (!HasPipelineAnnotation(op)) {
867
      return for_node;
868
    }
869
870
871
    // Step 2: Find the body and buffer allocations of the pipeline. The body
    // can be direct child of the for-loop. If the for-loop has BlockRealize as
    // its child, the pipeline body will be the child of the block.
872
873
    Stmt pipeline_body_root{nullptr};
    bool pipeline_body_from_block = false;
874
    Array<Buffer> pipeline_allocs;
875
876
877
    if (const auto *realize = for_node->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
878
879
880
        ICHECK(buffer->IsInstance<BufferNode>());
        buffer_data_to_buffer_.Set(buffer->data, buffer);
      }
881
      pipeline_body_root = block->body;
882
      pipeline_allocs = block->alloc_buffers;
883
      pipeline_body_from_block = true;
884
    } else {
885
886
887
888
889
      pipeline_body_root = for_node->body;
    }

    const SeqStmtNode *pipeline_body_seq = nullptr;
    std::vector<std::function<Stmt(Stmt)>> rewrap_fns;
890
    std::vector<LetWrapper> loop_var_let_wrappers;
891
    auto append_attr_wrapper = [&rewrap_fns](const AttrStmtNode *attr) {
892
      Any node = attr->node;
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
      String attr_key = attr->attr_key;
      PrimExpr value = attr->value;
      Span span = attr->span;
      rewrap_fns.emplace_back(
          [node = std::move(node), attr_key = std::move(attr_key),
           value = std::move(value), span](Stmt body) -> Stmt {
            return AttrStmt(node, attr_key, value, body, span);
          });
    };
    {
      Stmt current = pipeline_body_root;
      while (true) {
        if (const auto *seq_stmt = current.as<SeqStmtNode>()) {
          pipeline_body_seq = seq_stmt;
          break;
        }
        if (const auto *if_then_else = current.as<IfThenElseNode>()) {
          ICHECK(!if_then_else->else_case.defined())
              << "InjectSoftwarePipeline: Can't handle the body of the loop "
                 "because the IfThenElse node has an else branch";
          PrimExpr condition = if_then_else->condition;
          Span span = if_then_else->span;
          rewrap_fns.emplace_back(
              [condition = std::move(condition), span](Stmt body) -> Stmt {
                return IfThenElse(condition, body, Stmt(), span);
              });
          current = if_then_else->then_case;
          continue;
        }
        if (const auto *let_stmt = current.as<LetStmtNode>()) {
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
          // If this Let value uses the pipeline loop var, record it and push
          // inside each rewritten block later so the loop var can be
          // substituted with the correct per-iteration index. Otherwise, keep
          // it as a normal wrapper.
          bool uses_loop_var = UsesVar(
              let_stmt->value,
              [v = op->loop_var.get()](const VarNode *vn) { return vn == v; });
          if (uses_loop_var) {
            loop_var_let_wrappers.push_back({let_stmt->var, let_stmt->value});
          } else {
            Var var = let_stmt->var;
            PrimExpr value = let_stmt->value;
            Span span = let_stmt->span;
            rewrap_fns.emplace_back([var = std::move(var),
                                     value = std::move(value),
                                     span](Stmt body) -> Stmt {
              return LetStmt(var, value, body, span);
            });
          }
942
943
944
945
946
947
948
949
950
951
952
          current = let_stmt->body;
          continue;
        }
        if (const auto *attr = current.as<AttrStmtNode>()) {
          append_attr_wrapper(attr);
          current = attr->body;
          continue;
        }
        LOG(FATAL) << "ValueError: The body of the software pipeline should be "
                   << "SeqStmt, got " << current->GetTypeKey();
      }
953
    }
954
    ICHECK(pipeline_body_seq != nullptr);
955

956
957
    // Step 3: Blockize the components of the pipeline. Each child of the
    // pipelined loop will be converted into a block.
958
    PipelineInfo pipeline_info;
959
    Array<Block> original_order; // pipeline body blocks in the original order
960

961
    auto f_add_child = [&](const Stmt &child) {
962
963
964
      original_order.push_back(MakeBlock(child, buffer_data_to_buffer_));
    };
    for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) {
965
966
      const Stmt &child = pipeline_body_seq->seq[i];
      const auto *nested_block_realize = child.as<BlockRealizeNode>();
967
968
      if (nested_block_realize && is_one(nested_block_realize->predicate) &&
          nested_block_realize->block->body->IsInstance<SeqStmtNode>()) {
969
970
971
972
        const Block &nested_pipeline_block = nested_block_realize->block;
        ICHECK(nested_pipeline_block->match_buffers
                   .empty()); // match_buffer should have been lowered
        for (const auto &buffer : nested_pipeline_block->alloc_buffers) {
973
974
975
976
          pipeline_allocs.push_back(buffer);
          buffer_data_to_buffer_.Set(buffer->data, buffer);
        }
      }
977
      f_add_child(child);
978
979
    }

980
981
982
983
    auto pipeline_stages = Downcast<Array<Integer>>(
        op->annotations.at(tir::attr::software_pipeline_stage));
    auto pipeline_orders = Downcast<Array<Integer>>(
        op->annotations.at(tir::attr::software_pipeline_order));
984
985
    CHECK_EQ(pipeline_stages.size(), original_order.size())
        << "PrimFunc " << global_symbol_ << " has original order "
986
987
988
989
        << original_order.Map(
               [](const auto &block) { return block->name_hint; })
        << ", but pipeline annotation is " << pipeline_stages
        << " with different size";
990
991
    CHECK_EQ(pipeline_orders.size(), original_order.size())
        << "PrimFunc " << global_symbol_ << " has original order "
992
993
994
995
        << original_order.Map(
               [](const auto &block) { return block->name_hint; })
        << ", but pipeline annotation is " << pipeline_orders
        << " with different size";
996
997

    std::unordered_set<int> pipeline_async_stages;
998
999
    if (auto annot =
            op->annotations.Get(tir::attr::software_pipeline_async_stages)) {
1000
      for (auto s : Downcast<Array<Integer>>(annot.value())) {
1001
1002
1003
1004
1005
1006
        pipeline_async_stages.insert(s->value);
      }
    }

    for (size_t i = 0; i < pipeline_stages.size(); i++) {
      int stage = static_cast<int>(pipeline_stages[i]->value);
1007
1008
1009
1010
1011
      bool is_async =
          pipeline_async_stages.find(stage) != pipeline_async_stages.end();
      PipelineAnnotation stage_order{
          stage,
          /*order=*/static_cast<int>(pipeline_orders[i]->value), is_async};
1012
1013
1014
1015
1016
1017
      pipeline_info.emplace(original_order[i], stage_order);
    }

    ValidatePipelineBody(pipeline_info, original_order);

    // Step 4: Rewrite the pipeline body.
1018
    Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
1019
1020
                                     tvm::ffi::GetRef<For>(op), pipeline_info,
                                     loop_var_let_wrappers)
1021
                        .BuildPipeline();
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
    auto apply_wrappers = [&](Stmt stmt) {
      for (auto it = rewrap_fns.rbegin(); it != rewrap_fns.rend(); ++it) {
        stmt = (*it)(stmt);
      }
      return stmt;
    };
    if (!rewrap_fns.empty()) {
      if (pipeline_body_from_block) {
        BlockRealize pipeline_realize = Downcast<BlockRealize>(pipeline);
        Block pipeline_block = pipeline_realize->block;
        {
          BlockNode *block_node = pipeline_block.CopyOnWrite();
          block_node->body = apply_wrappers(block_node->body);
        }
        pipeline = BlockRealize(pipeline_realize->iter_values,
                                pipeline_realize->predicate, pipeline_block,
                                pipeline_realize->span);
      } else {
        pipeline = apply_wrappers(pipeline);
      }
    }
1043

1044
1045
1046
    if (const auto *realize = op->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
1047
1048
1049
1050
1051
1052
        buffer_data_to_buffer_.erase(buffer->data);
      }
    }
    return pipeline;
  }

1053
1054
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const auto &buffer : op->alloc_buffers) {
1055
1056
1057
1058
1059
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }

    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));

1060
1061
1062
1063
1064
1065
    Array<Array<BufferRegion>> access =
        GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
    BlockNode *n = block.CopyOnWrite();
    n->reads = access[0];
    n->writes = access[1];

1066
    for (const auto &buffer : op->alloc_buffers) {
1067
1068
      buffer_data_to_buffer_.erase(buffer->data);
    }
1069
    return block;
1070
1071
  }

1072
  bool HasPipelineAnnotation(const ForNode *op) const {
1073
1074
1075
1076
1077
1078
1079
1080
    auto it1 = op->annotations.find(tir::attr::software_pipeline_stage);
    auto it2 = op->annotations.find(tir::attr::software_pipeline_order);
    bool has_stage = it1 != op->annotations.end();
    bool has_order = it2 != op->annotations.end();
    if (has_stage && has_order) {
      return true;
    }
    if (has_stage) {
1081
      LOG(FATAL)
1082
          << "ValueError: Stage of the software pipeline is not defined.";
1083
1084
    }
    if (has_order) {
1085
      LOG(FATAL)
1086
          << "ValueError: Order of the software pipeline is not defined.";
1087
1088
1089
1090
1091
1092
1093
    }
    return false;
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Optional<String> global_symbol_;
};
1094
1095
} // namespace software_pipeline

1096
/*!
1097
1098
 * \brief Transform annotated loops into pipelined one that parallelize
 * producers and consumers. \return The IR transform pass.
1099
1100
1101
 */
tir::transform::Pass InjectSoftwarePipeline() {
  using namespace tir::transform;
1102
  auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
1103
    auto *fptr = f.CopyOnWrite();
1104
    fptr->body = software_pipeline::PipelineInjector::Inject(f);
1105
1106
1107
1108
1109
1110
    fptr->body = ConvertSSA(std::move(fptr->body));
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {});
}

1111
TVM_FFI_STATIC_INIT_BLOCK() {
1112
1113
1114
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.InjectSoftwarePipeline",
                        InjectSoftwarePipeline);
1115
}
1116

1117
1118
} // namespace tl
} // namespace tvm