"...targets/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "25e8cf0b8d7a33ca1d026ca9b91c525a74e4f62b"
inject_pipeline.cc 43.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file inject_software_pipeline.cc
22
23
 * \brief Transform annotated loops into pipelined one that parallelize
 * producers and consumers
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
 */
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h>

#include <unordered_set>

#include "support/utils.h"
#include "tir/schedule/utils.h"
#include "tir/transforms/ir_utils.h"

namespace tvm {
namespace tl {
using namespace tir;

/*!
 * \brief Create a block and infer the access region with the given body.
 *
42
43
44
 * The result is a opaque block that doesn't contain any block iter vars. In
 * case the body is a block realize without predicate, it is unnecessary to
 * create a new block, the block of the block realize will be returned.
45
46
47
48
49
 *
 * \param body The body of the block.
 * \param buffer_data_to_buffer The map from buffer data to buffer.
 * \return The result block.
 */
50
51
52
Block MakeBlock(const Stmt &body,
                const Map<Var, Buffer> &buffer_data_to_buffer) {
  if (const BlockRealizeNode *block_realize = body.as<BlockRealizeNode>()) {
53
54
55
56
57
    if (is_one(block_realize->predicate)) {
      // no need to create a new block
      return block_realize->block;
    }
  }
58
59
60
61
62
  Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
              /*body*/ body);
  Array<Array<BufferRegion>> access =
      GetBlockReadWriteRegion(block, buffer_data_to_buffer);
  BlockNode *n = block.CopyOnWrite();
63
64
65
66
67
68
69
70
71
72
73
74
  n->reads = access[0];
  n->writes = access[1];
  return block;
}

/*! Structure that represents the provided annotation per block or loop. */
struct PipelineAnnotation {
  int stage;
  int order;
  bool async;
};

75
76
using PipelineInfo = std::unordered_map<Block, PipelineAnnotation,
                                        ObjectPtrHash, ObjectPtrEqual>;
77
78

struct BufferAccessInfo {
79
80
  int def = -1; // the defining stage of the buffer
  int use = -1; // the last using stage of the buffer
81
82
};

83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
/*!
 * \brief Replace IfThenElse nodes with their then_case, preserving attribute
 * nodes \param body The statement to process \param condition The condition to
 * match in IfThenElse nodes \return The transformed statement
 */
Stmt replace_if_then_else(Stmt body, PrimExpr condition) {
  if (const auto *if_node = body.as<IfThenElseNode>()) {
    // If this is an IfThenElse with the matching condition, replace it with its
    // then_case
    if (if_node->condition.same_as(condition)) {
      return if_node->then_case;
    }
  } else if (const auto *attr_node = body.as<AttrStmtNode>()) {
    // For attribute nodes, preserve the attribute but process its body
    AttrStmt attr_stmt = GetRef<AttrStmt>(attr_node);
    attr_stmt.CopyOnWrite()->body =
        replace_if_then_else(attr_node->body, condition);
    return attr_stmt;
  } else if (const auto *block_node = body.as<BlockNode>()) {
    // For block nodes, process the body
    Block block = GetRef<Block>(block_node);
    block.CopyOnWrite()->body =
        replace_if_then_else(block_node->body, condition);
    return block;
  }
  // For any other node type, return it unchanged
  return body;
}

112
/*!
113
114
115
 * \brief Rewriter for the body of the software pipeline. This pass inserts
 * `floormod` to indices of the remapped buffer to select the version
 * corresponding to the pipeline stage.
116
117
 */
class PipelineBodyRewriter : public StmtExprMutator {
118
public:
119
120
121
  /*!
   * \brief Constructor of PipelineBodyRewriter.
   * \param buffer_data_to_buffer The map from buffer data to buffer.
122
123
124
125
126
127
128
   * \param buffer_remap The map from original buffer to the buffer with updated
   * shape for multi-versioning in the software pipeline. \param pipeline_loop
   * The original loop to be software pipelined. \param access_all_versions
   * Whether all versions the buffers in the software pipeline are accessed.
   * This will be used to update block access region. In the prologue and
   * epilogue of a two-stage software pipeline, only one version of these
   * buffers are accessed.
129
   */
130
131
132
  PipelineBodyRewriter(const Map<Var, Buffer> &buffer_data_to_buffer,
                       const Map<Buffer, Buffer> &buffer_remap,
                       For pipeline_loop, bool access_all_versions)
133
      : buffer_data_to_buffer_(buffer_data_to_buffer),
134
        buffer_remap_(buffer_remap), pipeline_loop_(pipeline_loop),
135
136
        access_all_versions_(access_all_versions) {}

137
138
139
private:
  BufferRegion
  RewritePipelineBufferRegion(const BufferRegion &buffer_region) const {
140
141
142
    auto it = buffer_remap_.find(buffer_region->buffer);
    if (it != buffer_remap_.end()) {
      Region new_region = buffer_region->region;
143
144
145
      const Buffer &new_buffer = (*it).second;
      // For pipeline buffers, relax the access region of the first dimension to
      // full extent if access_all_versions == true
146
147
148
      Range accessed_version =
          access_all_versions_
              ? Range::FromMinExtent(0, new_buffer->shape[0])
149
150
151
152
              : Range::FromMinExtent(
                    floormod((pipeline_loop_->loop_var - pipeline_loop_->min),
                             new_buffer->shape[0]),
                    Integer(1));
153
154
155
156
157
158
      new_region.insert(new_region.begin(), accessed_version);
      return BufferRegion(new_buffer, new_region);
    }
    return buffer_region;
  }

159
160
161
162
163
164
  PrimExpr RewriteBufferAccess(const Call &call,
                               const std::vector<int> arg_indices) {
    auto product = [](const Array<PrimExpr> &input) {
      return foldl(
          [](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
          make_const(DataType::Int(32), 1), input);
165
166
167
    };
    Array<PrimExpr> new_args = call->args;
    for (int i : arg_indices) {
168
169
      const Buffer &buffer =
          buffer_data_to_buffer_.at(Downcast<Var>(call->args[i]));
170
171
      auto it = buffer_remap_.find(buffer);
      if (it != buffer_remap_.end()) {
172
173
        const Buffer &new_buffer = (*it).second;
        const PrimExpr &old_index = call->args[i + 1];
174
175
176
177
178
179
180
        PrimExpr offset;
        if (new_buffer->strides.empty()) {
          offset = product(buffer->shape);
        } else {
          offset = new_buffer->strides[0];
        }
        PrimExpr new_index =
181
182
            old_index +
            floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
183
184
185
186
187
188
        new_args.Set(i + 1, new_index);
      }
    }
    return Call(call->dtype, call->op, new_args, call->span);
  }

189
190
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const Buffer &alloc_buffer : op->alloc_buffers) {
191
192
193
      buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer);
    }
    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
194
195
    BlockNode *n = block.CopyOnWrite();
    n->reads.MutateByApply([this](const BufferRegion &buffer_region) {
196
197
      return RewritePipelineBufferRegion(buffer_region);
    });
198
    n->writes.MutateByApply([this](const BufferRegion &buffer_region) {
199
200
      return RewritePipelineBufferRegion(buffer_region);
    });
201
    for (const Buffer &alloc_buffer : op->alloc_buffers) {
202
203
204
205
206
      buffer_data_to_buffer_.erase(alloc_buffer->data);
    }
    return std::move(block);
  }

