pipeline_planning.cc 17.4 KB
Newer Older
1
#include <tvm/arith/analyzer.h>
2
#include <tvm/ffi/reflection/registry.h>
3
#include <tvm/tir/analysis.h>
4
#include <tvm/tir/builtin.h>
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

/*!
 * \brief Check whether two regions have intersections.
 * \param region1 The first region.
 * \param region2 The second region.
 * \return Whether region1 and region2 have intersections.
 */
bool MayConflict(Region region1, Region region2) {
  ICHECK(region1.size() == region2.size());
  for (size_t i = 0; i < region1.size(); i++) {
    Range dim1 = region1[i];
    Range dim2 = region2[i];
    auto int_set1 = arith::IntSet::FromRange(dim1);
    auto int_set2 = arith::IntSet::FromRange(dim2);
    if (arith::Intersect({int_set1, int_set2}).IsNothing()) {
      return false;
    }
  }
  return true;
}

35
36
37
38
39
40
/*!
 * \brief Detect if a statement follows the global memory copy pattern:
 *        1. Contains exactly one buffer store operation
 *        2. Source buffer must be in global memory scope
 *        3. Destination buffer must be in local or shared memory scope
 */
41
class BufferRegionCollector : public StmtExprVisitor {
42
public:
43
44
45
46
47
48
49
50
51
52
  BufferRegionCollector(Map<Var, Buffer> buffer_data_to_buffer)
      : buffer_data_to_buffer_(buffer_data_to_buffer) {}

  Array<BufferRegion> GetReads() const { return reads_; }

  Array<BufferRegion> GetWrites() const { return writes_; }

  bool GetGlobalCopyPattern() const { return is_global_copy_pattern_; }

  PrimExpr GetConditonalExpr() const { return conditonal_expr; }
53
54
55
56

private:
  void VisitStmt_(const BufferStoreNode *op) final {
    Buffer store_buffer = op->buffer;
57
58
59
60
61
62
63
64
65
    Array<PrimExpr> indices = op->indices;
    // convert indices to region
    Array<Range> region;
    for (const auto &index : indices) {
      region.push_back(Range::FromMinExtent(index, 1));
    }
    auto store_region = BufferRegion(store_buffer, region);
    writes_.push_back(store_region);

66
67
68
    is_global_read_ = false;
    this->VisitExpr(op->value);
    if (is_global_read_ && (store_buffer.scope() == "shared" ||
69
                            store_buffer.scope() == "shared.dyn")) {
70
71
72
73
74
75
      is_global_copy_pattern_ = true;
    }
    is_global_read_ = false;
  }

  void VisitExpr_(const BufferLoadNode *op) final {
76
77
78
79
80
81
82
83
84
85
    auto load_buffer = op->buffer;
    Array<PrimExpr> indices = op->indices;
    // convert indices to region
    Array<Range> region;
    for (const auto &index : indices) {
      region.push_back(Range::FromMinExtent(index, 1));
    }
    auto load_region = BufferRegion(load_buffer, region);
    reads_.push_back(load_region);

86
87
88
89
90
91
92
    if (op->buffer.scope() == "global") {
      is_global_read_ = true;
    }
  }

  void VisitExpr_(const CallNode *op) final {
    auto args = op->args;
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    if (op->op.same_as(builtin::address_of())) {
      const BufferLoad load = Downcast<BufferLoad>(op->args[0]);
      const BufferRegion buffer_region = BufferRegion::FullRegion(load->buffer);
      // because we only care about the buffer itself instead of indices
      reads_.push_back(buffer_region);
    } else if (op->op.same_as(builtin::tvm_access_ptr())) {
      const VarNode *buffer_var = op->args[1].as<VarNode>();
      ICHECK(buffer_var);
      auto it = buffer_data_to_buffer_.find(GetRef<Var>(buffer_var));
      if (it != buffer_data_to_buffer_.end()) {
        const Buffer &buffer = (*it).second;
        const BufferRegion buffer_region = BufferRegion::FullRegion(buffer);
        // because we only care about the buffer itself instead of indices
        reads_.push_back(buffer_region);
      }
    } else if (op->op.same_as(tir::builtin::if_then_else())) {
109
110
111
112
113
114
115
      // Simplify nested if_then_else
      // if (cond) { if (inner_cond) { inner_then_expr } else { inner_else_expr
      // } } else { else_expr }
      // => if (cond && inner_cond) { inner_then_expr } else { else_expr }
      const PrimExpr &cond = op->args[0];
      const PrimExpr &then_expr = op->args[1];
      const PrimExpr &else_expr = op->args[2];
116
      conditonal_expr = cond;
117
118
      this->VisitExpr(then_expr);
      this->VisitExpr(else_expr);
119
120
    } else {
      StmtExprVisitor::VisitExpr_(op);
121
122
123
124
125
126
    }
  }

