pipeline_planning.cc 17.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file pipeline_planning.cc
 * \brief Plan the software pipeline
 */

#include <tvm/arith/analyzer.h>
#include <tvm/tir/analysis.h>
27
#include <tvm/tir/builtin.h>
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

/*!
 * \brief Check whether two regions have intersections.
 * \param region1 The first region.
 * \param region2 The second region.
 * \return Whether region1 and region2 have intersections.
 */
bool MayConflict(Region region1, Region region2) {
  ICHECK(region1.size() == region2.size());
  for (size_t i = 0; i < region1.size(); i++) {
    Range dim1 = region1[i];
    Range dim2 = region2[i];
    auto int_set1 = arith::IntSet::FromRange(dim1);
    auto int_set2 = arith::IntSet::FromRange(dim2);
    if (arith::Intersect({int_set1, int_set2}).IsNothing()) {
      return false;
    }
  }
  return true;
}

58
59
60
61
62
63
/*!
 * \brief Detect if a statement follows the global memory copy pattern:
 *        1. Contains exactly one buffer store operation
 *        2. Source buffer must be in global memory scope
 *        3. Destination buffer must be in local or shared memory scope
 */
64
class BufferRegionCollector : public StmtExprVisitor {
65
public:
66
67
68
69
70
71
72
73
74
75
  BufferRegionCollector(Map<Var, Buffer> buffer_data_to_buffer)
      : buffer_data_to_buffer_(buffer_data_to_buffer) {}

  Array<BufferRegion> GetReads() const { return reads_; }

  Array<BufferRegion> GetWrites() const { return writes_; }

  bool GetGlobalCopyPattern() const { return is_global_copy_pattern_; }

  PrimExpr GetConditonalExpr() const { return conditonal_expr; }
76
77
78
79

private:
  void VisitStmt_(const BufferStoreNode *op) final {
    Buffer store_buffer = op->buffer;
80
81
82
83
84
85
86
87
88
    Array<PrimExpr> indices = op->indices;
    // convert indices to region
    Array<Range> region;
    for (const auto &index : indices) {
      region.push_back(Range::FromMinExtent(index, 1));
    }
    auto store_region = BufferRegion(store_buffer, region);
    writes_.push_back(store_region);

89
90
91
92
93
94
95
96
97
98
99
    is_global_read_ = false;
    this->VisitExpr(op->value);
    if (is_global_read_ && (store_buffer.scope() == "shared" ||
                            store_buffer.scope() == "shared.dyn" ||
                            store_buffer.scope() == "local")) {
      is_global_copy_pattern_ = true;
    }
    is_global_read_ = false;
  }

  void VisitExpr_(const BufferLoadNode *op) final {
100
101
102
103
104
105
106
107
108
109
    auto load_buffer = op->buffer;
    Array<PrimExpr> indices = op->indices;
    // convert indices to region
    Array<Range> region;
    for (const auto &index : indices) {
      region.push_back(Range::FromMinExtent(index, 1));
    }
    auto load_region = BufferRegion(load_buffer, region);
    reads_.push_back(load_region);

110
111
112
113
114
115
116
    if (op->buffer.scope() == "global") {
      is_global_read_ = true;
    }
  }

  void VisitExpr_(const CallNode *op) final {
    auto args = op->args;
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    if (op->op.same_as(builtin::address_of())) {
      const BufferLoad load = Downcast<BufferLoad>(op->args[0]);
      const BufferRegion buffer_region = BufferRegion::FullRegion(load->buffer);
      // because we only care about the buffer itself instead of indices
      reads_.push_back(buffer_region);
    } else if (op->op.same_as(builtin::tvm_access_ptr())) {
      const VarNode *buffer_var = op->args[1].as<VarNode>();
      ICHECK(buffer_var);
      auto it = buffer_data_to_buffer_.find(GetRef<Var>(buffer_var));
      if (it != buffer_data_to_buffer_.end()) {
        const Buffer &buffer = (*it).second;
        const BufferRegion buffer_region = BufferRegion::FullRegion(buffer);
        // because we only care about the buffer itself instead of indices
        reads_.push_back(buffer_region);
      }
    } else if (op->op.same_as(tir::builtin::if_then_else())) {
133
134
135
136
137
138
139
      // Simplify nested if_then_else
      // if (cond) { if (inner_cond) { inner_then_expr } else { inner_else_expr
      // } } else { else_expr }
      // => if (cond && inner_cond) { inner_then_expr } else { else_expr }
      const PrimExpr &cond = op->args[0];
      const PrimExpr &then_expr = op->args[1];
      const PrimExpr &else_expr = op->args[2];
140
      conditonal_expr = cond;
141
142
      this->VisitExpr(then_expr);
      this->VisitExpr(else_expr);
143
144
    } else {
      StmtExprVisitor::VisitExpr_(op);
145
146
147
148
149
150
    }
  }

