pipeline_planning.cc 21.7 KB
Newer Older
1
#include <tvm/arith/analyzer.h>
2
#include <tvm/ffi/reflection/registry.h>
3
#include <tvm/tir/analysis.h>
4
#include <tvm/tir/builtin.h>
5
6
7
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

8
9
#include <utility>

10
#include "../target/utils.h"
11
#include "tvm/ir/expr.h"
12
13
14
15
16
17
18
19
20
21
22
23

namespace tvm {
namespace tl {

using namespace tir;

/*!
 * \brief Check whether two regions have intersections.
 * \param region1 The first region.
 * \param region2 The second region.
 * \return Whether region1 and region2 have intersections.
 */
24
bool MayConflict(const Region &region1, const Region &region2) {
25
26
27
28
29
30
31
32
33
34
35
36
37
  ICHECK(region1.size() == region2.size());
  for (size_t i = 0; i < region1.size(); i++) {
    Range dim1 = region1[i];
    Range dim2 = region2[i];
    auto int_set1 = arith::IntSet::FromRange(dim1);
    auto int_set2 = arith::IntSet::FromRange(dim2);
    if (arith::Intersect({int_set1, int_set2}).IsNothing()) {
      return false;
    }
  }
  return true;
}

38
39
40
41
42
43
/*!
 * \brief Detect if a statement follows the global memory copy pattern:
 *        1. Contains exactly one buffer store operation
 *        2. Source buffer must be in global memory scope
 *        3. Destination buffer must be in local or shared memory scope
 */
44
class BufferRegionCollector : public StmtExprVisitor {
45
public:
46
  BufferRegionCollector(Map<Var, Buffer> buffer_data_to_buffer)
47
      : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)) {}
48
49
50
51
52
53
54

  Array<BufferRegion> GetReads() const { return reads_; }

  Array<BufferRegion> GetWrites() const { return writes_; }

  bool GetGlobalCopyPattern() const { return is_global_copy_pattern_; }

55
56
57
private:
  void VisitStmt_(const BufferStoreNode *op) final {
    Buffer store_buffer = op->buffer;
58
59
60
61
62
63
64
65
66
    Array<PrimExpr> indices = op->indices;
    // convert indices to region
    Array<Range> region;
    for (const auto &index : indices) {
      region.push_back(Range::FromMinExtent(index, 1));
    }
    auto store_region = BufferRegion(store_buffer, region);
    writes_.push_back(store_region);

67
68
69
    is_global_read_ = false;
    this->VisitExpr(op->value);
    if (is_global_read_ && (store_buffer.scope() == "shared" ||
70
                            store_buffer.scope() == "shared.dyn")) {
71
72
73
74
75
76
      is_global_copy_pattern_ = true;
    }
    is_global_read_ = false;
  }

  void VisitExpr_(const BufferLoadNode *op) final {
77
78
79
80
81
82
83
84
85
86
    auto load_buffer = op->buffer;
    Array<PrimExpr> indices = op->indices;
    // convert indices to region
    Array<Range> region;
    for (const auto &index : indices) {
      region.push_back(Range::FromMinExtent(index, 1));
    }
    auto load_region = BufferRegion(load_buffer, region);
    reads_.push_back(load_region);

87
88
89
90
91
    if (op->buffer.scope() == "global" && !within_condition_expr_) {
      // skip condition expr of if_then_else node
      // shared[i] = T.if_then_else(global[i] < n, register_a[i], register_b[i])
      // is not a global read shared[i] = T.if_then_else(global[i] < n,
      // global_a[i], global_b[i]) is a global read
92
93
94
95
96
97
      is_global_read_ = true;
    }
  }

