pipeline_planning.cc 9.31 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file pipeline_planning.cc
 * \brief Plan the software pipeline
 */

#include <tvm/arith/analyzer.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

namespace {

/*!
 * \brief Check whether two regions have intersections.
 * \param region1 The first region.
 * \param region2 The second region.
 * \return Whether region1 and region2 have intersections.
 */
bool MayConflict(Region region1, Region region2) {
  ICHECK(region1.size() == region2.size());
  for (size_t i = 0; i < region1.size(); i++) {
    Range dim1 = region1[i];
    Range dim2 = region2[i];
    auto int_set1 = arith::IntSet::FromRange(dim1);
    auto int_set2 = arith::IntSet::FromRange(dim2);
    if (arith::Intersect({int_set1, int_set2}).IsNothing()) {
      return false;
    }
  }
  return true;
}

59
} // namespace
60
61

class PipelinePlanner : public StmtExprMutator {
62
63
public:
  static Stmt Substitute(const PrimFunc &f) {
64
    PipelinePlanner substituter;
65
    for (const auto &[_, buffer] : f->buffer_map) {
66
67
68
      substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
69
70
    ICHECK(target.defined())
        << "Pipeline_Planning: Require the target attribute";
71
72
73
74
    substituter.target_ = target.value();
    return substituter.VisitStmt(f->body);
  }

75
private:
76
77
78
79
80
81
82
83
84
85
86
  PipelinePlanner() = default;

  struct PipelineStageInfo {
    Array<BufferRegion> reads, writes;
    int original_order;
    int order = -1, stage = -1;
    bool copy_stage = false;
    int last_use_stage = -1;
  };

  PipelineStageInfo MakePipelineStageInfo(Stmt stmt, int idx) {
87
88
89
90
    Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
                /*body*/ stmt);
    Array<Array<BufferRegion>> access =
        GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
91
92
93
94
95
96
97
98
    PipelineStageInfo pinfo;
    pinfo.reads = std::move(access[0]);
    pinfo.writes = std::move(access[1]);
    pinfo.original_order = idx;

    // copy stage should only have one reads and one writes
    if (pinfo.reads.size() == 1 && pinfo.writes.size() == 1) {
      for (auto region : pinfo.reads)
99
100
        if (region->buffer.scope() == "global")
          pinfo.copy_stage = true;
101
      for (auto region : pinfo.writes)
102
103
        if (region->buffer.scope() == "global")
          pinfo.copy_stage = true;
104
105
106
107
108
    }

    return std::move(pinfo);
  }

109
  Stmt VisitStmt_(const ForNode *loop) final {
110
    auto num_stages_anno = loop->annotations.Get("num_stages");
111
112
    if (!num_stages_anno.defined())
      return StmtExprMutator::VisitStmt_(loop);
113
114
    int num_stages = num_stages_anno.as<IntImmNode>()->value;
    Stmt pipeline_body{nullptr};
115
116
117
    if (const auto *realize = loop->body.as<BlockRealizeNode>()) {
      const auto &block = realize->block;
      for (const auto &buffer : block->alloc_buffers) {
118
119
120
121
122
123
124
        ICHECK(buffer->IsInstance<BufferNode>());
        buffer_data_to_buffer_.Set(buffer->data, buffer);
      }
      pipeline_body = block->body;
    } else {
      pipeline_body = loop->body;
    }
125
126
127
128
    const SeqStmtNode *pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
    CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
                                "should be SeqStmt, got "
                             << loop->body->GetTypeKey();
129
130
131
132
133
134
135
136
137
138
    CHECK(num_stages >= 1);
    CHECK(loop->kind == ForKind::kSerial);

    std::vector<PipelineStageInfo> pipeline_stage_infos;
    for (size_t i = 0; i < pipeline_body_seq->size(); i++) {
      auto pinfo = MakePipelineStageInfo(pipeline_body_seq->seq[i], i);
      pipeline_stage_infos.push_back(std::move(pinfo));
    }

    // analysis use-def chain
139
140
141
142
143
144
145
146
147
148
149
    for (auto &pinfo : pipeline_stage_infos) {
      for (int i = pinfo.original_order + 1;
           i < static_cast<int>(pipeline_body_seq->size()); i++) {
        if (!pinfo.copy_stage)
          continue;
        for (const BufferRegion &read : pipeline_stage_infos[i].reads) {
          if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(),
                           [&](const BufferRegion &r) {
                             return r->buffer == read->buffer &&
                                    MayConflict(r->region, read->region);
                           }) != pinfo.writes.end()) {
150
151
152
            pinfo.last_use_stage = std::max(pinfo.last_use_stage, i);
          }
        }
153
154
155
156
157
158
159
160
        for (const BufferRegion &write : pipeline_stage_infos[i].writes) {
          if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(),
                           [&](const BufferRegion &r) {
                             return r->buffer == write->buffer &&
                                    MayConflict(r->region, write->region);
                           }) != pinfo.writes.end()) {
            CHECK(false) << "Can't handle multiple write on overlap buffer "
                            "region in the pipeline "
161
162
163
164
165
166
167
168
169
                            "planning pass: "
                         << pipeline_body_seq->seq[pinfo.original_order];
          }
        }
      }
    }

    // Making stages and orders
    int order_idx = 0;