207
  Stmt VisitStmt_(const BufferStoreNode *op) final {
208
209
210
211
212
    BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
    auto it = buffer_remap_.find(store->buffer);
    if (it == buffer_remap_.end()) {
      return std::move(store);
    }
213
214
    const Buffer &new_buffer = (*it).second;
    auto *n = store.CopyOnWrite();
215
    n->buffer = new_buffer;
216
217
    PrimExpr version = floormod(
        (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
218
219
220
221
    n->indices.insert(n->indices.begin(), version);
    return std::move(store);
  }

222
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
223
224
225
226
227
    BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
    auto it = buffer_remap_.find(load->buffer);
    if (it == buffer_remap_.end()) {
      return std::move(load);
    }
228
229
    const Buffer &new_buffer = (*it).second;
    auto *n = load.CopyOnWrite();
230
    n->buffer = new_buffer;
231
232
    PrimExpr version = floormod(
        (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
233
234
235
236
    n->indices.insert(n->indices.begin(), version);
    return std::move(load);
  }

237
  PrimExpr VisitExpr_(const CallNode *op) final {
238
239
240
241
242
243
244
245
246
247
248
249
250
251
    Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
    if (call->op.same_as(builtin::tvm_access_ptr())) {
      return RewriteBufferAccess(call, {1});
    }
    return call;
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Map<Buffer, Buffer> buffer_remap_;
  For pipeline_loop_;
  bool access_all_versions_;
};

/*!
252
253
 * \brief Rewriter for the software pipeline that rewrite a loop into a
 * pipelined one.
254
255
 */
class PipelineRewriter : public StmtExprMutator {
256
257
258
public:
  PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
                   const Array<Buffer> &pipeline_allocs,
259
260
                   const For &pipeline_loop, const PipelineInfo &pipeline_info,
                   PrimExpr predicate_condition = PrimExpr())
261
      : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
262
        pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop),
263
264
        pipeline_info_(pipeline_info),
        predicate_condition_(predicate_condition) {}
265
266

  Stmt BuildPipeline() {
267
268
269
270
271
    // Step 1: Analyze accesses to the buffers in the pipeline and compute the
    // number of versions need to maintain for each buffer.
    std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
        infos = GetBufferAccessInfo();
    for (const Buffer &buffer : pipeline_allocs_) {
272
273
274
275
276
277
278
      int num_versions = ComputeBufferVersions(buffer, infos.at(buffer));
      if (num_versions > 1) {
        buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions));
      }
    }

    ordered_stmts_.resize(pipeline_info_.size());
279
    for (const auto &[block, anno] : pipeline_info_) {
280
281
282
      ordered_stmts_.Set(anno.order, block);
    }

283
    for (const Block &block : ordered_stmts_) {
284
285
      int stage = pipeline_info_[block].stage;
      if (pipeline_info_[block].async) {
286
        auto &state = async_states[stage];
287
288
289
290
        state.producer_head = pipeline_loop_->min - 1;
        for (auto write_region : block->writes) {
          auto buffer = write_region->buffer;
          state.dst_buffers.insert(buffer.get());
291
292
          if (buffer_remap_.count(buffer))
            state.dst_buffers.insert(buffer_remap_[buffer].get());
293
294
295
296
        }
      }
    }
    std::unordered_set<int> consumed;
297
    for (const Block &block : ordered_stmts_) {
298
299
      int stage = pipeline_info_[block].stage;
      if (pipeline_info_[block].async) {
300
        auto &state = async_states[stage];
301
302
303
304
305
306
307
308
309
        if (state.commit_groups.empty() || consumed.count(stage)) {
          state.commit_groups.push_back({});
        }
        state.commit_groups.back().push_back(pipeline_info_[block].order);
        consumed.erase(stage);
        for (auto write_region : block->writes) {
          auto buffer = buffer_remap_.count(write_region->buffer)
                            ? buffer_remap_[write_region->buffer]
                            : write_region->buffer;
310
311
          state.buffer_to_commit_group_[buffer.get()] =
              state.commit_groups.size() - 1;
312
313
314
315
        }
      }

      for (auto read_region : block->reads) {
316
317
318
        for (const auto &[producer_stage_id, producer_state] : async_states) {
          if (producer_stage_id <= stage &&
              producer_state.writes(read_region->buffer)) {
319
320
321
322
323
324
325
            consumed.insert(producer_stage_id);
          }
        }
      }
    }

    // Step 2: Emit the pipeline prologue, body and epilogue.
326
327
328
329
330
331
332
333
    Stmt prologue = EmitImpl(pipeline_loop_->min,
                             pipeline_loop_->min + max_stage_, true, true);
    Stmt body =
        EmitImpl(pipeline_loop_->min + max_stage_,
                 pipeline_loop_->min + pipeline_loop_->extent, false, false);
    Stmt epilogue = EmitImpl(
        pipeline_loop_->min + pipeline_loop_->extent,
        pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true);
334
335
336

    SeqStmt stmt = SeqStmt({prologue, body, epilogue});

337
338
    // Step 3: Make a new block that contains new buffer allocations after
    // pipeline rewriting.
339
    Array<Buffer> alloc_buffers;
340
    for (const auto &alloc : pipeline_allocs_) {
341
342
343
344
345
346
347
348
      alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc));
      buffer_data_to_buffer_.erase(alloc->data);
    }
    Block block = MakeBlock(stmt, buffer_data_to_buffer_);
    block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers);
    return BlockRealize({}, Bool(true), block);
  }

