inject_pipeline.cc 37.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file inject_software_pipeline.cc
22
23
 * \brief Transform annotated loops into pipelined one that parallelize
 * producers and consumers
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
 */
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h>

#include <unordered_set>

#include "support/utils.h"
#include "tir/schedule/utils.h"
#include "tir/transforms/ir_utils.h"

namespace tvm {
namespace tl {
using namespace tir;

/*!
 * \brief Create a block and infer the access region with the given body.
 *
42
43
44
 * The result is a opaque block that doesn't contain any block iter vars. In
 * case the body is a block realize without predicate, it is unnecessary to
 * create a new block, the block of the block realize will be returned.
45
46
47
48
49
 *
 * \param body The body of the block.
 * \param buffer_data_to_buffer The map from buffer data to buffer.
 * \return The result block.
 */
50
51
52
Block MakeBlock(const Stmt &body,
                const Map<Var, Buffer> &buffer_data_to_buffer) {
  if (const BlockRealizeNode *block_realize = body.as<BlockRealizeNode>()) {
53
54
55
56
57
    if (is_one(block_realize->predicate)) {
      // no need to create a new block
      return block_realize->block;
    }
  }
58
59
60
61
62
  Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
              /*body*/ body);
  Array<Array<BufferRegion>> access =
      GetBlockReadWriteRegion(block, buffer_data_to_buffer);
  BlockNode *n = block.CopyOnWrite();
63
64
65
66
67
68
69
70
71
72
73
74
  n->reads = access[0];
  n->writes = access[1];
  return block;
}

/*! Structure that represents the provided annotation per block or loop. */
struct PipelineAnnotation {
  int stage;
  int order;
  bool async;
};

75
76
using PipelineInfo = std::unordered_map<Block, PipelineAnnotation,
                                        ObjectPtrHash, ObjectPtrEqual>;
77
78

struct BufferAccessInfo {
79
80
  int def = -1; // the defining stage of the buffer
  int use = -1; // the last using stage of the buffer
81
82
83
};

/*!
84
85
86
 * \brief Rewriter for the body of the software pipeline. This pass inserts
 * `floormod` to indices of the remapped buffer to select the version
 * corresponding to the pipeline stage.
87
88
 */
class PipelineBodyRewriter : public StmtExprMutator {
89
public:
90
91
92
  /*!
   * \brief Constructor of PipelineBodyRewriter.
   * \param buffer_data_to_buffer The map from buffer data to buffer.
93
94
95
96
97
98
99
   * \param buffer_remap The map from original buffer to the buffer with updated
   * shape for multi-versioning in the software pipeline. \param pipeline_loop
   * The original loop to be software pipelined. \param access_all_versions
   * Whether all versions the buffers in the software pipeline are accessed.
   * This will be used to update block access region. In the prologue and
   * epilogue of a two-stage software pipeline, only one version of these
   * buffers are accessed.
100
   */
101
102
103
  PipelineBodyRewriter(const Map<Var, Buffer> &buffer_data_to_buffer,
                       const Map<Buffer, Buffer> &buffer_remap,
                       For pipeline_loop, bool access_all_versions)
104
      : buffer_data_to_buffer_(buffer_data_to_buffer),
105
        buffer_remap_(buffer_remap), pipeline_loop_(pipeline_loop),
106
107
        access_all_versions_(access_all_versions) {}

108
109
110
private:
  BufferRegion
  RewritePipelineBufferRegion(const BufferRegion &buffer_region) const {
111
112
113
    auto it = buffer_remap_.find(buffer_region->buffer);
    if (it != buffer_remap_.end()) {
      Region new_region = buffer_region->region;
114
115
116
      const Buffer &new_buffer = (*it).second;
      // For pipeline buffers, relax the access region of the first dimension to
      // full extent if access_all_versions == true
117
118
119
      Range accessed_version =
          access_all_versions_
              ? Range::FromMinExtent(0, new_buffer->shape[0])
120
121
122
123
              : Range::FromMinExtent(
                    floormod((pipeline_loop_->loop_var - pipeline_loop_->min),
                             new_buffer->shape[0]),
                    Integer(1));
124
125
126
127
128
129
      new_region.insert(new_region.begin(), accessed_version);
      return BufferRegion(new_buffer, new_region);
    }
    return buffer_region;
  }

130
131
132
133
134
135
  PrimExpr RewriteBufferAccess(const Call &call,
                               const std::vector<int> arg_indices) {
    auto product = [](const Array<PrimExpr> &input) {
      return foldl(
          [](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
          make_const(DataType::Int(32), 1), input);
136
137
138
    };
    Array<PrimExpr> new_args = call->args;
    for (int i : arg_indices) {
139
140
      const Buffer &buffer =
          buffer_data_to_buffer_.at(Downcast<Var>(call->args[i]));
141
142
      auto it = buffer_remap_.find(buffer);
      if (it != buffer_remap_.end()) {
143
144
        const Buffer &new_buffer = (*it).second;
        const PrimExpr &old_index = call->args[i + 1];
145
146
147
148
149
150
151
        PrimExpr offset;
        if (new_buffer->strides.empty()) {
          offset = product(buffer->shape);
        } else {
          offset = new_buffer->strides[0];
        }
        PrimExpr new_index =
152
153
            old_index +
            floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
154
155
156
157
158
159
        new_args.Set(i + 1, new_index);
      }
    }
    return Call(call->dtype, call->op, new_args, call->span);
  }

160
161
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const Buffer &alloc_buffer : op->alloc_buffers) {
162
163
164
      buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer);
    }
    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
165
166
    BlockNode *n = block.CopyOnWrite();
    n->reads.MutateByApply([this](const BufferRegion &buffer_region) {
167
168
      return RewritePipelineBufferRegion(buffer_region);
    });
169
    n->writes.MutateByApply([this](const BufferRegion &buffer_region) {
170
171
      return RewritePipelineBufferRegion(buffer_region);
    });
172
    for (const Buffer &alloc_buffer : op->alloc_buffers) {
173
174
175
176
177
      buffer_data_to_buffer_.erase(alloc_buffer->data);
    }
    return std::move(block);
  }