170
171
172
    for (auto &pinfo : pipeline_stage_infos) {
      if (pinfo.copy_stage && pinfo.last_use_stage != -1)
        continue;
173
174
      pinfo.order = order_idx++;
      pinfo.stage = num_stages;
175
176
177
      for (auto &pinfo_1 : pipeline_stage_infos) {
        if (pinfo_1.copy_stage &&
            pinfo_1.last_use_stage == pinfo.original_order) {
178
179
180
181
182
          pinfo_1.order = order_idx++;
          pinfo_1.stage = 0;
        }
      }
    }
183
184
185
186
187
    ICHECK(size_t(order_idx) == pipeline_stage_infos.size())
        << "The number of stages should be equal to the number of pipeline "
           "stages. "
        << "Got " << order_idx << " stages and " << pipeline_stage_infos.size()
        << " pipeline stages.";
188

189
190
    // if all the copy is at the end of the order, we can move these copy to the
    // beginning of the order and shrink the stage offset by 1.
191
192
193
194
    int copy_stage_at_end = [&]() {
      int copy_stage_cnt = 0;
      int copy_order_min = pipeline_stage_infos.size();
      int non_copy_order_max = 0;
195
      for (auto &pinfo : pipeline_stage_infos) {
196
197
198
199
200
201
202
        if (pinfo.copy_stage) {
          copy_stage_cnt++;
          copy_order_min = std::min(copy_order_min, pinfo.order);
        } else {
          non_copy_order_max = std::max(non_copy_order_max, pinfo.order);
        }
      }
203
204
      if (copy_order_min > non_copy_order_max)
        return copy_stage_cnt;
205
206
207
      return -1;
    }();
    if (copy_stage_at_end > 0 && num_stages >= 2) {
208
209
210
211
212
      for (auto &pinfo : pipeline_stage_infos) { // move copy to the beginning
        pinfo.order =
            (pinfo.order + copy_stage_at_end) % pipeline_stage_infos.size();
        if (!pinfo.copy_stage)
          pinfo.stage--;
213
214
215
216
217
      }
    }

    // Finally, make the pipeline annotation
    Map<String, ObjectRef> annotations;
218
    for (const auto &[key, value] : loop->annotations) {
219
220
221
222
223
224
225
226
      if (key != "num_stages") {
        annotations.Set(key, value);
      }
    }

    std::vector<Integer> orders, stages;
    orders.reserve(pipeline_stage_infos.size());
    stages.reserve(pipeline_stage_infos.size());
227
    for (auto &pinfo : pipeline_stage_infos) {
228
229
230
231
232
233
234
      orders.push_back(pinfo.order);
      stages.push_back(pinfo.stage);
    }

    annotations.Set(tir::attr::software_pipeline_stage, Array<Integer>(stages));
    annotations.Set(tir::attr::software_pipeline_order, Array<Integer>(orders));
    if (TargetHasAsyncCopy(target_))
235
236
      annotations.Set(tir::attr::software_pipeline_async_stages,
                      Array<Integer>{0});
237
238
239
240
241

    return For(loop->loop_var, loop->min, loop->extent, loop->kind, loop->body,
               loop->thread_binding, annotations);
  }

242
243
  Stmt VisitStmt_(const BlockNode *op) final {
    for (const auto &buffer : op->alloc_buffers) {
244
245
246
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
247
    for (const auto &buffer : op->alloc_buffers) {
248
249
250
251
252
253
254
255
256
257
258
259
      buffer_data_to_buffer_.erase(buffer->data);
    }
    return std::move(block);
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  Target target_;
};

tvm::transform::Pass PipelinePlanning() {
  using namespace tir::transform;
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
260
    PrimFuncNode *fptr = f.CopyOnWrite();
261
262
263
264
265
266
    fptr->body = PipelinePlanner::Substitute(f);
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.PipelinePlanning", {});
}

267
268
TVM_REGISTER_GLOBAL("tl.transform.PipelinePlanning")
    .set_body_typed(PipelinePlanning);
269

270
271
} // namespace tl
} // namespace tvm