  void VisitExpr_(const CallNode *op) final {
    auto args = op->args;
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    if (op->op.same_as(builtin::address_of())) {
      const BufferLoad load = Downcast<BufferLoad>(op->args[0]);
      const BufferRegion buffer_region = BufferRegion::FullRegion(load->buffer);
      // because we only care about the buffer itself instead of indices
      reads_.push_back(buffer_region);
    } else if (op->op.same_as(builtin::tvm_access_ptr())) {
      const VarNode *buffer_var = op->args[1].as<VarNode>();
      ICHECK(buffer_var);
      auto it = buffer_data_to_buffer_.find(GetRef<Var>(buffer_var));
      if (it != buffer_data_to_buffer_.end()) {
        const Buffer &buffer = (*it).second;
        const BufferRegion buffer_region = BufferRegion::FullRegion(buffer);
        // because we only care about the buffer itself instead of indices
        reads_.push_back(buffer_region);
      }
113
114
115
116
117
118
119
    } else if (op->op.same_as(builtin::if_then_else())) {
      within_condition_expr_ = true;
      this->VisitExpr(op->args[0]);
      within_condition_expr_ = false;
      for (auto i = 1; i < op->args.size(); i++) {
        this->VisitExpr(op->args[i]);
      }
120
121
    } else {
      StmtExprVisitor::VisitExpr_(op);
122
123
124
    }
  }

125
126
127
128
129
130
131
132
133
134
135
136
  void VisitStmt_(const IfThenElseNode *op) final {
    within_condition_expr_ = true;
    this->VisitExpr(op->condition);
    within_condition_expr_ = false;
    this->VisitStmt(op->then_case);
    if (op->else_case.defined()) {
      within_condition_expr_ = true;
      this->VisitStmt(op->else_case.value());
      within_condition_expr_ = false;
    }
  }

137
private:
138
139
140
  Map<Var, Buffer> buffer_data_to_buffer_;
  Array<BufferRegion> reads_;
  Array<BufferRegion> writes_;
141
142
143
  bool is_global_read_ = false;
  bool under_buffer_store_ = false;
  bool is_global_copy_pattern_ = false;
144
  bool within_condition_expr_ = false;
145
146
};