349
private:
350
351
352
  /*!
   * \brief Analyze accesses to the buffers in the software pipeline.
   *
353
354
355
   * This method check the 'define' and 'use' stage of the buffers in the
   * software pipeline, which can be used to compute the number of versions
   * needed to maintain after rewriting.
356
357
358
   */
  std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
  GetBufferAccessInfo() {
359
360
361
362
    std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
        infos;
    for (const auto &pair : pipeline_info_) {
      const Block &block = pair.first;
363
364
365
      int stage = pair.second.stage;
      max_stage_ = std::max(max_stage_, stage);

366
      for (const BufferRegion &write : block->writes) {
367
368
369
        if (!infos.count(write->buffer)) {
          infos.emplace(write->buffer, BufferAccessInfo{});
        }
370
        auto &info = infos.at(write->buffer);
371
372
373
374
375
376
377
        if (info.def == -1) {
          info.def = stage;
        } else {
          info.def = std::min(info.def, stage);
        }
      }

378
      for (const BufferRegion &read : block->reads) {
379
380
381
        if (!infos.count(read->buffer)) {
          infos.emplace(read->buffer, BufferAccessInfo{});
        }
382
        auto &info = infos.at(read->buffer);
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
        info.use = std::max(info.use, stage);
      }
    }
    return infos;
  }

  /*!
   * \brief Check whether two regions have intersections.
   * \param region1 The first region.
   * \param region2 The second region.
   * \return Whether region1 and region2 have intersections.
   */
  bool MayConflict(Region region1, Region region2) {
    ICHECK(region1.size() == region2.size());
    for (size_t i = 0; i < region1.size(); i++) {
      Range dim1 = region1[i];
      Range dim2 = region2[i];
      auto int_set1 = arith::IntSet::FromRange(dim1);
      auto int_set2 = arith::IntSet::FromRange(dim2);
      if (arith::Intersect({int_set1, int_set2}).IsNothing()) {
        return false;
      }
    }
    return true;
  }

  /*!
410
411
   * \brief Compute the number of versions need to maintain for buffer accessed
   * in the software pipeline.
412
   *
413
414
415
416
417
418
   * This method applies liveness analysis to the target buffer to compute the
   * number of versions need to maintain during the software pipeline.
   * Annotation `attr::double_buffer_scope` is handled here which provides a way
   * to override the result of the analysis. Additional double buffering in the
   * software pipeline can be useful to eliminate synchronizations in GPU
   * devices.
419
420
421
422
423
   *
   * \param buffer The target buffer
   * \param buffer_info The access information of the target buffer.
   * \return The number of versions required for the target buffer.
   */