  void VisitStmt_(const IfThenElseNode *op) final {
    // Skip condition
    this->VisitStmt(op->then_case);
127
    conditonal_expr = op->condition;
128
129
130
131
132
133
    if (op->else_case.defined()) {
      this->VisitStmt(op->else_case.value());
    }
  }

private:
134
135
136
  Map<Var, Buffer> buffer_data_to_buffer_;
  Array<BufferRegion> reads_;
  Array<BufferRegion> writes_;
137
138
139
  bool is_global_read_ = false;
  bool under_buffer_store_ = false;
  bool is_global_copy_pattern_ = false;
140
  PrimExpr conditonal_expr;
141
142
};

143
class PipelinePlanner : public StmtExprMutator {
144
public:
145
146
  static Stmt Substitute(const PrimFunc &f, bool use_async_copy = true) {
    PipelinePlanner substituter(use_async_copy);
147
    for (const auto &[_, buffer] : f->buffer_map) {
148
149
150
      substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
151
152
    ICHECK(target.defined())
        << "Pipeline_Planning: Require the target attribute";
153
154
155
156
    substituter.target_ = target.value();
    return substituter.VisitStmt(f->body);
  }

157
private:
158
  PipelinePlanner() = default;
159
  PipelinePlanner(bool use_async_copy) : use_async_copy_(use_async_copy) {}
160

161
162
163
164
165
166
167
168
169
170
171
172
  /*! \brief Information about a pipeline stage
   *
   * \param reads Array of buffer regions read by this stage
   * \param writes Array of buffer regions written by this stage
   * \param original_order Original position of this stage in the pipeline
   * before reordering \param order Current position of this stage in the
   * pipeline after reordering (-1 if not yet assigned) \param stage Pipeline
   * stage number this operation belongs to (-1 if not yet assigned) \param
   * copy_stage Whether this stage is a memory copy operation \param
   * last_use_stage Last pipeline stage that uses the results of this stage (-1
   * if not yet determined)
   */
173
174
175
176
177
  struct PipelineStageInfo {
    Array<BufferRegion> reads, writes;
    int original_order;
    int order = -1, stage = -1;
    bool copy_stage = false;
178
    bool prepare_for_condition = false;
179
    int last_use_stage = -1;
180
181
    // represent the stage is used in a conditional statement
    PrimExpr conditonal_expr;
182
183
184
  };

  PipelineStageInfo MakePipelineStageInfo(Stmt stmt, int idx) {
185
186
187
188
    Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
                /*body*/ stmt);
    Array<Array<BufferRegion>> access =
        GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
189
190
    auto collector = BufferRegionCollector(buffer_data_to_buffer_);
    collector(block);
191
    PipelineStageInfo pinfo;
192
193
    pinfo.reads = std::move(collector.GetReads());
    pinfo.writes = std::move(collector.GetWrites());
194
    pinfo.original_order = idx;
195
196
    pinfo.copy_stage = collector.GetGlobalCopyPattern();
    pinfo.conditonal_expr = collector.GetConditonalExpr();
197
198
199
    return std::move(pinfo);
  }

200
  Stmt VisitStmt_(const ForNode *loop) final {
201
202
    auto order_anno = loop->annotations.Get("tl_pipeline_order");
    auto stage_anno = loop->annotations.Get("tl_pipeline_stage");
203
    auto num_stages_anno = loop->annotations.Get("num_stages");
204
    if (order_anno && stage_anno) {
205
206
207
      // Check if order_anno or stage_anno contains -1, which means TMA+WS is
      // enabled
      bool ws_tma_enabled = false;
208
209
      auto order_array = Downcast<Array<Integer>>(order_anno.value());
      auto stage_array = Downcast<Array<Integer>>(stage_anno.value());
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
      for (const auto &val : order_array) {
        if (val->value == -1) {
          ws_tma_enabled = true;
          break;
        }
      }
      if (!ws_tma_enabled) {
        for (const auto &val : stage_array) {
          if (val->value == -1) {
            ws_tma_enabled = true;
            break;
          }
        }
      }

      if (ws_tma_enabled) {
        return StmtExprMutator::VisitStmt_(loop);
      }

229
      Map<String, Any> annotations;
230
231
232
233
234
      for (const auto &[key, value] : loop->annotations) {
        if (key != "tl_pipeline_order") {
          annotations.Set(key, value);
        }
      }
235
      annotations.Set(tir::attr::software_pipeline_order, order_anno.value());
236
237
238
239
240
241

      for (const auto &[key, value] : loop->annotations) {
        if (key != "tl_pipeline_stage") {
          annotations.Set(key, value);
        }
      }
242
      annotations.Set(tir::attr::software_pipeline_stage, stage_anno.value());
243
      if (TargetHasAsyncCopy(target_) && use_async_copy_)
244
245
246
247
248
249
250
        annotations.Set(tir::attr::software_pipeline_async_stages,
                        Array<Integer>{0});
      auto for_node = GetRef<For>(loop);
      for_node.CopyOnWrite()->annotations = annotations;
      return for_node;
    }

251
    if (!num_stages_anno)
252
      return StmtExprMutator::VisitStmt_(loop);
253
    int num_stages = num_stages_anno->as<IntImmNode>()->value;
254
    Stmt pipeline_body{nullptr};
255
256
257
    if (const auto *realize = loop->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
258
259
260
        ICHECK(buffer->IsInstance<BufferNode>());
        buffer_data_to_buffer_.Set(buffer->data, buffer);
      }
261
262
263
264
265
266
267
268
269
270
271
272
      if (const auto *seq_stmt = block->body.as<SeqStmtNode>()) {
        pipeline_body = block->body;
      } else if (const auto *if_then_else = block->body.as<IfThenElseNode>()) {
        // should assert else case is nullptr
        ICHECK(!if_then_else->else_case.defined())
            << "Pipeline_Planning: Can't handle the body of the loop because "
               "it is not a SeqStmt";
        pipeline_body = if_then_else->then_case;
      } else {
        LOG(FATAL) << "Pipeline_Planning: Can't handle the body of the loop "
                      "because it is not a SeqStmt or IfThenElse";
      }
273
274
275
    } else {
      pipeline_body = loop->body;
    }
276
    const SeqStmtNode *pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
277
278
279
280
    CHECK(pipeline_body_seq)
        << "ValueError: The body of the software pipeline "
           "should be SeqStmt, got "
        << pipeline_body->GetTypeKey() << " " << pipeline_body;
281
282
283
284
285
286
287
288
289
    CHECK(num_stages >= 1);
    CHECK(loop->kind == ForKind::kSerial);