147
class PipelinePlanner : public StmtExprMutator {
148
public:
149
150
  static Stmt Substitute(const PrimFunc &f, bool use_async_copy = true) {
    PipelinePlanner substituter(use_async_copy);
151
    for (const auto &[_, buffer] : f->buffer_map) {
152
153
154
      substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
155
156
    ICHECK(target.defined())
        << "Pipeline_Planning: Require the target attribute";
157
158
159
160
    substituter.target_ = target.value();
    return substituter.VisitStmt(f->body);
  }

161
private:
162
  PipelinePlanner() = default;
163
  PipelinePlanner(bool use_async_copy) : use_async_copy_(use_async_copy) {}
164

165
166
167
168
  /*! \brief Information about a pipeline stage
   *
   * \param reads Array of buffer regions read by this stage
   * \param writes Array of buffer regions written by this stage
169
   * \param original_stmt_index Original position of this stage in the pipeline
170
171
172
173
   * before reordering \param order Current position of this stage in the
   * pipeline after reordering (-1 if not yet assigned) \param stage Pipeline
   * stage number this operation belongs to (-1 if not yet assigned) \param
   * copy_stage Whether this stage is a memory copy operation \param
174
175
176
177
178
179
180
181
182
183
   * last_use_stmt_index Index of the last statement (in original order) that
   * uses the results of this stage (-1 if not yet determined). This field is
   * crucial for pipeline optimization:
   * - For copy stages: indicates the index of the last statement that reads
   * from the copied data, helping determine optimal placement of copy
   * operations
   * - Used to ensure copy operations are scheduled before their consumers
   * - A value of -1 means no subsequent statement uses this stage's output
   * - This information enables better pipeline scheduling by minimizing data
   *   dependencies and maximizing parallelism
184
   */
185
186
  struct PipelineStageInfo {
    Array<BufferRegion> reads, writes;
187
    int original_stmt_index{};
188
189
    int order = -1, stage = -1;
    bool copy_stage = false;
190
191
192
193
194
195
196
197
198
199
200
    bool producer_for_copy = false;
    int last_use_stmt_index =
        -1; // Initialized to -1, indicating no consumers found yet

  public:
    bool is_first_stage() const { return copy_stage || producer_for_copy; }
    bool is_copy_stage() const { return copy_stage; }
    bool is_producer_for_copy() const { return producer_for_copy; }
    bool is_last_use_stmt_index_valid() const {
      return last_use_stmt_index != -1;
    }
201
202
203
  };

  PipelineStageInfo MakePipelineStageInfo(Stmt stmt, int idx) {
204
    Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
205
                /*body*/ std::move(stmt));
206
207
    Array<Array<BufferRegion>> access =
        GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
208
209
    auto collector = BufferRegionCollector(buffer_data_to_buffer_);
    collector(block);
210
    PipelineStageInfo pinfo;
211
212
    pinfo.reads = std::move(collector.GetReads());
    pinfo.writes = std::move(collector.GetWrites());
213
    pinfo.original_stmt_index = idx;
214
    pinfo.copy_stage = collector.GetGlobalCopyPattern();
215
216
217
    return std::move(pinfo);
  }

218
  Stmt VisitStmt_(const ForNode *loop) final {
219
220
    auto order_anno = loop->annotations.Get("tl_pipeline_order");
    auto stage_anno = loop->annotations.Get("tl_pipeline_stage");
221
    auto num_stages_anno = loop->annotations.Get("num_stages");
222
    if (order_anno && stage_anno) {
223
224
225
      // Check if order_anno or stage_anno contains -1, which means TMA+WS is
      // enabled
      bool ws_tma_enabled = false;
226
227
      auto order_array = Downcast<Array<Integer>>(order_anno.value());
      auto stage_array = Downcast<Array<Integer>>(stage_anno.value());
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
      for (const auto &val : order_array) {
        if (val->value == -1) {
          ws_tma_enabled = true;
          break;
        }
      }
      if (!ws_tma_enabled) {
        for (const auto &val : stage_array) {
          if (val->value == -1) {
            ws_tma_enabled = true;
            break;
          }
        }
      }

      if (ws_tma_enabled) {
        return StmtExprMutator::VisitStmt_(loop);
      }

247
      Map<String, Any> annotations;
248
249
250
251
252
      for (const auto &[key, value] : loop->annotations) {
        if (key != "tl_pipeline_order") {
          annotations.Set(key, value);
        }
      }
253
      annotations.Set(tir::attr::software_pipeline_order, order_anno.value());
254
255
256
257
258
259

      for (const auto &[key, value] : loop->annotations) {
        if (key != "tl_pipeline_stage") {
          annotations.Set(key, value);
        }
      }
260
      annotations.Set(tir::attr::software_pipeline_stage, stage_anno.value());
261
      if (TargetHasAsyncCopy(target_) && use_async_copy_)
262
263
264
265
266
267
268
        annotations.Set(tir::attr::software_pipeline_async_stages,
                        Array<Integer>{0});
      auto for_node = GetRef<For>(loop);
      for_node.CopyOnWrite()->annotations = annotations;
      return for_node;
    }

269
    if (!num_stages_anno)
270
      return StmtExprMutator::VisitStmt_(loop);
271
    int num_stages = num_stages_anno->as<IntImmNode>()->value;
272
    Stmt pipeline_body{nullptr};
273
274
275
    if (const auto *realize = loop->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
276
277
278
        ICHECK(buffer->IsInstance<BufferNode>());
        buffer_data_to_buffer_.Set(buffer->data, buffer);
      }
279
280
281
282
283
284
285
286
287
288
289
290
      if (const auto *seq_stmt = block->body.as<SeqStmtNode>()) {
        pipeline_body = block->body;
      } else if (const auto *if_then_else = block->body.as<IfThenElseNode>()) {
        // should assert else case is nullptr
        ICHECK(!if_then_else->else_case.defined())
            << "Pipeline_Planning: Can't handle the body of the loop because "
               "it is not a SeqStmt";
        pipeline_body = if_then_else->then_case;
      } else {
        LOG(FATAL) << "Pipeline_Planning: Can't handle the body of the loop "
                      "because it is not a SeqStmt or IfThenElse";
      }
291
292
293
    } else {
      pipeline_body = loop->body;
    }
294
    const SeqStmtNode *pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
295
296
297
298
    CHECK(pipeline_body_seq)
        << "ValueError: The body of the software pipeline "
           "should be SeqStmt, got "
        << pipeline_body->GetTypeKey() << " " << pipeline_body;
299
300
301
302
303
304
305
306
307
    CHECK(num_stages >= 1);
    CHECK(loop->kind == ForKind::kSerial);