424
425
  int ComputeBufferVersions(const Buffer &buffer,
                            const BufferAccessInfo &buffer_info) {
426
    if (buffer_info.def == -1) {
427
428
      // Keep the original number of versions as buffers defined outside the
      // software pipeline should not be mutated.
429
430
431
432
      return 1;
    }

    // `use - def + 1` is a upper bound of the needed versions
433
434
    // We optimize a few case where the number of versions can be smaller than
    // the upper bound
435
436
    int num_versions = buffer_info.use - buffer_info.def + 1;
    if (num_versions >= 2) {
437
438
439
440
      // A special case when `use - def + 1 == 2`. Double buffering is only
      // needed in this case when these exists a reader block_i and a writer
      // block_j such that order(block_i) < order(block_j) and stage(block_i) <
      // stage(block_j) and the access regions of block_i and block_j overlap.
441
      bool need_multi_version = false;
442
443
444
      for (const auto &pair1 : pipeline_info_) {
        const Block &writer_block = pair1.first;
        const auto &writer_info = pair1.second;
445

446
447
448
        auto it1 = std::find_if(writer_block->writes.begin(),
                                writer_block->writes.end(),
                                [&](const BufferRegion &buffer_region) {
449
450
451
452
453
454
                                  return buffer_region->buffer.same_as(buffer);
                                });
        if (it1 == writer_block->writes.end()) {
          continue;
        }

455
456
457
458
459
460
461
462
        for (const auto &pair2 : pipeline_info_) {
          const Block &reader_block = pair2.first;
          const auto &reader_info = pair2.second;
          auto it2 = std::find_if(
              reader_block->reads.begin(), reader_block->reads.end(),
              [&](const BufferRegion &buffer_region) {
                return buffer_region->buffer.same_as(buffer);
              });
463
464
465
          if (it2 == reader_block->reads.end()) {
            continue;
          }
466
467
          if (writer_info.order < reader_info.order &&
              writer_info.stage < reader_info.stage &&
468
469
470
471
472
473
474
475
476
477
478
479
480
481
              MayConflict((*it1)->region, (*it2)->region)) {
            need_multi_version = true;
            break;
          }
        }
      }
      if (!need_multi_version) {
        num_versions--;
      }
    }
    return num_versions;
  }

  /*!
482
483
   * \brief Rewrite buffer allocation to keep multiple versions of original
   * buffer for pipelined accesses. \param buffer The buffer to be resized.
484
485
486
   * \param num_versions The number of versions to keep.
   * \return The resized buffer.
   */
487
  Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
488
489
490
491
492
493
494
495
496
497
    ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get()));
    new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
    if (new_buffer->strides.size()) {
      ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
      PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1];
      new_buffer->strides.insert(new_buffer->strides.begin(), stride_0);
    }
    return Buffer(new_buffer);
  }

498
499
  // Per-stage states that need to be tracked across pipeline prologue, body,
  // and epilogue.
500
501
  struct AsyncStateGlobal {
    // Buffers that this stage asynchronously writes.
502
503
504
505
506
507
508
    std::unordered_set<const BufferNode *> dst_buffers;
    // An imaginary index that the latest async operation associated with this
    // stage has written into. Only valid if all associated predicates are true,
    // so that we can count the number of async invocations exactly. When it is
    // valid, it is the "sum of extents of loops that have been executed" - 1,
    // e.g. for epilogue it is prologue extent + body extent - 1. This is only
    // needed to compute wait count for epilogue without async producers.
509
510
    PrimExpr producer_head;
    std::vector<std::vector<int>> commit_groups;
511
    std::unordered_map<const BufferNode *, int> buffer_to_commit_group_;
512
513
514
    bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; }
  };

515
516
  // Per-stage states that are local to each of pipeline prologue, body, and
  // epilogue.
517
518
  struct AsyncStateLocal {
    struct PendingWait {
519
520
      // The index into a list of blocks, where async_wait_queue should be
      // attached at the beginning.
521
      int insert_before;
522
523
      // in_flight_count would be a more precise name, but the implementation
      // uses wait_count for brevity.
524
525
526
527
528
529
530
      PrimExpr wait_count{nullptr};

      bool valid() const { return wait_count.defined(); }
    };

    std::vector<PendingWait> pending_waits;

531
532
    // A symbolic expression representing the index the latest async operation
    // associated with this stage has written into, at the "current" iteration.
533
534
535
536
537
538
539
540
541
542
543
544
545
    Optional<PrimExpr> producer_head;
  };

  /*! Structure holding intermediate information for pipeline loop rewriting. */
  struct RewrittenBlockInfo {
    int stage;
    int order;
    PrimExpr predicate;
    Block block;
    PrimExpr access_index;
    bool is_async;
  };

546
547
  void PopulateWaitCounts(const std::vector<RewrittenBlockInfo> &new_blocks,
                          std::map<int, AsyncStateLocal> *async_states_local) {
548
549
550
    for (size_t i = 0; i < new_blocks.size(); ++i) {
      int producer_stage_idx = -1;
      for (auto read_region : new_blocks[i].block->reads) {
551
552
553
554
555
        for (const auto &[stage, state] : async_states) {
          if (stage <= new_blocks[i].stage &&
              state.writes(read_region->buffer)) {
            // Found an earlier stage where read_region->buffer was
            // asynchronously written
556
557
558
559
560
561
            ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage)
                << "A dependency on multiple async stages is not supported";
            producer_stage_idx = stage;
          }
        }
      }
562
563
564
565
      if (producer_stage_idx == -1)
        continue;
      const auto &state = async_states[producer_stage_idx];
      auto &dep_local_state = (*async_states_local)[producer_stage_idx];
566
      PrimExpr in_flight_cnt = 0;
567
      for (const auto &group : state.commit_groups) {
568
569
570
571
572
        PrimExpr consumer_head = new_blocks[i].access_index;
        PrimExpr producer_head;
        if (dep_local_state.producer_head.defined()) {
          producer_head = dep_local_state.producer_head.value();
          // if the group is after the wait point, minus by 1
573
574
          if (group.front() > new_blocks[i].order)
            producer_head -= 1;
575
576
577
578
579
580
581
582
        } else {
          producer_head = state.producer_head;
        }
        in_flight_cnt += producer_head - consumer_head;
      }

      // We can relax the in-flight-count by the number of independent commit.
      std::unordered_set<int> dependent_groups;
583
      for (const auto &read_region : new_blocks[i].block->reads) {
584
        if (state.buffer_to_commit_group_.count(read_region->buffer.get()))
585
586
          dependent_groups.insert(
              state.buffer_to_commit_group_.at(read_region->buffer.get()));
587
588
589
590
591
      }
      for (int i = int(state.commit_groups.size()) - 1; i >= 0; i--) {
        if (dependent_groups.count(i) == 0)
          in_flight_cnt += 1;
        else
592
          break; // stop relaxing
593
594
      }
      in_flight_cnt = analyzer_.Simplify(in_flight_cnt);
595
596
      dep_local_state.pending_waits.push_back(
          {static_cast<int>(i), in_flight_cnt});
597
598
599
    }
  }