178
  Stmt VisitStmt_(const BufferStoreNode *op) final {
179
180
181
182
183
    BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
    auto it = buffer_remap_.find(store->buffer);
    if (it == buffer_remap_.end()) {
      return std::move(store);
    }
184
185
    const Buffer &new_buffer = (*it).second;
    auto *n = store.CopyOnWrite();
186
    n->buffer = new_buffer;
187
188
    PrimExpr version = floormod(
        (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
189
190
191
192
    n->indices.insert(n->indices.begin(), version);
    return std::move(store);
  }

193
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
194
195
196
197
198
    BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
    auto it = buffer_remap_.find(load->buffer);
    if (it == buffer_remap_.end()) {
      return std::move(load);
    }
199
200
    const Buffer &new_buffer = (*it).second;
    auto *n = load.CopyOnWrite();
201
    n->buffer = new_buffer;
202
203
    PrimExpr version = floormod(
        (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
204
205
206
207
    n->indices.insert(n->indices.begin(), version);
    return std::move(load);
  }

208
  PrimExpr VisitExpr_(const CallNode *op) final {
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
    if (call->op.same_as(builtin::tvm_access_ptr())) {
      return RewriteBufferAccess(call, {1});
    }
    return call;
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Map<Buffer, Buffer> buffer_remap_;
  For pipeline_loop_;
  bool access_all_versions_;
};

/*!
223
224
 * \brief Rewriter for the software pipeline that rewrite a loop into a
 * pipelined one.
225
226
 */
class PipelineRewriter : public StmtExprMutator {
227
228
229
230
public:
  PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
                   const Array<Buffer> &pipeline_allocs,
                   const For &pipeline_loop, const PipelineInfo &pipeline_info)
231
232

      : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
233
        pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop),
234
235
236
        pipeline_info_(pipeline_info) {}

  Stmt BuildPipeline() {
237
238
239
240
241
    // Step 1: Analyze accesses to the buffers in the pipeline and compute the
    // number of versions need to maintain for each buffer.
    std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
        infos = GetBufferAccessInfo();
    for (const Buffer &buffer : pipeline_allocs_) {
242
243
244
245
246
247
248
      int num_versions = ComputeBufferVersions(buffer, infos.at(buffer));
      if (num_versions > 1) {
        buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions));
      }
    }

    ordered_stmts_.resize(pipeline_info_.size());
249
    for (const auto &[block, anno] : pipeline_info_) {
250
251
252
      ordered_stmts_.Set(anno.order, block);
    }

253
    for (const Block &block : ordered_stmts_) {
254
255
      int stage = pipeline_info_[block].stage;
      if (pipeline_info_[block].async) {
256
        auto &state = async_states[stage];
257
258
259
260
        state.producer_head = pipeline_loop_->min - 1;
        for (auto write_region : block->writes) {
          auto buffer = write_region->buffer;
          state.dst_buffers.insert(buffer.get());
261
262
          if (buffer_remap_.count(buffer))
            state.dst_buffers.insert(buffer_remap_[buffer].get());
263
264
265
266
        }
      }
    }
    std::unordered_set<int> consumed;
267
    for (const Block &block : ordered_stmts_) {
268
269
      int stage = pipeline_info_[block].stage;
      if (pipeline_info_[block].async) {
270
        auto &state = async_states[stage];
271
272
273
274
275
276
277
278
279
        if (state.commit_groups.empty() || consumed.count(stage)) {
          state.commit_groups.push_back({});
        }
        state.commit_groups.back().push_back(pipeline_info_[block].order);
        consumed.erase(stage);
        for (auto write_region : block->writes) {
          auto buffer = buffer_remap_.count(write_region->buffer)
                            ? buffer_remap_[write_region->buffer]
                            : write_region->buffer;
280
281
          state.buffer_to_commit_group_[buffer.get()] =
              state.commit_groups.size() - 1;
282
283
284
285
        }
      }

      for (auto read_region : block->reads) {
286
287
288
        for (const auto &[producer_stage_id, producer_state] : async_states) {
          if (producer_stage_id <= stage &&
              producer_state.writes(read_region->buffer)) {
289
290
291
292
293
294
295
            consumed.insert(producer_stage_id);
          }
        }
      }
    }

    // Step 2: Emit the pipeline prologue, body and epilogue.
296
297
298
299
300
301
302
303
    Stmt prologue = EmitImpl(pipeline_loop_->min,
                             pipeline_loop_->min + max_stage_, true, true);
    Stmt body =
        EmitImpl(pipeline_loop_->min + max_stage_,
                 pipeline_loop_->min + pipeline_loop_->extent, false, false);
    Stmt epilogue = EmitImpl(
        pipeline_loop_->min + pipeline_loop_->extent,
        pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true);
304
305
306

    SeqStmt stmt = SeqStmt({prologue, body, epilogue});

307
308
    // Step 3: Make a new block that contains new buffer allocations after
    // pipeline rewriting.
309
    Array<Buffer> alloc_buffers;
310
    for (const auto &alloc : pipeline_allocs_) {
311
312
313
314
315
316
317
318
      alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc));
      buffer_data_to_buffer_.erase(alloc->data);
    }
    Block block = MakeBlock(stmt, buffer_data_to_buffer_);
    block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers);
    return BlockRealize({}, Bool(true), block);
  }