    std::vector<PipelineStageInfo> pipeline_stage_infos;
    for (size_t i = 0; i < pipeline_body_seq->size(); i++) {
      auto pinfo = MakePipelineStageInfo(pipeline_body_seq->seq[i], i);
      pipeline_stage_infos.push_back(std::move(pinfo));
    }

308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
    // For every copy stage, mark all its dependency stages as producer_for_copy
    // Helper struct to manage copy stage dependency reads
    struct CopyStageDependencyReadsManager {
      std::vector<BufferRegion> regions;

      // Add a region if not already present (by structural equality)
      void AddUnique(const BufferRegion &region) {
        for (const BufferRegion &copy_read : regions) {
          if (region->buffer.same_as(copy_read->buffer)) {
            return;
          }
        }
        regions.push_back(region);
      }

      // Check if a region is present (by structural equality)
      bool Contains(const BufferRegion &region) const {
        for (const BufferRegion &copy_read : regions) {
          if (region->buffer.same_as(copy_read->buffer)) {
            return true;
          }
        }
        return false;
      }

      size_t Size() const { return regions.size(); }
    };

    CopyStageDependencyReadsManager copy_stage_dependency_reads_mgr;

    // Step 1. Collect Copy reads
    for (const auto &pinfo : pipeline_stage_infos) {
      if (pinfo.is_copy_stage()) {
        for (const BufferRegion &read : pinfo.reads) {
          copy_stage_dependency_reads_mgr.AddUnique(read);
        }
      }
    }