600
601
  // Given pipelined blocks and async-related information, generate final loop
  // statements with async scopes (if any).
602
  Array<Stmt> CompletePipelineLoopStatements(
603
604
      const std::vector<RewrittenBlockInfo> &blocks,
      const std::map<int, AsyncStateLocal> &async_states_local) const {
605
    std::vector<RewrittenBlockInfo> new_blocks = blocks;
606
607
608
609
    for (const auto &[stage_id, state] : async_states_local) {
      for (const auto &pw : state.pending_waits) {
        auto &block = new_blocks[pw.insert_before].block;
        BlockNode *n = block.CopyOnWrite();
610
        auto zero = make_zero(DataType::Int(32));
611
612
613
        n->body = AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
                           AttrStmt(zero, tir::attr::async_wait_inflight_count,
                                    pw.wait_count, n->body));
614
615
616
617
618
      }
    }

    // mark the last async stmt as commit
    std::unordered_set<int> commit_group_indices;
619
    for (const auto &[stage_id, state] : async_states) {
620
621
622
623
624
625
626
627
628
629
      for (size_t i = 0; i < state.commit_groups.size(); ++i) {
        commit_group_indices.insert(state.commit_groups[i].back());
      }
    }

    Array<Stmt> stmts;

    for (size_t i = 0; i < new_blocks.size(); i++) {
      Block block = new_blocks[i].block;
      if (commit_group_indices.count(new_blocks[i].order)) {
630
631
632
        auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)),
                                           tir::attr::async_commit_queue_scope,
                                           new_blocks[i].stage, block->body);
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
        block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_);
      }
      stmts.push_back(BlockRealize({}, new_blocks[i].predicate, block));
    }

    return stmts;
  }

  /*!
   * \brief Emit the pipeline loop in the given range.
   * \param start The start of the range
   * \param end The end of the range
   * \param unroll_loop Whether the loop should be unrolled.
   * \return The result loop.
   */
648
649
  Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop,
                bool need_bound_check) {
650
651
    PrimExpr new_loop_var;
    PrimExpr extent = end - start;
652
653
654
    auto make_nop = []() {
      return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {}));
    };
655
656
657

    bool is_unit_loop = analyzer_.CanProveEqual(extent, 1);
    if (is_unit_loop) {
658
      new_loop_var = start; // use constants as the loop var for unit loops
659
660
661
662
663
664
665
666
667
    } else {
      new_loop_var = pipeline_loop_->loop_var.copy_with_suffix("");
      analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end));
    }

    std::vector<RewrittenBlockInfo> new_blocks;

    // Async related
    std::map<int, AsyncStateLocal> async_states_local;
668
    PrimExpr normalized_access_index;
669

670
    for (const Block &block : ordered_stmts_) {
671
672
673
674
675
      int stage = pipeline_info_.at(block).stage;
      int order = pipeline_info_.at(block).order;
      PrimExpr inbound = Bool(true);
      PrimExpr skewed_loop_var = new_loop_var - stage;
      if (need_bound_check)
676
677
678
        inbound =
            analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) &&
            (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent);
679
680
681
      if (analyzer_.CanProve(!inbound)) {
        continue;
      }
682
683
684
      Block new_block = Downcast<Block>(
          PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_,
                               pipeline_loop_, max_stage_ != 1)(block));
685
686
687
688

      PrimExpr delta = start - pipeline_loop_->min;
      // This variable corresponds to
      // - "producer_head" if this stage is an async producer
689
690
      // - "consumer_head" if this stage reads from asynchronously written
      // buffers.
691
      normalized_access_index =
692
          is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
693

694
695
      // Adjust the block predicate and the body according to the final loop
      // bound
696
697
698
699
700
      //  [pipeline_loop_->min, extent).
      if (!is_unit_loop) {
        Var loop_iter = Downcast<Var>(new_loop_var);
        inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}});
      }
701
702
      new_block = Downcast<Block>(Substitute(
          new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
703
704
705
706
707
708
709
      if (predicate_condition_.defined()) {
        BlockNode *n = new_block.CopyOnWrite();
        n->body = IfThenElse(
            Substitute(predicate_condition_,
                       {{pipeline_loop_->loop_var, normalized_access_index}}),
            n->body);
      }
710
      if (pipeline_info_[block].async) {
711
        auto &local_state = async_states_local[stage];
712
        local_state.producer_head = normalized_access_index;
713
714
715
        BlockNode *n = new_block.CopyOnWrite();
        n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope,
                           1, n->body);
716
717
      }

718
719
720
      new_blocks.push_back({stage, order, inbound, new_block,
                            normalized_access_index,
                            pipeline_info_[block].async});
721
722
723
    }

    PopulateWaitCounts(new_blocks, &async_states_local);
724