319
private:
320
321
322
  /*!
   * \brief Analyze accesses to the buffers in the software pipeline.
   *
323
324
325
   * This method check the 'define' and 'use' stage of the buffers in the
   * software pipeline, which can be used to compute the number of versions
   * needed to maintain after rewriting.
326
327
328
   */
  std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
  GetBufferAccessInfo() {
329
330
331
332
    std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
        infos;
    for (const auto &pair : pipeline_info_) {
      const Block &block = pair.first;
333
334
335
      int stage = pair.second.stage;
      max_stage_ = std::max(max_stage_, stage);

336
      for (const BufferRegion &write : block->writes) {
337
338
339
        if (!infos.count(write->buffer)) {
          infos.emplace(write->buffer, BufferAccessInfo{});
        }
340
        auto &info = infos.at(write->buffer);
341
342
343
344
345
346
347
        if (info.def == -1) {
          info.def = stage;
        } else {
          info.def = std::min(info.def, stage);
        }
      }

348
      for (const BufferRegion &read : block->reads) {
349
350
351
        if (!infos.count(read->buffer)) {
          infos.emplace(read->buffer, BufferAccessInfo{});
        }
352
        auto &info = infos.at(read->buffer);
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
        info.use = std::max(info.use, stage);
      }
    }
    return infos;
  }

  /*!
   * \brief Check whether two regions have intersections.
   * \param region1 The first region.
   * \param region2 The second region.
   * \return Whether region1 and region2 have intersections.
   */
  bool MayConflict(Region region1, Region region2) {
    ICHECK(region1.size() == region2.size());
    for (size_t i = 0; i < region1.size(); i++) {
      Range dim1 = region1[i];
      Range dim2 = region2[i];
      auto int_set1 = arith::IntSet::FromRange(dim1);
      auto int_set2 = arith::IntSet::FromRange(dim2);
      if (arith::Intersect({int_set1, int_set2}).IsNothing()) {
        return false;
      }
    }
    return true;
  }

  /*!
380
381
   * \brief Compute the number of versions need to maintain for buffer accessed
   * in the software pipeline.
382
   *
383
384
385
386
387
388
   * This method applies liveness analysis to the target buffer to compute the
   * number of versions need to maintain during the software pipeline.
   * Annotation `attr::double_buffer_scope` is handled here which provides a way
   * to override the result of the analysis. Additional double buffering in the
   * software pipeline can be useful to eliminate synchronizations in GPU
   * devices.
389
390
391
392
393
   *
   * \param buffer The target buffer
   * \param buffer_info The access information of the target buffer.
   * \return The number of versions required for the target buffer.
   */