    // Step 2. find if pinfo write the copy reads, then update the
    // copy_stage_dependency_reads To prevent infinite loops, we set a maximum
    // number of iterations. In theory, the number of possible updates is
    // bounded by the number of pipeline stages, since each stage can only be
    // marked as producer_for_copy once, and each read can only be added once.
    // But for safety, we add a hard limit.
    const size_t max_iterations = (pipeline_stage_infos.size() * 4) + 16;
    size_t iter_count = 0;

356
    for (auto &pinfo : pipeline_stage_infos) {
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
      if (!pinfo.is_copy_stage()) {
        continue;
      }
      auto original_copy_stmt_index = pinfo.original_stmt_index;
      bool updated = true;
      while (updated) {
        updated = false;
        for (auto &pinfo_inner : pipeline_stage_infos) {
          if (pinfo_inner.is_copy_stage()) {
            continue;
          }
          if (pinfo_inner.original_stmt_index >= original_copy_stmt_index) {
            break;
          }

          bool should_prepare = false;
          for (const BufferRegion &write : pinfo_inner.writes) {
            if (copy_stage_dependency_reads_mgr.Contains(write)) {
              should_prepare = true;
              break;
            }
          }
          if (should_prepare && !pinfo_inner.is_producer_for_copy()) {
            pinfo_inner.producer_for_copy = true;
            updated = true;
          }
          if (should_prepare) {
            for (const BufferRegion &read : pinfo_inner.reads) {
              size_t before = copy_stage_dependency_reads_mgr.Size();
              copy_stage_dependency_reads_mgr.AddUnique(read);
              if (copy_stage_dependency_reads_mgr.Size() > before) {
                updated = true;
389
              }
390
            }
391
392
          }
        }
393
394
395
396
397
398
399
        iter_count++;
        if (iter_count > max_iterations) {
          LOG(FATAL)
              << "Pipeline planning: Exceeded maximum iterations ("
              << max_iterations << ") in copy stage dependency propagation. "
              << "This may indicate a cyclic or pathological dependency graph.";
        }
400
401
402
      }
    }

403
404
405
406
407
    // Analysis use-def chain to determine last_use_stmt_index for copy
    // operations This step is critical for pipeline optimization as it
    // identifies the index of the last statement that consumes data produced by
    // copy stages, enabling optimal placement of copy operations in the
    // pipeline schedule.
408
    for (auto &pinfo : pipeline_stage_infos) {
409
410
411
412
413
414
      // Only analyze copy stages (memory copy operations)
      if (!pinfo.is_first_stage())
        continue;

      // Check all subsequent statements to find the latest consumer
      for (int i = pinfo.original_stmt_index + 1;
415
           i < static_cast<int>(pipeline_body_seq->size()); i++) {
416
417
418

        // Check if any read operation in statement 'i' uses data written by
        // this copy stage
419
        for (const BufferRegion &read : pipeline_stage_infos[i].reads) {
420
421
          // Look for overlapping buffer regions between this stage's writes and
          // stage 'i's reads
422
423
424
425
426
          if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(),
                           [&](const BufferRegion &r) {
                             return r->buffer == read->buffer &&
                                    MayConflict(r->region, read->region);
                           }) != pinfo.writes.end()) {
427
428
429
430
            // Update last_use_stmt_index to the maximum (latest) statement
            // index that uses this data This ensures we capture the final
            // consumer of the copied data
            pinfo.last_use_stmt_index = std::max(pinfo.last_use_stmt_index, i);
431
432
          }
        }
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
        // Check for write-after-write conflicts (multiple stages writing to
        // same buffer region) This is important for pipeline correctness and
        // affects last_use_stmt_index analysis
        if (pinfo.is_copy_stage()) {
          for (const BufferRegion &write : pipeline_stage_infos[i].writes) {
            if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(),
                             [&](const BufferRegion &r) {
                               return r->buffer == write->buffer &&
                                      MayConflict(r->region, write->region);
                             }) != pinfo.writes.end()) {
              LOG(FATAL) << "Pipeline planning error: Multiple writes to "
                            "overlapping buffer regions detected. "
                         << "Stage " << pinfo.original_stmt_index
                         << " and stage " << i
                         << " are both writing to buffer '"
                         << write->buffer->name
                         << "' with overlapping regions. This is not supported "
                            "in pipeline planning.";
            }
452
453
454
455
456
457
458
          }
        }
      }
    }

    // Making stages and orders
    int order_idx = 0;
459
    // Stage 1. Create pipeline stages and assign order
460
    for (auto &pinfo : pipeline_stage_infos) {
461
      // Skip elements that must be in first stage:
462
463
464
465
466
      // 1. Copy stages (with active last_use_stmt_index) - these need special
      // handling
      //    because they have consumers that depend on their data
      // 2. All Producer stages for copy stages.
      if (pinfo.is_first_stage() && pinfo.is_last_use_stmt_index_valid()) {
467
        continue;
468
      }
469

470
471
472
      // Main logic stage assignment:
      // - Increment order index
      // - Assign to new stage (current num_stages)
473
474
      pinfo.order = order_idx++;
      pinfo.stage = num_stages;
475

476
477
478
      // Schedule copy stages that have this stage as their last consumer
      // This ensures copy operations are placed right before their final
      // consumer for optimal pipeline efficiency
479
      for (auto &pinfo_1 : pipeline_stage_infos) {
480
481
        if ((pinfo_1.is_first_stage() &&
             pinfo_1.last_use_stmt_index == pinfo.original_stmt_index)) {
482
          pinfo_1.order = order_idx++;
483
          pinfo_1.stage = 0; // Copy stages are typically assigned to stage 0
484
        }
485
486
487
      }
    }

