pipeline_planning.cc 14 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file pipeline_planning.cc
 * \brief Plan the software pipeline
 */

#include <tvm/arith/analyzer.h>
#include <tvm/tir/analysis.h>
27
#include <tvm/tir/builtin.h>
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

/*!
 * \brief Check whether two regions have intersections.
 * \param region1 The first region.
 * \param region2 The second region.
 * \return Whether region1 and region2 have intersections.
 */
bool MayConflict(Region region1, Region region2) {
  ICHECK(region1.size() == region2.size());
  for (size_t i = 0; i < region1.size(); i++) {
    Range dim1 = region1[i];
    Range dim2 = region2[i];
    auto int_set1 = arith::IntSet::FromRange(dim1);
    auto int_set2 = arith::IntSet::FromRange(dim2);
    if (arith::Intersect({int_set1, int_set2}).IsNothing()) {
      return false;
    }
  }
  return true;
}

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
/*!
 * \brief Detect if a statement follows the global memory copy pattern:
 *        1. Contains exactly one buffer store operation
 *        2. Source buffer must be in global memory scope
 *        3. Destination buffer must be in local or shared memory scope
 */
class GlobalCopyPatternDetector : public StmtExprVisitor {
public:
  static bool Detect(const Stmt &stmt) {
    GlobalCopyPatternDetector detector;
    detector.VisitStmt(stmt);
    return detector.is_global_copy_pattern_;
  }

private:
  void VisitStmt_(const BufferStoreNode *op) final {
    Buffer store_buffer = op->buffer;
    is_global_read_ = false;
    this->VisitExpr(op->value);
    if (is_global_read_ && (store_buffer.scope() == "shared" ||
                            store_buffer.scope() == "shared.dyn" ||
                            store_buffer.scope() == "local")) {
      is_global_copy_pattern_ = true;
    }
    is_global_read_ = false;
  }

  void VisitExpr_(const BufferLoadNode *op) final {
    if (op->buffer.scope() == "global") {
      is_global_read_ = true;
    }
  }

  void VisitExpr_(const CallNode *op) final {
    auto args = op->args;
    if (op->op.same_as(tir::builtin::if_then_else())) {
      // Simplify nested if_then_else
      // if (cond) { if (inner_cond) { inner_then_expr } else { inner_else_expr
      // } } else { else_expr }
      // => if (cond && inner_cond) { inner_then_expr } else { else_expr }
      const PrimExpr &cond = op->args[0];
      const PrimExpr &then_expr = op->args[1];
      const PrimExpr &else_expr = op->args[2];
      this->VisitExpr(then_expr);
      this->VisitExpr(else_expr);
    }
  }

  void VisitStmt_(const IfThenElseNode *op) final {
    // Skip condition
    this->VisitStmt(op->then_case);
    if (op->else_case.defined()) {
      this->VisitStmt(op->else_case.value());
    }
  }

private:
  bool is_global_read_ = false;
  bool under_buffer_store_ = false;
  bool is_global_copy_pattern_ = false;
};