725
    auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local);
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832

    // Group blocks by their predicate conditions
    PrimExpr current_condition = Bool(true);
    Array<Stmt> current_stmts;
    Array<PrimExpr> ordered_conditions;
    Array<Array<Stmt>> condition_to_stmts;

    for (const auto &stmt : stmts) {
      if (const auto *realize = stmt.as<BlockRealizeNode>()) {
        // Helper function to find IfThenElse through potential AttrStmt nodes
        auto find_if_then_else =
            [](Stmt body) -> std::pair<bool, const IfThenElseNode *> {
          while (true) {
            if (const auto *if_node = body.as<IfThenElseNode>()) {
              return {true, if_node};
            } else if (const auto *attr_node = body.as<AttrStmtNode>()) {
              // Continue traversing through attributes
              body = attr_node->body;
            } else {
              // No IfThenElse found
              return {false, nullptr};
            }
          }
        };

        auto [has_if, if_then_else] = find_if_then_else(realize->block->body);

        if (has_if) {
          if (if_then_else->else_case.defined()) {
            // IfThenElse nodes with else case are treated individually
            if (!current_stmts.empty()) {
              ordered_conditions.push_back(current_condition);
              condition_to_stmts.push_back(current_stmts);
              current_stmts = {};
            }
            current_condition = Bool(true);
            current_stmts.push_back(stmt);
          } else {
            // If we encounter a new condition
            if (!StructuralEqual()(if_then_else->condition,
                                   current_condition)) {
              // Store the current group if it's not empty
              if (!current_stmts.empty()) {
                ordered_conditions.push_back(current_condition);
                condition_to_stmts.push_back(current_stmts);
                current_stmts = {};
              }
              current_condition = if_then_else->condition;
            }
            BlockRealize new_realize = Downcast<BlockRealize>(stmt);
            new_realize.CopyOnWrite()->block.CopyOnWrite()->body =
                replace_if_then_else(new_realize->block->body,
                                     if_then_else->condition);
            current_stmts.push_back(new_realize);
          }
        } else {
          if (!current_stmts.empty()) {
            ordered_conditions.push_back(current_condition);
            condition_to_stmts.push_back(current_stmts);
            current_stmts = {};
          }
          current_condition = Bool(true);
          current_stmts.push_back(stmt);
        }
      } else {
        // Non-BlockRealize statements are treated individually
        if (!current_stmts.empty()) {
          ordered_conditions.push_back(current_condition);
          condition_to_stmts.push_back(current_stmts);
          current_stmts = {};
        }
        current_condition = Bool(true);
        current_stmts.push_back(stmt);
      }
    }

    // Add the last group if not empty
    if (!current_stmts.empty()) {
      ordered_conditions.push_back(current_condition);
      condition_to_stmts.push_back(current_stmts);
    }

    // Build the final statement sequence with proper conditionals
    Array<Stmt> final_stmts;
    for (auto i = 0; i < ordered_conditions.size(); i++) {
      Array<Stmt> condition_stmts = condition_to_stmts[i];
      if (condition_stmts.empty())
        continue;

      // Create a sequence from the statements with this condition
      Stmt stmt_block;
      if (condition_stmts.size() == 1) {
        stmt_block = condition_stmts[0];
      } else {
        stmt_block = SeqStmt(condition_stmts);
      }

      // If condition is not trivially true, wrap in if-then-else
      if (!is_one(ordered_conditions[i]) &&
          !analyzer_.CanProve(ordered_conditions[i] == true)) {
        stmt_block = IfThenElse(ordered_conditions[i], stmt_block);
      }

      final_stmts.push_back(stmt_block);
    }

    // Use final_stmts instead of the original stmts
833
834
    Stmt new_loop{nullptr};

835
    if (final_stmts.empty()) {
836
837
      return make_nop();
    }
838
839
840

    if (final_stmts.size() == 1) {
      new_loop = final_stmts[0];
841
    } else {
842
      new_loop = SeqStmt(final_stmts);
843
844
845
846
    }

    if (!is_unit_loop) {
      Map<String, ObjectRef> preserved_annotations;
847
848
      for (const auto &kv : pipeline_loop_->annotations) {
        const String &key = kv.first;
849
850
851
852
853
854
855
        if (kv.first != tir::attr::software_pipeline_stage &&
            kv.first != tir::attr::software_pipeline_order &&
            kv.first != tir::attr::software_pipeline_async_stages) {
          preserved_annotations.Set(key, kv.second);
        }
      }
      new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
856
857
                     unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind,
                     std::move(new_loop), NullOpt, preserved_annotations);
858
859
    }
    // Update producer heads in the global async states.
860
    for (const auto &[stage_id, state] : async_states_local) {
861
862
863
      async_states[stage_id].producer_head += extent;
    }

864
865
    return BlockRealize({}, Bool(true),
                        MakeBlock(std::move(new_loop), buffer_data_to_buffer_));
866
867
868
869
870
871
872
  }

  arith::Analyzer analyzer_;
  Map<Var, Buffer> buffer_data_to_buffer_;
  Array<Buffer> pipeline_allocs_;
  For pipeline_loop_;
  PipelineInfo pipeline_info_;
873
  PrimExpr predicate_condition_;
874
875
876
877
878
879
880
881
882
  int max_stage_ = -1;
  Map<Buffer, Buffer> buffer_remap_;
  Array<Block> ordered_stmts_;
  std::map<int, AsyncStateGlobal> async_states;
};

/*!
 * \brief Build the dependency graph among a array of blocks.
 * \param[in] blocks The array of blocks.
883
884
885
 * \param[out] dep_src2dst Optional, a map to store dependency edges from the
 * source to the destination. \param[out] dep_dst2src Optional, a map to store
 * dependency edges from the destination to the source.
886
 */