394
395
  int ComputeBufferVersions(const Buffer &buffer,
                            const BufferAccessInfo &buffer_info) {
396
    if (buffer_info.def == -1) {
397
398
      // Keep the original number of versions as buffers defined outside the
      // software pipeline should not be mutated.
399
400
401
402
      return 1;
    }

    // `use - def + 1` is a upper bound of the needed versions
403
404
    // We optimize a few case where the number of versions can be smaller than
    // the upper bound
405
406
    int num_versions = buffer_info.use - buffer_info.def + 1;
    if (num_versions >= 2) {
407
408
409
410
      // A special case when `use - def + 1 == 2`. Double buffering is only
      // needed in this case when these exists a reader block_i and a writer
      // block_j such that order(block_i) < order(block_j) and stage(block_i) <
      // stage(block_j) and the access regions of block_i and block_j overlap.
411
      bool need_multi_version = false;
412
413
414
      for (const auto &pair1 : pipeline_info_) {
        const Block &writer_block = pair1.first;
        const auto &writer_info = pair1.second;
415

416
417
418
        auto it1 = std::find_if(writer_block->writes.begin(),
                                writer_block->writes.end(),
                                [&](const BufferRegion &buffer_region) {
419
420
421
422
423
424
                                  return buffer_region->buffer.same_as(buffer);
                                });
        if (it1 == writer_block->writes.end()) {
          continue;
        }

425
426
427
428
429
430
431
432
        for (const auto &pair2 : pipeline_info_) {
          const Block &reader_block = pair2.first;
          const auto &reader_info = pair2.second;
          auto it2 = std::find_if(
              reader_block->reads.begin(), reader_block->reads.end(),
              [&](const BufferRegion &buffer_region) {
                return buffer_region->buffer.same_as(buffer);
              });
433
434
435
          if (it2 == reader_block->reads.end()) {
            continue;
          }
436
437
          if (writer_info.order < reader_info.order &&
              writer_info.stage < reader_info.stage &&
438
439
440
441
442
443
444
445
446
447
448
449
450
451
              MayConflict((*it1)->region, (*it2)->region)) {
            need_multi_version = true;
            break;
          }
        }
      }
      if (!need_multi_version) {
        num_versions--;
      }
    }
    return num_versions;
  }

  /*!
452
453
   * \brief Rewrite buffer allocation to keep multiple versions of original
   * buffer for pipelined accesses. \param buffer The buffer to be resized.
454
455
456
   * \param num_versions The number of versions to keep.
   * \return The resized buffer.
   */
457
  Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
458
459
460
461
462
463
464
465
466
467
    ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get()));
    new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
    if (new_buffer->strides.size()) {
      ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
      PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1];
      new_buffer->strides.insert(new_buffer->strides.begin(), stride_0);
    }
    return Buffer(new_buffer);
  }

468
469
  // Per-stage states that need to be tracked across pipeline prologue, body,
  // and epilogue.
470
471
  struct AsyncStateGlobal {
    // Buffers that this stage asynchronously writes.
472
473
474
475
476
477
478
    std::unordered_set<const BufferNode *> dst_buffers;
    // An imaginary index that the latest async operation associated with this
    // stage has written into. Only valid if all associated predicates are true,
    // so that we can count the number of async invocations exactly. When it is
    // valid, it is the "sum of extents of loops that have been executed" - 1,
    // e.g. for epilogue it is prologue extent + body extent - 1. This is only
    // needed to compute wait count for epilogue without async producers.
479
480
    PrimExpr producer_head;
    std::vector<std::vector<int>> commit_groups;
481
    std::unordered_map<const BufferNode *, int> buffer_to_commit_group_;
482
483
484
    bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; }
  };

485
486
  // Per-stage states that are local to each of pipeline prologue, body, and
  // epilogue.
487
488
  struct AsyncStateLocal {
    struct PendingWait {
489
490
      // The index into a list of blocks, where async_wait_queue should be
      // attached at the beginning.
491
      int insert_before;
492
493
      // in_flight_count would be a more precise name, but the implementation
      // uses wait_count for brevity.
494
495
496
497
498
499
500
      PrimExpr wait_count{nullptr};

      bool valid() const { return wait_count.defined(); }
    };

    std::vector<PendingWait> pending_waits;

501
502
    // A symbolic expression representing the index the latest async operation
    // associated with this stage has written into, at the "current" iteration.
503
504
505
506
507
508
509
510
511
512
513
514
515
    Optional<PrimExpr> producer_head;
  };

  /*! Structure holding intermediate information for pipeline loop rewriting. */
  struct RewrittenBlockInfo {
    int stage;
    int order;
    PrimExpr predicate;
    Block block;
    PrimExpr access_index;
    bool is_async;
  };