    std::vector<PipelineStageInfo> pipeline_stage_infos;
    for (size_t i = 0; i < pipeline_body_seq->size(); i++) {
      auto pinfo = MakePipelineStageInfo(pipeline_body_seq->seq[i], i);
      pipeline_stage_infos.push_back(std::move(pinfo));
    }

290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    // process the conditional stage
    // assign conditional stage (analysis the copy stage)
    for (auto &pinfo : pipeline_stage_infos) {
      for (const auto &write : pinfo.writes) {
        for (const auto &other : pipeline_stage_infos) {
          if (other.conditonal_expr.defined()) {
            auto check_var = [&](const ObjectRef &n) {
              if (const auto *buffer_load = n.as<BufferLoadNode>()) {
                if (buffer_load->buffer == write->buffer) {
                  pinfo.prepare_for_condition = true;
                }
              }
            };
            PostOrderVisit(other.conditonal_expr, check_var);
          }
        }
      }
    }

309
    // analysis use-def chain
310
311
312
313
314
315
316
317
318
319
320
    for (auto &pinfo : pipeline_stage_infos) {
      for (int i = pinfo.original_order + 1;
           i < static_cast<int>(pipeline_body_seq->size()); i++) {
        if (!pinfo.copy_stage)
          continue;
        for (const BufferRegion &read : pipeline_stage_infos[i].reads) {
          if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(),
                           [&](const BufferRegion &r) {
                             return r->buffer == read->buffer &&
                                    MayConflict(r->region, read->region);
                           }) != pinfo.writes.end()) {
321
322
323
            pinfo.last_use_stage = std::max(pinfo.last_use_stage, i);
          }
        }
324
325
326
327
328
329
        for (const BufferRegion &write : pipeline_stage_infos[i].writes) {
          if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(),
                           [&](const BufferRegion &r) {
                             return r->buffer == write->buffer &&
                                    MayConflict(r->region, write->region);
                           }) != pinfo.writes.end()) {
330
331
332
333
334
335
            LOG(FATAL) << "Pipeline planning error: Multiple writes to "
                          "overlapping buffer regions detected. "
                       << "Stage " << pinfo.original_order << " and stage " << i
                       << " are both writing to buffer '" << write->buffer->name
                       << "' with overlapping regions. This is not supported "
                          "in pipeline planning.";
336
337
338
339
340
341
342
          }
        }
      }
    }

    // Making stages and orders
    int order_idx = 0;
343
    // Create pipeline stages and assign order
344
    for (auto &pinfo : pipeline_stage_infos) {
345
346
347
348
349
      // Skip elements that must be in first stage:
      // 1. Copy stages (with active last_use_stage)
      // 2. Condition preparation stages
      if ((pinfo.copy_stage && pinfo.last_use_stage != -1) ||
          pinfo.prepare_for_condition)
350
        continue;
351

352
353
354
      // Main logic stage assignment:
      // - Increment order index
      // - Assign to new stage (current num_stages)
355
356
      pinfo.order = order_idx++;
      pinfo.stage = num_stages;
357

358
      for (auto &pinfo_1 : pipeline_stage_infos) {
359
360
        if ((pinfo_1.copy_stage &&
             pinfo_1.last_use_stage == pinfo.original_order)) {
361
362
363
364
365
          pinfo_1.order = order_idx++;
          pinfo_1.stage = 0;
        }
      }
    }