887
888
889
890
891
892
893
894
895
896
void BuildDependencyGraph(const Array<Block> &blocks,
                          std::unordered_map<Block, Array<Block>, ObjectPtrHash,
                                             ObjectPtrEqual> *dep_src2dst,
                          std::unordered_map<Block, Array<Block>, ObjectPtrHash,
                                             ObjectPtrEqual> *dep_dst2src) {
  std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
      buffer_writers;

  for (const Block &block : blocks) {
    for (const BufferRegion &read : block->reads) {
897
898
      auto it = buffer_writers.find(read->buffer->data);
      if (it != buffer_writers.end()) {
899
        for (const Block &writer : it->second) {
900
901
902
903
904
905
906
907
908
          if (dep_src2dst != nullptr) {
            (*dep_src2dst)[writer].push_back(block);
          }
          if (dep_dst2src != nullptr) {
            (*dep_dst2src)[block].push_back(writer);
          }
        }
      }
    }
909
    for (const BufferRegion &write : block->writes) {
910
911
912
913
914
915
      buffer_writers[write->buffer->data].push_back(block);
    }
  }
}

class PipelineInjector : private StmtExprMutator {
916
917
public:
  static Stmt Inject(const PrimFunc &func) {
918
919
    auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
    PipelineInjector injector(global_symbol);
920
921
    for (const auto &kv : func->buffer_map) {
      const Buffer &buffer = kv.second;
922
923
924
925
926
      injector.buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    return injector(func->body);
  }

927
928
929
private:
  explicit PipelineInjector(Optional<String> global_symbol)
      : global_symbol_(global_symbol) {}
930
931
932
933

  /*!
   * \brief Check the pipeline satisfies the following conditions:
   * 1. No conflicting order: The order of each statement should be unique.
934
935
936
937
   * 2. Reordering of statements doesn't break buffer access dependencies.
   * Specifically, for dependency (e.g. read-after-write) from statement A to
   * statement B, it requires: case 1: stage(A) < stage(B) case 2: stage(A) ==
   * stage(B) and order(A) < order(B)
938
   */
939
940
  void ValidatePipelineBody(const PipelineInfo &pipeline_info,
                            const Array<Block> &original_order) {
941
942
    std::unordered_set<int> used_orders;
    std::unordered_map<int, int> stage_max_order;
943
944
945
946
    std::unordered_map<int, const Block *> order_to_block;
    std::unordered_map<const Block *, int> block_to_stage;
    for (const Block &block : original_order) {
      const auto &stmt_info = pipeline_info.at(block);
947
948
      int order = stmt_info.order;
      CHECK(!used_orders.count(order))
949
950
          << "ValueError: Two statements in the software pipeline cannot have "
             "the same order";
951
952
953
      used_orders.insert(order);
    }

954
955
    std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
        dep_src2dst;
956
957
    BuildDependencyGraph(original_order, &dep_src2dst, nullptr);

958
959
960
961
962
963
    for (const auto &pair : dep_src2dst) {
      const Block &src = pair.first;
      const auto &src_info = pipeline_info.at(src);
      const Array<Block> &dsts = pair.second;
      for (const Block &dst : dsts) {
        const auto &dst_info = pipeline_info.at(dst);
964
965
        CHECK_LE(src_info.stage, dst_info.stage)
            << "ValueError: statement " << dst << " in stage " << dst_info.stage
966
967
            << " cannot depends on statement " << src << " in a later stage "
            << src_info.stage;
968
        if (src_info.stage == dst_info.stage) {
969
970
971
972
          CHECK_LT(src_info.order, dst_info.order)
              << "ValueError: two statements with buffer "
                 "access dependency in the same stage of the "
                 "software pipeline cannot be reordered";
973
974
975
976
977
        }
      }
    }
  }

978
  Stmt VisitStmt_(const ForNode *op) final {
979
980
981
982
983
    // Step 1: Recursively rewrite the children first.
    For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
    if (!HasPipelineAnnotation(op)) {
      return std::move(for_node);
    }
984
985
986
    // Step 2: Find the body and buffer allocations of the pipeline. The body
    // can be direct child of the for-loop. If the for-loop has BlockRealize as
    // its child, the pipeline body will be the child of the block.
987
    Stmt pipeline_body{nullptr};
988
    PrimExpr predicate_condition{nullptr};
989
    Array<Buffer> pipeline_allocs;
990
991
992
    if (const auto *realize = for_node->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
993
994
995
        ICHECK(buffer->IsInstance<BufferNode>());
        buffer_data_to_buffer_.Set(buffer->data, buffer);
      }
996
997
998
999
1000
1001
1002
1003
1004
      if (const auto *if_then_else = block->body.as<IfThenElseNode>()) {
        ICHECK(!if_then_else->else_case.defined())
            << "Pipeline_Planning: Can't handle the body of the loop because "
               "it is not a SeqStmt";
        pipeline_body = if_then_else->then_case;
        predicate_condition = if_then_else->condition;
      } else {
        pipeline_body = block->body;
      }
1005
1006
1007
1008
1009
      pipeline_allocs = block->alloc_buffers;
    } else {
      pipeline_body = for_node->body;
    }

1010
1011
1012
1013
    const SeqStmtNode *pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
    CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
                                "should be SeqStmt, got "
                             << pipeline_body->GetTypeKey();
1014

1015
1016
    // Step 3: Blockize the components of the pipeline. Each child of the
    // pipelined loop will be converted into a block.
1017
    PipelineInfo pipeline_info;
1018
    Array<Block> original_order; // pipeline body blocks in the original order
1019

1020
    auto f_add_child = [&](const Stmt &child) {
1021
1022
1023
      original_order.push_back(MakeBlock(child, buffer_data_to_buffer_));
    };
    for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) {
1024
1025
      const auto *nested_block_realize =
          pipeline_body_seq->seq[i].as<BlockRealizeNode>();
1026
1027
      if (nested_block_realize && is_one(nested_block_realize->predicate) &&
          nested_block_realize->block->body->IsInstance<SeqStmtNode>()) {
1028
1029
1030
1031
        const Block &nested_pipeline_block = nested_block_realize->block;
        ICHECK(nested_pipeline_block->match_buffers
                   .empty()); // match_buffer should have been lowered
        for (const auto &buffer : nested_pipeline_block->alloc_buffers) {
1032
1033
1034
          pipeline_allocs.push_back(buffer);
          buffer_data_to_buffer_.Set(buffer->data, buffer);
        }
1035
        const auto *nested_seq = nested_pipeline_block->body.as<SeqStmtNode>();
1036
1037
1038
1039
1040
1041
1042
1043
        for (size_t j = 0; j < nested_seq->seq.size(); j++) {
          f_add_child(nested_seq->seq[j]);
        }
      } else {
        f_add_child(pipeline_body_seq->seq[i]);
      }
    }

1044
1045
1046
1047
    auto pipeline_stages = Downcast<Array<Integer>>(
        op->annotations.at(tir::attr::software_pipeline_stage));
    auto pipeline_orders = Downcast<Array<Integer>>(
        op->annotations.at(tir::attr::software_pipeline_order));
1048
1049
    CHECK_EQ(pipeline_stages.size(), original_order.size())
        << "PrimFunc " << global_symbol_ << " has original order "
1050
1051
1052
1053
        << original_order.Map(
               [](const auto &block) { return block->name_hint; })
        << ", but pipeline annotation is " << pipeline_stages
        << " with different size";
1054
1055
    CHECK_EQ(pipeline_orders.size(), original_order.size())
        << "PrimFunc " << global_symbol_ << " has original order "
1056
1057
1058
1059
        << original_order.Map(
               [](const auto &block) { return block->name_hint; })
        << ", but pipeline annotation is " << pipeline_orders
        << " with different size";
1060
1061