516
517
  void PopulateWaitCounts(const std::vector<RewrittenBlockInfo> &new_blocks,
                          std::map<int, AsyncStateLocal> *async_states_local) {
518
519
520
    for (size_t i = 0; i < new_blocks.size(); ++i) {
      int producer_stage_idx = -1;
      for (auto read_region : new_blocks[i].block->reads) {
521
522
523
524
525
        for (const auto &[stage, state] : async_states) {
          if (stage <= new_blocks[i].stage &&
              state.writes(read_region->buffer)) {
            // Found an earlier stage where read_region->buffer was
            // asynchronously written
526
527
528
529
530
531
            ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage)
                << "A dependency on multiple async stages is not supported";
            producer_stage_idx = stage;
          }
        }
      }
532
533
534
535
      if (producer_stage_idx == -1)
        continue;
      const auto &state = async_states[producer_stage_idx];
      auto &dep_local_state = (*async_states_local)[producer_stage_idx];
536
      PrimExpr in_flight_cnt = 0;
537
      for (const auto &group : state.commit_groups) {
538
539
540
541
542
        PrimExpr consumer_head = new_blocks[i].access_index;
        PrimExpr producer_head;
        if (dep_local_state.producer_head.defined()) {
          producer_head = dep_local_state.producer_head.value();
          // if the group is after the wait point, minus by 1
543
544
          if (group.front() > new_blocks[i].order)
            producer_head -= 1;
545
546
547
548
549
550
551
552
        } else {
          producer_head = state.producer_head;
        }
        in_flight_cnt += producer_head - consumer_head;
      }

      // We can relax the in-flight-count by the number of independent commit.
      std::unordered_set<int> dependent_groups;
553
      for (const auto &read_region : new_blocks[i].block->reads) {
554
        if (state.buffer_to_commit_group_.count(read_region->buffer.get()))
555
556
          dependent_groups.insert(
              state.buffer_to_commit_group_.at(read_region->buffer.get()));
557
558
559
560
561
      }
      for (int i = int(state.commit_groups.size()) - 1; i >= 0; i--) {
        if (dependent_groups.count(i) == 0)
          in_flight_cnt += 1;
        else
562
          break; // stop relaxing
563
564
      }
      in_flight_cnt = analyzer_.Simplify(in_flight_cnt);
565
566
      dep_local_state.pending_waits.push_back(
          {static_cast<int>(i), in_flight_cnt});
567
568
569
    }
  }

570
571
  // Given pipelined blocks and async-related information, generate final loop
  // statements with async scopes (if any).
572
  Array<Stmt> CompletePipelineLoopStatements(
573
574
      const std::vector<RewrittenBlockInfo> &blocks,
      const std::map<int, AsyncStateLocal> &async_states_local) const {
575
    std::vector<RewrittenBlockInfo> new_blocks = blocks;
576
577
578
579
    for (const auto &[stage_id, state] : async_states_local) {
      for (const auto &pw : state.pending_waits) {
        auto &block = new_blocks[pw.insert_before].block;
        BlockNode *n = block.CopyOnWrite();
580
        auto zero = make_zero(DataType::Int(32));
581
582
583
        n->body = AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
                           AttrStmt(zero, tir::attr::async_wait_inflight_count,
                                    pw.wait_count, n->body));
584
585
586
587
588
      }
    }

    // mark the last async stmt as commit
    std::unordered_set<int> commit_group_indices;
589
    for (const auto &[stage_id, state] : async_states) {
590
591
592
593
594
595
596
597
598
599
      for (size_t i = 0; i < state.commit_groups.size(); ++i) {
        commit_group_indices.insert(state.commit_groups[i].back());
      }
    }

    Array<Stmt> stmts;

    for (size_t i = 0; i < new_blocks.size(); i++) {
      Block block = new_blocks[i].block;
      if (commit_group_indices.count(new_blocks[i].order)) {
600
601
602
        auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)),
                                           tir::attr::async_commit_queue_scope,
                                           new_blocks[i].stage, block->body);
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
        block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_);
      }
      stmts.push_back(BlockRealize({}, new_blocks[i].predicate, block));
    }

    return stmts;
  }

  /*!
   * \brief Emit the pipeline loop in the given range.
   * \param start The start of the range
   * \param end The end of the range
   * \param unroll_loop Whether the loop should be unrolled.
   * \return The result loop.
   */
618
619
  Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop,
                bool need_bound_check) {
620
621
622
    PrimExpr new_loop_var;
    PrimExpr extent = end - start;

623
624
625
    auto make_nop = []() {
      return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {}));
    };
626
627
628

    bool is_unit_loop = analyzer_.CanProveEqual(extent, 1);
    if (is_unit_loop) {
629
      new_loop_var = start; // use constants as the loop var for unit loops
630
631
632
633
634
635
636
637
638
639
    } else {
      new_loop_var = pipeline_loop_->loop_var.copy_with_suffix("");
      analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end));
    }

    std::vector<RewrittenBlockInfo> new_blocks;

    // Async related
    std::map<int, AsyncStateLocal> async_states_local;