120
class PipelinePlanner : public StmtExprMutator {
121
122
public:
  static Stmt Substitute(const PrimFunc &f) {
123
    PipelinePlanner substituter;
124
    for (const auto &[_, buffer] : f->buffer_map) {
125
126
127
      substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
128
129
    ICHECK(target.defined())
        << "Pipeline_Planning: Require the target attribute";
130
131
132
133
    substituter.target_ = target.value();
    return substituter.VisitStmt(f->body);
  }

134
private:
135
136
  PipelinePlanner() = default;

137
138
139
140
141
142
143
144
145
146
147
148
  /*! \brief Information about a pipeline stage
   *
   * \param reads Array of buffer regions read by this stage
   * \param writes Array of buffer regions written by this stage
   * \param original_order Original position of this stage in the pipeline
   * before reordering \param order Current position of this stage in the
   * pipeline after reordering (-1 if not yet assigned) \param stage Pipeline
   * stage number this operation belongs to (-1 if not yet assigned) \param
   * copy_stage Whether this stage is a memory copy operation \param
   * last_use_stage Last pipeline stage that uses the results of this stage (-1
   * if not yet determined)
   */
149
150
151
152
153
154
155
156
157
  struct PipelineStageInfo {
    Array<BufferRegion> reads, writes;
    int original_order;
    int order = -1, stage = -1;
    bool copy_stage = false;
    int last_use_stage = -1;
  };

  PipelineStageInfo MakePipelineStageInfo(Stmt stmt, int idx) {
158
159
160
161
    Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
                /*body*/ stmt);
    Array<Array<BufferRegion>> access =
        GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
162

163
164
165
166
    PipelineStageInfo pinfo;
    pinfo.reads = std::move(access[0]);
    pinfo.writes = std::move(access[1]);
    pinfo.original_order = idx;
167
    pinfo.copy_stage = GlobalCopyPatternDetector::Detect(stmt);
168
169
170
171

    return std::move(pinfo);
  }

172
  Stmt VisitStmt_(const ForNode *loop) final {
173
174
    auto order_anno = loop->annotations.Get("tl_pipeline_order");
    auto stage_anno = loop->annotations.Get("tl_pipeline_stage");
175
    auto num_stages_anno = loop->annotations.Get("num_stages");
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    if (order_anno.defined() && stage_anno.defined()) {
      Map<String, ObjectRef> annotations;
      for (const auto &[key, value] : loop->annotations) {
        if (key != "tl_pipeline_order") {
          annotations.Set(key, value);
        }
      }
      annotations.Set(tir::attr::software_pipeline_order, order_anno);

      for (const auto &[key, value] : loop->annotations) {
        if (key != "tl_pipeline_stage") {
          annotations.Set(key, value);
        }
      }
      annotations.Set(tir::attr::software_pipeline_stage, stage_anno);
      if (TargetHasAsyncCopy(target_))
        annotations.Set(tir::attr::software_pipeline_async_stages,
                        Array<Integer>{0});
      auto for_node = GetRef<For>(loop);
      for_node.CopyOnWrite()->annotations = annotations;
      return for_node;
    }

199
200
    if (!num_stages_anno.defined())
      return StmtExprMutator::VisitStmt_(loop);
201
202
    int num_stages = num_stages_anno.as<IntImmNode>()->value;
    Stmt pipeline_body{nullptr};
203
204
205
    if (const auto *realize = loop->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
206
207
208
        ICHECK(buffer->IsInstance<BufferNode>());
        buffer_data_to_buffer_.Set(buffer->data, buffer);
      }
209
210
211
212
213
214
215
216
217
218
219
220
      if (const auto *seq_stmt = block->body.as<SeqStmtNode>()) {
        pipeline_body = block->body;
      } else if (const auto *if_then_else = block->body.as<IfThenElseNode>()) {
        // should assert else case is nullptr
        ICHECK(!if_then_else->else_case.defined())
            << "Pipeline_Planning: Can't handle the body of the loop because "
               "it is not a SeqStmt";
        pipeline_body = if_then_else->then_case;
      } else {
        LOG(FATAL) << "Pipeline_Planning: Can't handle the body of the loop "
                      "because it is not a SeqStmt or IfThenElse";
      }
221
222
223
    } else {
      pipeline_body = loop->body;
    }
224
    const SeqStmtNode *pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
225
226
227
228
    CHECK(pipeline_body_seq)
        << "ValueError: The body of the software pipeline "
           "should be SeqStmt, got "
        << pipeline_body->GetTypeKey() << " " << pipeline_body;
229
230
231
232
233
234
235
236
237
238
    CHECK(num_stages >= 1);
    CHECK(loop->kind == ForKind::kSerial);

    std::vector<PipelineStageInfo> pipeline_stage_infos;
    for (size_t i = 0; i < pipeline_body_seq->size(); i++) {
      auto pinfo = MakePipelineStageInfo(pipeline_body_seq->seq[i], i);
      pipeline_stage_infos.push_back(std::move(pinfo));
    }

    // analysis use-def chain
239
240
241
242
243
244
245
246
247
248
249
    for (auto &pinfo : pipeline_stage_infos) {
      for (int i = pinfo.original_order + 1;
           i < static_cast<int>(pipeline_body_seq->size()); i++) {
        if (!pinfo.copy_stage)
          continue;
        for (const BufferRegion &read : pipeline_stage_infos[i].reads) {
          if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(),
                           [&](const BufferRegion &r) {
                             return r->buffer == read->buffer &&
                                    MayConflict(r->region, read->region);
                           }) != pinfo.writes.end()) {
250
251
252
            pinfo.last_use_stage = std::max(pinfo.last_use_stage, i);
          }
        }
253
254
255
256
257
258
        for (const BufferRegion &write : pipeline_stage_infos[i].writes) {
          if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(),
                           [&](const BufferRegion &r) {
                             return r->buffer == write->buffer &&
                                    MayConflict(r->region, write->region);
                           }) != pinfo.writes.end()) {
259
260
261
262
263
264
            LOG(FATAL) << "Pipeline planning error: Multiple writes to "
                          "overlapping buffer regions detected. "
                       << "Stage " << pinfo.original_order << " and stage " << i
                       << " are both writing to buffer '" << write->buffer->name
                       << "' with overlapping regions. This is not supported "
                          "in pipeline planning.";
265
266
267
268
269
270
271
          }
        }
      }
    }

    // Making stages and orders
    int order_idx = 0;