    std::unordered_set<int> pipeline_async_stages;
1062
1063
    if (auto annot =
            op->annotations.Get(tir::attr::software_pipeline_async_stages)) {
1064
1065
1066
1067
1068
1069
1070
      for (auto s : Downcast<Array<Integer>>(annot)) {
        pipeline_async_stages.insert(s->value);
      }
    }

    for (size_t i = 0; i < pipeline_stages.size(); i++) {
      int stage = static_cast<int>(pipeline_stages[i]->value);
1071
1072
1073
1074
1075
      bool is_async =
          pipeline_async_stages.find(stage) != pipeline_async_stages.end();
      PipelineAnnotation stage_order{
          stage,
          /*order=*/static_cast<int>(pipeline_orders[i]->value), is_async};
1076
1077
1078
1079
1080
1081
      pipeline_info.emplace(original_order[i], stage_order);
    }

    ValidatePipelineBody(pipeline_info, original_order);

    // Step 4: Rewrite the pipeline body.
1082
1083
1084
1085
    Stmt pipeline =
        PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
                         GetRef<For>(op), pipeline_info, predicate_condition)
            .BuildPipeline();
1086

1087
1088
1089
    if (const auto *realize = op->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
1090
1091
1092
1093
1094
1095
        buffer_data_to_buffer_.erase(buffer->data);
      }
    }
    return pipeline;
  }

1096
1097
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const auto &buffer : op->alloc_buffers) {
1098
1099
1100
1101
1102
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }

    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));

1103
    for (const auto &buffer : op->alloc_buffers) {
1104
1105
1106
1107
1108
      buffer_data_to_buffer_.erase(buffer->data);
    }
    return std::move(block);
  }

1109
  bool HasPipelineAnnotation(const ForNode *op) const {
1110
1111
1112
1113
1114
1115
1116
1117
    auto it1 = op->annotations.find(tir::attr::software_pipeline_stage);
    auto it2 = op->annotations.find(tir::attr::software_pipeline_order);
    bool has_stage = it1 != op->annotations.end();
    bool has_order = it2 != op->annotations.end();
    if (has_stage && has_order) {
      return true;
    }
    if (has_stage) {
1118
      LOG(FATAL)
1119
          << "ValueError: Stage of the software pipeline is not defined.";
1120
1121
    }
    if (has_order) {
1122
      LOG(FATAL)
1123
          << "ValueError: Order of the software pipeline is not defined.";
1124
1125
1126
1127
1128
1129
1130
1131
1132
    }
    return false;
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Optional<String> global_symbol_;
};

/*!
1133
1134
 * \brief Transform annotated loops into pipelined one that parallelize
 * producers and consumers. \return The IR transform pass.
1135
1136
1137
1138
 */
tir::transform::Pass InjectSoftwarePipeline() {
  using namespace tir::transform;
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
1139
    auto *fptr = f.CopyOnWrite();
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
    fptr->body = PipelineInjector::Inject(f);
    fptr->body = ConvertSSA(std::move(fptr->body));
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {});
}

TVM_REGISTER_GLOBAL("tl.transform.InjectSoftwarePipeline")
    .set_body_typed(InjectSoftwarePipeline);

1150
1151
} // namespace tl
} // namespace tvm