640
    for (const Block &block : ordered_stmts_) {
641
642
643
644
645
      int stage = pipeline_info_.at(block).stage;
      int order = pipeline_info_.at(block).order;
      PrimExpr inbound = Bool(true);
      PrimExpr skewed_loop_var = new_loop_var - stage;
      if (need_bound_check)
646
647
648
        inbound =
            analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) &&
            (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent);
649
650
651
      if (analyzer_.CanProve(!inbound)) {
        continue;
      }
652
653
654
      Block new_block = Downcast<Block>(
          PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_,
                               pipeline_loop_, max_stage_ != 1)(block));
655
656
657
658

      PrimExpr delta = start - pipeline_loop_->min;
      // This variable corresponds to
      // - "producer_head" if this stage is an async producer
659
660
661
662
      // - "consumer_head" if this stage reads from asynchronously written
      // buffers.
      PrimExpr normalized_access_index =
          is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
663

664
665
      // Adjust the block predicate and the body according to the final loop
      // bound
666
667
668
669
670
671
      //  [pipeline_loop_->min, extent).
      if (!is_unit_loop) {
        Var loop_iter = Downcast<Var>(new_loop_var);
        inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}});
      }

672
673
      new_block = Downcast<Block>(Substitute(
          new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
674
675

      if (pipeline_info_[block].async) {
676
        auto &local_state = async_states_local[stage];
677
        local_state.producer_head = normalized_access_index;
678
679
680
        BlockNode *n = new_block.CopyOnWrite();
        n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope,
                           1, n->body);
681
682
      }

683
684
685
      new_blocks.push_back({stage, order, inbound, new_block,
                            normalized_access_index,
                            pipeline_info_[block].async});
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
    }

    PopulateWaitCounts(new_blocks, &async_states_local);
    auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local);

    Stmt new_loop{nullptr};

    if (stmts.empty()) {
      return make_nop();
    }
    if (stmts.size() == 1) {
      new_loop = stmts[0];
    } else {
      new_loop = SeqStmt(stmts);
    }

    if (!is_unit_loop) {
      Map<String, ObjectRef> preserved_annotations;
704
705
      for (const auto &kv : pipeline_loop_->annotations) {
        const String &key = kv.first;
706
707
708
709
710
711
712
        if (kv.first != tir::attr::software_pipeline_stage &&
            kv.first != tir::attr::software_pipeline_order &&
            kv.first != tir::attr::software_pipeline_async_stages) {
          preserved_annotations.Set(key, kv.second);
        }
      }
      new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
713
714
                     unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind,
                     std::move(new_loop), NullOpt, preserved_annotations);
715
716
717
    }

    // Update producer heads in the global async states.
718
    for (const auto &[stage_id, state] : async_states_local) {
719
720
721
      async_states[stage_id].producer_head += extent;
    }

722
723
    return BlockRealize({}, Bool(true),
                        MakeBlock(std::move(new_loop), buffer_data_to_buffer_));
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
  }

  arith::Analyzer analyzer_;
  Map<Var, Buffer> buffer_data_to_buffer_;
  Array<Buffer> pipeline_allocs_;
  For pipeline_loop_;
  PipelineInfo pipeline_info_;
  int max_stage_ = -1;
  Map<Buffer, Buffer> buffer_remap_;
  Array<Block> ordered_stmts_;
  std::map<int, AsyncStateGlobal> async_states;
};

/*!
 * \brief Build the dependency graph among a array of blocks.
 * \param[in] blocks The array of blocks.
740
741
742
 * \param[out] dep_src2dst Optional, a map to store dependency edges from the
 * source to the destination. \param[out] dep_dst2src Optional, a map to store
 * dependency edges from the destination to the source.
743
 */
744
745
746
747
748
749
750
751
752
753
void BuildDependencyGraph(const Array<Block> &blocks,
                          std::unordered_map<Block, Array<Block>, ObjectPtrHash,
                                             ObjectPtrEqual> *dep_src2dst,
                          std::unordered_map<Block, Array<Block>, ObjectPtrHash,
                                             ObjectPtrEqual> *dep_dst2src) {
  std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
      buffer_writers;

  for (const Block &block : blocks) {
    for (const BufferRegion &read : block->reads) {
754
755
      auto it = buffer_writers.find(read->buffer->data);
      if (it != buffer_writers.end()) {
756
        for (const Block &writer : it->second) {
757
758
759
760
761
762
763
764
765
          if (dep_src2dst != nullptr) {
            (*dep_src2dst)[writer].push_back(block);
          }
          if (dep_dst2src != nullptr) {
            (*dep_dst2src)[block].push_back(writer);
          }
        }
      }
    }
766
    for (const BufferRegion &write : block->writes) {
767
768
769
770
771
772
      buffer_writers[write->buffer->data].push_back(block);
    }
  }
}