  void VisitStmt_(const IfThenElseNode *op) final {
    // Skip condition
    this->VisitStmt(op->then_case);
151
    conditonal_expr = op->condition;
152
153
154
155
156
157
    if (op->else_case.defined()) {
      this->VisitStmt(op->else_case.value());
    }
  }

private:
158
159
160
  Map<Var, Buffer> buffer_data_to_buffer_;
  Array<BufferRegion> reads_;
  Array<BufferRegion> writes_;
161
162
163
  bool is_global_read_ = false;
  bool under_buffer_store_ = false;
  bool is_global_copy_pattern_ = false;
164
  PrimExpr conditonal_expr;
165
166
};

167
class PipelinePlanner : public StmtExprMutator {
168
169
public:
  static Stmt Substitute(const PrimFunc &f) {
170
    PipelinePlanner substituter;
171
    for (const auto &[_, buffer] : f->buffer_map) {
172
173
174
      substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
175
176
    ICHECK(target.defined())
        << "Pipeline_Planning: Require the target attribute";
177
178
179
180
    substituter.target_ = target.value();
    return substituter.VisitStmt(f->body);
  }

181
private:
182
183
  PipelinePlanner() = default;

184
185
186
187
188
189
190
191
192
193
194
195
  /*! \brief Information about a pipeline stage
   *
   * \param reads Array of buffer regions read by this stage
   * \param writes Array of buffer regions written by this stage
   * \param original_order Original position of this stage in the pipeline
   * before reordering \param order Current position of this stage in the
   * pipeline after reordering (-1 if not yet assigned) \param stage Pipeline
   * stage number this operation belongs to (-1 if not yet assigned) \param
   * copy_stage Whether this stage is a memory copy operation \param
   * last_use_stage Last pipeline stage that uses the results of this stage (-1
   * if not yet determined)
   */
196
197
198
199
200
  struct PipelineStageInfo {
    Array<BufferRegion> reads, writes;
    int original_order;
    int order = -1, stage = -1;
    bool copy_stage = false;
201
    bool prepare_for_condition = false;
202
    int last_use_stage = -1;
203
204
    // represent the stage is used in a conditional statement
    PrimExpr conditonal_expr;
205
206
207
  };

  PipelineStageInfo MakePipelineStageInfo(Stmt stmt, int idx) {
208
209
210
211
    Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
                /*body*/ stmt);
    Array<Array<BufferRegion>> access =
        GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
212
213
    auto collector = BufferRegionCollector(buffer_data_to_buffer_);
    collector(block);
214
    PipelineStageInfo pinfo;
215
216
    pinfo.reads = std::move(collector.GetReads());
    pinfo.writes = std::move(collector.GetWrites());
217
    pinfo.original_order = idx;
218
219
    pinfo.copy_stage = collector.GetGlobalCopyPattern();
    pinfo.conditonal_expr = collector.GetConditonalExpr();
220
221
222
    return std::move(pinfo);
  }

223
  Stmt VisitStmt_(const ForNode *loop) final {
224
225
    auto order_anno = loop->annotations.Get("tl_pipeline_order");
    auto stage_anno = loop->annotations.Get("tl_pipeline_stage");
226
    auto num_stages_anno = loop->annotations.Get("num_stages");
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    if (order_anno.defined() && stage_anno.defined()) {
      Map<String, ObjectRef> annotations;
      for (const auto &[key, value] : loop->annotations) {
        if (key != "tl_pipeline_order") {
          annotations.Set(key, value);
        }
      }
      annotations.Set(tir::attr::software_pipeline_order, order_anno);

      for (const auto &[key, value] : loop->annotations) {
        if (key != "tl_pipeline_stage") {
          annotations.Set(key, value);
        }
      }
      annotations.Set(tir::attr::software_pipeline_stage, stage_anno);
      if (TargetHasAsyncCopy(target_))
        annotations.Set(tir::attr::software_pipeline_async_stages,
                        Array<Integer>{0});
      auto for_node = GetRef<For>(loop);
      for_node.CopyOnWrite()->annotations = annotations;
      return for_node;
    }

250
251
    if (!num_stages_anno.defined())
      return StmtExprMutator::VisitStmt_(loop);
252
253
    int num_stages = num_stages_anno.as<IntImmNode>()->value;
    Stmt pipeline_body{nullptr};
254
255
256
    if (const auto *realize = loop->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
257
258
259
        ICHECK(buffer->IsInstance<BufferNode>());
        buffer_data_to_buffer_.Set(buffer->data, buffer);
      }
260
261
262
263
264
265
266
267
268
269
270
271
      if (const auto *seq_stmt = block->body.as<SeqStmtNode>()) {
        pipeline_body = block->body;
      } else if (const auto *if_then_else = block->body.as<IfThenElseNode>()) {
        // should assert else case is nullptr
        ICHECK(!if_then_else->else_case.defined())
            << "Pipeline_Planning: Can't handle the body of the loop because "
               "it is not a SeqStmt";
        pipeline_body = if_then_else->then_case;
      } else {
        LOG(FATAL) << "Pipeline_Planning: Can't handle the body of the loop "
                      "because it is not a SeqStmt or IfThenElse";
      }
272
273
274
    } else {
      pipeline_body = loop->body;
    }
275
    const SeqStmtNode *pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
276
277
278
279
    CHECK(pipeline_body_seq)
        << "ValueError: The body of the software pipeline "
           "should be SeqStmt, got "
        << pipeline_body->GetTypeKey() << " " << pipeline_body;
280
281
282
283
284
285
286
287
288
    CHECK(num_stages >= 1);
    CHECK(loop->kind == ForKind::kSerial);