272
273
274
    for (auto &pinfo : pipeline_stage_infos) {
      if (pinfo.copy_stage && pinfo.last_use_stage != -1)
        continue;
275

276
277
      pinfo.order = order_idx++;
      pinfo.stage = num_stages;
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296

      bool used_by_copy = false;
      for (const auto &write : pinfo.writes) {
        for (const auto &other : pipeline_stage_infos) {
          if (other.copy_stage) {
            for (const auto &read : other.reads) {
              if (write->buffer == read->buffer &&
                  MayConflict(write->region, read->region)) {
                used_by_copy = true;
                break;
              }
            }
          }
        }
      }
      if (used_by_copy) {
        pinfo.stage = 0;
      }

297
298
299
      for (auto &pinfo_1 : pipeline_stage_infos) {
        if (pinfo_1.copy_stage &&
            pinfo_1.last_use_stage == pinfo.original_order) {
300
301
302
303
304
          pinfo_1.order = order_idx++;
          pinfo_1.stage = 0;
        }
      }
    }
305
306
307
308
309
310
311
312
313
314
    // process the tail copy stage
    auto &head_pinfo = pipeline_stage_infos.at(0);
    if (head_pinfo.order == -1) {
      for (auto &pinfo : pipeline_stage_infos) {
        pinfo.order++;
      }
      head_pinfo.stage = 0;
      order_idx++;
    }

315
316
317
318
319
    ICHECK(size_t(order_idx) == pipeline_stage_infos.size())
        << "The number of stages should be equal to the number of pipeline "
           "stages. "
        << "Got " << order_idx << " stages and " << pipeline_stage_infos.size()
        << " pipeline stages.";
320

321
322
    // if all the copy is at the end of the order, we can move these copy to the
    // beginning of the order and shrink the stage offset by 1.
323
324
325
326
    int copy_stage_at_end = [&]() {
      int copy_stage_cnt = 0;
      int copy_order_min = pipeline_stage_infos.size();
      int non_copy_order_max = 0;
327
      for (auto &pinfo : pipeline_stage_infos) {
328
329
330
331
332
333
334
        if (pinfo.copy_stage) {
          copy_stage_cnt++;
          copy_order_min = std::min(copy_order_min, pinfo.order);
        } else {
          non_copy_order_max = std::max(non_copy_order_max, pinfo.order);
        }
      }
335
336
      if (copy_order_min > non_copy_order_max)
        return copy_stage_cnt;
337
338
339
      return -1;
    }();
    if (copy_stage_at_end > 0 && num_stages >= 2) {
340
341
342
343
344
      for (auto &pinfo : pipeline_stage_infos) { // move copy to the beginning
        pinfo.order =
            (pinfo.order + copy_stage_at_end) % pipeline_stage_infos.size();
        if (!pinfo.copy_stage)
          pinfo.stage--;
345
346
347
348
349
      }
    }

    // Finally, make the pipeline annotation
    Map<String, ObjectRef> annotations;
350
    for (const auto &[key, value] : loop->annotations) {
351
352
353
354
355
356
357
358
      if (key != "num_stages") {
        annotations.Set(key, value);
      }
    }

    std::vector<Integer> orders, stages;
    orders.reserve(pipeline_stage_infos.size());
    stages.reserve(pipeline_stage_infos.size());
359
    for (auto &pinfo : pipeline_stage_infos) {
360
361
362
363
364
365
366
      orders.push_back(pinfo.order);
      stages.push_back(pinfo.stage);
    }

    annotations.Set(tir::attr::software_pipeline_stage, Array<Integer>(stages));
    annotations.Set(tir::attr::software_pipeline_order, Array<Integer>(orders));
    if (TargetHasAsyncCopy(target_))
367
368
      annotations.Set(tir::attr::software_pipeline_async_stages,
                      Array<Integer>{0});
369
370
371
372
373

    return For(loop->loop_var, loop->min, loop->extent, loop->kind, loop->body,
               loop->thread_binding, annotations);
  }

374
375
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const auto &buffer : op->alloc_buffers) {
376
377
378
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
379
    for (const auto &buffer : op->alloc_buffers) {
380
381
382
383
384
385
386
387
388
389
390
391
      buffer_data_to_buffer_.erase(buffer->data);
    }
    return std::move(block);
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Target target_;
};

tvm::transform::Pass PipelinePlanning() {
  using namespace tir::transform;
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
392
    PrimFuncNode *fptr = f.CopyOnWrite();
393
394
395
396
397
398
    fptr->body = PipelinePlanner::Substitute(f);
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.PipelinePlanning", {});
}

399
400
TVM_REGISTER_GLOBAL("tl.transform.PipelinePlanning")
    .set_body_typed(PipelinePlanning);
401

402
403
} // namespace tl
} // namespace tvm