class PipelineInjector : private StmtExprMutator {
773
774
public:
  static Stmt Inject(const PrimFunc &func) {
775
776
    auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
    PipelineInjector injector(global_symbol);
777
778
    for (const auto &kv : func->buffer_map) {
      const Buffer &buffer = kv.second;
779
780
781
782
783
      injector.buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    return injector(func->body);
  }

784
785
786
private:
  explicit PipelineInjector(Optional<String> global_symbol)
      : global_symbol_(global_symbol) {}
787
788
789
790

  /*!
   * \brief Check the pipeline satisfies the following conditions:
   * 1. No conflicting order: The order of each statement should be unique.
791
792
793
794
   * 2. Reordering of statements doesn't break buffer access dependencies.
   * Specifically, for dependency (e.g. read-after-write) from statement A to
   * statement B, it requires: case 1: stage(A) < stage(B) case 2: stage(A) ==
   * stage(B) and order(A) < order(B)
795
   */
796
797
  void ValidatePipelineBody(const PipelineInfo &pipeline_info,
                            const Array<Block> &original_order) {
798
799
    std::unordered_set<int> used_orders;
    std::unordered_map<int, int> stage_max_order;
800
801
802
803
    std::unordered_map<int, const Block *> order_to_block;
    std::unordered_map<const Block *, int> block_to_stage;
    for (const Block &block : original_order) {
      const auto &stmt_info = pipeline_info.at(block);
804
805
      int order = stmt_info.order;
      CHECK(!used_orders.count(order))
806
807
          << "ValueError: Two statements in the software pipeline cannot have "
             "the same order";
808
809
810
      used_orders.insert(order);
    }

811
812
    std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
        dep_src2dst;
813
814
    BuildDependencyGraph(original_order, &dep_src2dst, nullptr);

815
816
817
818
819
820
    for (const auto &pair : dep_src2dst) {
      const Block &src = pair.first;
      const auto &src_info = pipeline_info.at(src);
      const Array<Block> &dsts = pair.second;
      for (const Block &dst : dsts) {
        const auto &dst_info = pipeline_info.at(dst);
821
822
        CHECK_LE(src_info.stage, dst_info.stage)
            << "ValueError: statement " << dst << " in stage " << dst_info.stage
823
824
            << " cannot depends on statement " << src << " in a later stage "
            << src_info.stage;
825
        if (src_info.stage == dst_info.stage) {
826
827
828
829
          CHECK_LT(src_info.order, dst_info.order)
              << "ValueError: two statements with buffer "
                 "access dependency in the same stage of the "
                 "software pipeline cannot be reordered";
830
831
832
833
834
        }
      }
    }
  }

835
  Stmt VisitStmt_(const ForNode *op) final {
836
837
838
839
840
    // Step 1: Recursively rewrite the children first.
    For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
    if (!HasPipelineAnnotation(op)) {
      return std::move(for_node);
    }
841
842
843
    // Step 2: Find the body and buffer allocations of the pipeline. The body
    // can be direct child of the for-loop. If the for-loop has BlockRealize as
    // its child, the pipeline body will be the child of the block.
844
845
    Stmt pipeline_body{nullptr};
    Array<Buffer> pipeline_allocs;
846
847
848
    if (const auto *realize = for_node->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
849
850
851
852
853
854
855
856
857
        ICHECK(buffer->IsInstance<BufferNode>());
        buffer_data_to_buffer_.Set(buffer->data, buffer);
      }
      pipeline_body = block->body;
      pipeline_allocs = block->alloc_buffers;
    } else {
      pipeline_body = for_node->body;
    }

858
859
860
861
    const SeqStmtNode *pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
    CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
                                "should be SeqStmt, got "
                             << pipeline_body->GetTypeKey();
862

863
864
    // Step 3: Blockize the components of the pipeline. Each child of the
    // pipelined loop will be converted into a block.
865
    PipelineInfo pipeline_info;
866
    Array<Block> original_order; // pipeline body blocks in the original order
867

868
    auto f_add_child = [&](const Stmt &child) {
869
870
871
      original_order.push_back(MakeBlock(child, buffer_data_to_buffer_));
    };
    for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) {
872
873
      const auto *nested_block_realize =
          pipeline_body_seq->seq[i].as<BlockRealizeNode>();
874
875
      if (nested_block_realize && is_one(nested_block_realize->predicate) &&
          nested_block_realize->block->body->IsInstance<SeqStmtNode>()) {
876
877
878
879
        const Block &nested_pipeline_block = nested_block_realize->block;
        ICHECK(nested_pipeline_block->match_buffers
                   .empty()); // match_buffer should have been lowered
        for (const auto &buffer : nested_pipeline_block->alloc_buffers) {
880
881
882
          pipeline_allocs.push_back(buffer);
          buffer_data_to_buffer_.Set(buffer->data, buffer);
        }
883
        const auto *nested_seq = nested_pipeline_block->body.as<SeqStmtNode>();
884
885
886
887
888
889
890
891
        for (size_t j = 0; j < nested_seq->seq.size(); j++) {
          f_add_child(nested_seq->seq[j]);
        }
      } else {
        f_add_child(pipeline_body_seq->seq[i]);
      }
    }

892
893
894
895
    auto pipeline_stages = Downcast<Array<Integer>>(
        op->annotations.at(tir::attr::software_pipeline_stage));
    auto pipeline_orders = Downcast<Array<Integer>>(
        op->annotations.at(tir::attr::software_pipeline_order));
896
897
    CHECK_EQ(pipeline_stages.size(), original_order.size())
        << "PrimFunc " << global_symbol_ << " has original order "
898
899
900
901
        << original_order.Map(
               [](const auto &block) { return block->name_hint; })
        << ", but pipeline annotation is " << pipeline_stages
        << " with different size";
902
903
    CHECK_EQ(pipeline_orders.size(), original_order.size())
        << "PrimFunc " << global_symbol_ << " has original order "
904
905
906
907
        << original_order.Map(
               [](const auto &block) { return block->name_hint; })
        << ", but pipeline annotation is " << pipeline_orders
        << " with different size";
908
909