    std::vector<PipelineStageInfo> pipeline_stage_infos;
    for (size_t i = 0; i < pipeline_body_seq->size(); i++) {
      auto pinfo = MakePipelineStageInfo(pipeline_body_seq->seq[i], i);
      pipeline_stage_infos.push_back(std::move(pinfo));
    }

289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    // process the conditional stage
    // assign conditional stage (analysis the copy stage)
    for (auto &pinfo : pipeline_stage_infos) {
      for (const auto &write : pinfo.writes) {
        for (const auto &other : pipeline_stage_infos) {
          if (other.conditonal_expr.defined()) {
            auto check_var = [&](const ObjectRef &n) {
              if (const auto *buffer_load = n.as<BufferLoadNode>()) {
                if (buffer_load->buffer == write->buffer) {
                  pinfo.prepare_for_condition = true;
                }
              }
            };
            PostOrderVisit(other.conditonal_expr, check_var);
          }
        }
      }
    }

308
    // analysis use-def chain
309
310
311
312
313
314
315
316
317
318
319
    for (auto &pinfo : pipeline_stage_infos) {
      for (int i = pinfo.original_order + 1;
           i < static_cast<int>(pipeline_body_seq->size()); i++) {
        if (!pinfo.copy_stage)
          continue;
        for (const BufferRegion &read : pipeline_stage_infos[i].reads) {
          if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(),
                           [&](const BufferRegion &r) {
                             return r->buffer == read->buffer &&
                                    MayConflict(r->region, read->region);
                           }) != pinfo.writes.end()) {
320
321
322
            pinfo.last_use_stage = std::max(pinfo.last_use_stage, i);
          }
        }
323
324
325
326
327
328
        for (const BufferRegion &write : pipeline_stage_infos[i].writes) {
          if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(),
                           [&](const BufferRegion &r) {
                             return r->buffer == write->buffer &&
                                    MayConflict(r->region, write->region);
                           }) != pinfo.writes.end()) {
329
330
331
332
333
334
            LOG(FATAL) << "Pipeline planning error: Multiple writes to "
                          "overlapping buffer regions detected. "
                       << "Stage " << pinfo.original_order << " and stage " << i
                       << " are both writing to buffer '" << write->buffer->name
                       << "' with overlapping regions. This is not supported "
                          "in pipeline planning.";
335
336
337
338
339
340
341
          }
        }
      }
    }

    // Making stages and orders
    int order_idx = 0;
342
    // Create pipeline stages and assign order
343
    for (auto &pinfo : pipeline_stage_infos) {
344
345
346
347
348
      // Skip elements that must be in first stage:
      // 1. Copy stages (with active last_use_stage)
      // 2. Condition preparation stages
      if ((pinfo.copy_stage && pinfo.last_use_stage != -1) ||
          pinfo.prepare_for_condition)
349
        continue;
350

351
352
353
      // Main logic stage assignment:
      // - Increment order index
      // - Assign to new stage (current num_stages)
354
355
      pinfo.order = order_idx++;
      pinfo.stage = num_stages;
356

357
      for (auto &pinfo_1 : pipeline_stage_infos) {
358
359
        if ((pinfo_1.copy_stage &&
             pinfo_1.last_use_stage == pinfo.original_order)) {
360
361
362
363
364
          pinfo_1.order = order_idx++;
          pinfo_1.stage = 0;
        }
      }
    }