488
489
490
491
492
    ICHECK(size_t(order_idx) == pipeline_stage_infos.size())
        << "The number of stages should be equal to the number of pipeline "
           "stages. "
        << "Got " << order_idx << " stages and " << pipeline_stage_infos.size()
        << " pipeline stages.";
493

494
495
    // Step 2. if all the copy is at the end of the order, we can move these
    // copy to the beginning of the order and shrink the stage offset by 1.
496
497
498
499
    int copy_stage_at_end = [&]() {
      int copy_stage_cnt = 0;
      int copy_order_min = pipeline_stage_infos.size();
      int non_copy_order_max = 0;
500
      for (auto &pinfo : pipeline_stage_infos) {
501
        if (pinfo.is_first_stage()) {
502
503
504
505
506
507
          copy_stage_cnt++;
          copy_order_min = std::min(copy_order_min, pinfo.order);
        } else {
          non_copy_order_max = std::max(non_copy_order_max, pinfo.order);
        }
      }
508
509
      if (copy_order_min > non_copy_order_max)
        return copy_stage_cnt;
510
511
512
      return -1;
    }();
    if (copy_stage_at_end > 0 && num_stages >= 2) {
513
514
515
      for (auto &pinfo : pipeline_stage_infos) { // move copy to the beginning
        pinfo.order =
            (pinfo.order + copy_stage_at_end) % pipeline_stage_infos.size();
516
        if (!pinfo.is_copy_stage() && !pinfo.is_producer_for_copy())
517
          pinfo.stage--;
518
519
520
521
      }
    }

    // Finally, make the pipeline annotation
522
    Map<String, Any> annotations;
523
    for (const auto &[key, value] : loop->annotations) {
524
525
526
527
528
529
530
531
      if (key != "num_stages") {
        annotations.Set(key, value);
      }
    }

    std::vector<Integer> orders, stages;
    orders.reserve(pipeline_stage_infos.size());
    stages.reserve(pipeline_stage_infos.size());
532
    for (auto &pinfo : pipeline_stage_infos) {
533
534
535
536
537
538
      orders.push_back(pinfo.order);
      stages.push_back(pinfo.stage);
    }

    annotations.Set(tir::attr::software_pipeline_stage, Array<Integer>(stages));
    annotations.Set(tir::attr::software_pipeline_order, Array<Integer>(orders));
539
    if (TargetHasAsyncCopy(target_) && use_async_copy_)
540
541
      annotations.Set(tir::attr::software_pipeline_async_stages,
                      Array<Integer>{0});
542
543
544
545
546

    return For(loop->loop_var, loop->min, loop->extent, loop->kind, loop->body,
               loop->thread_binding, annotations);
  }

547
548
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const auto &buffer : op->alloc_buffers) {
549
550
551
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
552
    for (const auto &buffer : op->alloc_buffers) {
553
554
555
556
557
558
559
      buffer_data_to_buffer_.erase(buffer->data);
    }
    return std::move(block);
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Target target_;
560
  bool use_async_copy_{};
561
562
563
564
};

tvm::transform::Pass PipelinePlanning() {
  using namespace tir::transform;
565
  auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
566
567
    bool use_async_copy =
        ctx->GetConfig<Bool>("tir.use_async_copy", Bool(true)).value();
568
    PrimFuncNode *fptr = f.CopyOnWrite();
569
    fptr->body = PipelinePlanner::Substitute(f, use_async_copy);
570
571
572
573
574
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.PipelinePlanning", {});
}

575
576
577
578
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.PipelinePlanning", PipelinePlanning);
});
579

580
581
} // namespace tl
} // namespace tvm