    std::unordered_set<int> pipeline_async_stages;
910
911
    if (auto annot =
            op->annotations.Get(tir::attr::software_pipeline_async_stages)) {
912
913
914
915
916
917
918
      for (auto s : Downcast<Array<Integer>>(annot)) {
        pipeline_async_stages.insert(s->value);
      }
    }

    for (size_t i = 0; i < pipeline_stages.size(); i++) {
      int stage = static_cast<int>(pipeline_stages[i]->value);
919
920
921
922
923
      bool is_async =
          pipeline_async_stages.find(stage) != pipeline_async_stages.end();
      PipelineAnnotation stage_order{
          stage,
          /*order=*/static_cast<int>(pipeline_orders[i]->value), is_async};
924
925
926
927
928
929
      pipeline_info.emplace(original_order[i], stage_order);
    }

    ValidatePipelineBody(pipeline_info, original_order);

    // Step 4: Rewrite the pipeline body.
930
931
932
    Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
                                     GetRef<For>(op), pipeline_info)
                        .BuildPipeline();
933

934
935
936
    if (const auto *realize = op->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
937
938
939
940
941
942
        buffer_data_to_buffer_.erase(buffer->data);
      }
    }
    return pipeline;
  }

943
944
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const auto &buffer : op->alloc_buffers) {
945
946
947
948
949
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }

    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));

950
    for (const auto &buffer : op->alloc_buffers) {
951
952
953
954
955
      buffer_data_to_buffer_.erase(buffer->data);
    }
    return std::move(block);
  }

956
  bool HasPipelineAnnotation(const ForNode *op) const {
957
958
959
960
961
962
963
964
    auto it1 = op->annotations.find(tir::attr::software_pipeline_stage);
    auto it2 = op->annotations.find(tir::attr::software_pipeline_order);
    bool has_stage = it1 != op->annotations.end();
    bool has_order = it2 != op->annotations.end();
    if (has_stage && has_order) {
      return true;
    }
    if (has_stage) {
965
966
      LOG(FATAL)
          << "ValueError: Order of the software pipeline is not defined.";
967
968
    }
    if (has_order) {
969
970
      LOG(FATAL)
          << "ValueError: Stage of the software pipeline is not defined.";
971
972
973
974
975
976
977
978
979
    }
    return false;
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Optional<String> global_symbol_;
};

/*!
980
981
 * \brief Transform annotated loops into pipelined one that parallelize
 * producers and consumers. \return The IR transform pass.
982
983
984
985
 */
tir::transform::Pass InjectSoftwarePipeline() {
  using namespace tir::transform;
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
986
    auto *fptr = f.CopyOnWrite();
987
988
989
990
991
992
993
994
995
996
    fptr->body = PipelineInjector::Inject(f);
    fptr->body = ConvertSSA(std::move(fptr->body));
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {});
}

TVM_REGISTER_GLOBAL("tl.transform.InjectSoftwarePipeline")
    .set_body_typed(InjectSoftwarePipeline);

997
998
} // namespace tl
} // namespace tvm