365
366
367
368
369

    // Handle trailing unassigned copy stages:
    // These are typically final copy operations needing post-main-stage
    // insertion

370
    auto &head_pinfo = pipeline_stage_infos.at(0);
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    int unassigned_order_elem = -1;

    // Process dependent copy stages:
    // Insert copy stages after current stage but assign to stage 0
    // and adjust the order index
    for (auto &pinfo : pipeline_stage_infos) {
      if (pinfo.order == unassigned_order_elem) {
        pinfo.order = unassigned_order_elem++;
        // traverse the from the next info
        for (auto it = pipeline_stage_infos.begin() + unassigned_order_elem;
             it != pipeline_stage_infos.end(); it++) {
          it->order += 1;
        }
        pinfo.stage = 0;
        order_idx++;
386
387
388
      }
    }

389
390
391
392
393
    ICHECK(size_t(order_idx) == pipeline_stage_infos.size())
        << "The number of stages should be equal to the number of pipeline "
           "stages. "
        << "Got " << order_idx << " stages and " << pipeline_stage_infos.size()
        << " pipeline stages.";
394

395
396
    // if all the copy is at the end of the order, we can move these copy to the
    // beginning of the order and shrink the stage offset by 1.
397
398
399
400
    int copy_stage_at_end = [&]() {
      int copy_stage_cnt = 0;
      int copy_order_min = pipeline_stage_infos.size();
      int non_copy_order_max = 0;
401
      for (auto &pinfo : pipeline_stage_infos) {
402
403
404
405
406
407
408
        if (pinfo.copy_stage) {
          copy_stage_cnt++;
          copy_order_min = std::min(copy_order_min, pinfo.order);
        } else {
          non_copy_order_max = std::max(non_copy_order_max, pinfo.order);
        }
      }
409
410
      if (copy_order_min > non_copy_order_max)
        return copy_stage_cnt;
411
412
413
      return -1;
    }();
    if (copy_stage_at_end > 0 && num_stages >= 2) {
414
415
416
417
418
      for (auto &pinfo : pipeline_stage_infos) { // move copy to the beginning
        pinfo.order =
            (pinfo.order + copy_stage_at_end) % pipeline_stage_infos.size();
        if (!pinfo.copy_stage)
          pinfo.stage--;
419
420
421
422
423
      }
    }

    // Finally, make the pipeline annotation
    Map<String, ObjectRef> annotations;
424
    for (const auto &[key, value] : loop->annotations) {
425
426
427
428
429
430
431
432
      if (key != "num_stages") {
        annotations.Set(key, value);
      }
    }

    std::vector<Integer> orders, stages;
    orders.reserve(pipeline_stage_infos.size());
    stages.reserve(pipeline_stage_infos.size());
433
    for (auto &pinfo : pipeline_stage_infos) {
434
435
436
437
438
439
440
      orders.push_back(pinfo.order);
      stages.push_back(pinfo.stage);
    }

    annotations.Set(tir::attr::software_pipeline_stage, Array<Integer>(stages));
    annotations.Set(tir::attr::software_pipeline_order, Array<Integer>(orders));
    if (TargetHasAsyncCopy(target_))
441
442
      annotations.Set(tir::attr::software_pipeline_async_stages,
                      Array<Integer>{0});
443
444
445
446
447

    return For(loop->loop_var, loop->min, loop->extent, loop->kind, loop->body,
               loop->thread_binding, annotations);
  }

448
449
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const auto &buffer : op->alloc_buffers) {
450
451
452
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
453
    for (const auto &buffer : op->alloc_buffers) {
454
455
456
457
458
459
460
461
462
463
464
465
      buffer_data_to_buffer_.erase(buffer->data);
    }
    return std::move(block);
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Target target_;
};

tvm::transform::Pass PipelinePlanning() {
  using namespace tir::transform;
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
466
    PrimFuncNode *fptr = f.CopyOnWrite();
467
468
469
470
471
472
    fptr->body = PipelinePlanner::Substitute(f);
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.PipelinePlanning", {});
}

473
474
TVM_REGISTER_GLOBAL("tl.transform.PipelinePlanning")
    .set_body_typed(PipelinePlanning);
475

476
477
} // namespace tl
} // namespace tvm