366
367
368
369

    // Handle trailing unassigned copy stages:
    // These are typically final copy operations needing post-main-stage
    // insertion
370
    auto &head_pinfo = pipeline_stage_infos.at(0);
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    int unassigned_order_elem = -1;

    // Process dependent copy stages:
    // Insert copy stages after current stage but assign to stage 0
    // and adjust the order index
    for (auto &pinfo : pipeline_stage_infos) {
      if (pinfo.order == unassigned_order_elem) {
        pinfo.order = unassigned_order_elem++;
        // traverse the from the next info
        for (auto it = pipeline_stage_infos.begin() + unassigned_order_elem;
             it != pipeline_stage_infos.end(); it++) {
          it->order += 1;
        }
        pinfo.stage = 0;
        order_idx++;
386
387
388
      }
    }

389
390
391
392
393
    ICHECK(size_t(order_idx) == pipeline_stage_infos.size())
        << "The number of stages should be equal to the number of pipeline "
           "stages. "
        << "Got " << order_idx << " stages and " << pipeline_stage_infos.size()
        << " pipeline stages.";
394

395
396
    // if all the copy is at the end of the order, we can move these copy to the
    // beginning of the order and shrink the stage offset by 1.
397
398
399
400
    int copy_stage_at_end = [&]() {
      int copy_stage_cnt = 0;
      int copy_order_min = pipeline_stage_infos.size();
      int non_copy_order_max = 0;
401
      for (auto &pinfo : pipeline_stage_infos) {
402
        if (pinfo.copy_stage || pinfo.prepare_for_condition) {
403
404
405
406
407
408
          copy_stage_cnt++;
          copy_order_min = std::min(copy_order_min, pinfo.order);
        } else {
          non_copy_order_max = std::max(non_copy_order_max, pinfo.order);
        }
      }
409
410
      if (copy_order_min > non_copy_order_max)
        return copy_stage_cnt;
411
412
413
      return -1;
    }();
    if (copy_stage_at_end > 0 && num_stages >= 2) {
414
415
416
      for (auto &pinfo : pipeline_stage_infos) { // move copy to the beginning
        pinfo.order =
            (pinfo.order + copy_stage_at_end) % pipeline_stage_infos.size();
417
        if (!pinfo.copy_stage && !pinfo.prepare_for_condition)
418
          pinfo.stage--;
419
420
421
422
      }
    }

    // Finally, make the pipeline annotation
423
    Map<String, Any> annotations;
424
    for (const auto &[key, value] : loop->annotations) {
425
426
427
428
429
430
431
432
      if (key != "num_stages") {
        annotations.Set(key, value);
      }
    }

    std::vector<Integer> orders, stages;
    orders.reserve(pipeline_stage_infos.size());
    stages.reserve(pipeline_stage_infos.size());
433
    for (auto &pinfo : pipeline_stage_infos) {
434
435
436
437
438
439
      orders.push_back(pinfo.order);
      stages.push_back(pinfo.stage);
    }

    annotations.Set(tir::attr::software_pipeline_stage, Array<Integer>(stages));
    annotations.Set(tir::attr::software_pipeline_order, Array<Integer>(orders));
440
    if (TargetHasAsyncCopy(target_) && use_async_copy_)
441
442
      annotations.Set(tir::attr::software_pipeline_async_stages,
                      Array<Integer>{0});
443
444
445
446
447

    return For(loop->loop_var, loop->min, loop->extent, loop->kind, loop->body,
               loop->thread_binding, annotations);
  }

448
449
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const auto &buffer : op->alloc_buffers) {
450
451
452
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
453
    for (const auto &buffer : op->alloc_buffers) {
454
455
456
457
458
459
460
      buffer_data_to_buffer_.erase(buffer->data);
    }
    return std::move(block);
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Target target_;
461
  bool use_async_copy_;
462
463
464
465
466
};

tvm::transform::Pass PipelinePlanning() {
  using namespace tir::transform;
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
467
468
    bool use_async_copy =
        ctx->GetConfig<Bool>("tir.use_async_copy", Bool(true)).value();
469
    PrimFuncNode *fptr = f.CopyOnWrite();
470
    fptr->body = PipelinePlanner::Substitute(f, use_async_copy);
471
472
473
474
475
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.PipelinePlanning", {});
}

476
477
478
479
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.PipelinePlanning", PipelinePlanning);
});
480

481
482
} // namespace tl
} // namespace tvm