layout_inference.cc 9.77 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file layout_inference.cc
 * \brief infer the fragment/shared memory layout
 */

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>

#include <queue>

#include "../op/parallel.h"
34
35
#include "arith/ir_mutator_with_analyzer.h"
#include "common/loop_fusion_utils.h"
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include "loop_partition.h"
#include "loop_vectorize.h"

namespace tvm {
namespace tl {

using namespace tir;
using arith::IRMutatorWithAnalyzer;

struct LayoutInferenceResult {
  Map<Buffer, Layout> layout_map;
  Map<For, Fragment> for_map;
  Map<For, PrimExpr> predicate_map;
};

class BufferUseDefCollector : public StmtExprVisitor {
52
public:
53
54
55
56
57
58
59
60
61
  BufferUseDefCollector() = default;

  LayoutInferenceResult Run() {
    Map<Buffer, Layout> layout_map = annotated_layout_map_;
    int num_infer = infer_list_.size();

    // maintain a bfs queue and infer common layout
    std::queue<int> q;
    std::vector<bool> in_queue(num_infer, true);
62
63
    for (int i = 0; i < num_infer; i++)
      q.push(i);
64

65
66
67
    auto run_infer_step = [&](int cur_infer_id, InferLevel level,
                              bool update_queue) {
      auto &next = infer_list_[cur_infer_id];
68
69
      auto iter_var = thread_var_vec_[cur_infer_id];
      auto updates = next->InferLayout(
70
71
72
73
          LayoutInferArgs{
              target_,
              static_cast<size_t>(*as_const_int(iter_var->dom->extent)),
              layout_map},
74
          level);
75
      for (const auto &[buffer, layout] : updates) {
76
77
78
79
80
        if (layout_map.count(buffer)) {
          ICHECK(StructuralEqual()(layout, layout_map[buffer]))
              << "Get different layout for " << buffer;
        } else {
          layout_map.Set(buffer, layout);
81
82
          if (!update_queue)
            continue;
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
          for (int idx : use_list_[buffer]) {
            if (!in_queue[idx] && idx != cur_infer_id) {
              in_queue[idx] = true;
              q.push(idx);
            }
          }
        }
      }
    };
    auto finish_infer_queue = [&]() {
      while (!q.empty()) {
        int cur_infer_id = q.front();
        q.pop();
        in_queue[cur_infer_id] = false;
        run_infer_step(cur_infer_id, InferLevel::kCommon, true);
      }
    };

    // step 1, infer strict layout
    for (int i = 0; i < num_infer; i++) {
      run_infer_step(i, InferLevel::kStrict, false);
    }

    // step2, infer common layout with bfs
    finish_infer_queue();

    // step 3, relax the infer constraint to free and rerun.
    for (int i = 0; i < num_infer; i++) {
      run_infer_step(i, InferLevel::kFree, true);
      finish_infer_queue();
    }

    // Check that all fragments have been inferred
116
    for (const auto &[buffer, _] : use_list_) {
117
      if (buffer.scope() == "local.fragment" && layout_map.count(buffer) == 0)
118
119
        LOG_ERROR << "The layout for fragment " << buffer
                  << " can not be inferred correctly.";
120
121
122
123
124
    }

    // Collect the layout for for nodes
    Map<For, Fragment> for_map;
    Map<For, PrimExpr> predicate_map;
125
126
    for (auto &base_infer : infer_list_) {
      if (auto for_infer = dynamic_cast<ParallelOp *>(base_infer.get())) {
127
128
129
130
131
132
133
134
135
136
137
138
        ICHECK(for_infer->GetLoopLayout().defined())
            << "The Layout for Parallel for can not be inferred correctly : \n"
            << for_infer->GetRoot();
        for_map.Set(for_infer->GetRoot(), for_infer->GetLoopLayout());
        if (auto predicate = for_infer->GetPredicate(thread_var_->var))
          predicate_map.Set(for_infer->GetRoot(), predicate.value());
      }
    }

    return {layout_map, for_map, predicate_map};
  }

139
140
  void Collect(const PrimFunc &f) {
    for (const auto &[_, buffer] : f->buffer_map) {
141
142
143
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
144
145
    ICHECK(target.defined())
        << "Layout_Inference: Require the target attribute";
146
147
148
149
    target_ = target.value();
    this->operator()(f->body);
  }

150
151
private:
  void VisitExpr_(const CallNode *op) final {
152
153
    StmtExprVisitor::VisitExpr_(op);
    // Do not analysis the call node to the global function.
154
155
    if (op->op.as<GlobalVarNode>())
      return;
156
157
158

    auto p = ParseOperator(GetRef<Call>(op), buffer_data_to_buffer_);
    if (p != nullptr) {
159
      for (const auto &arg : op->args) {
160
161
162
163
164
165
166
167
168
        if (auto buffer = getBufferFromAccessPtr(arg)) {
          addToUseList(buffer.value());
        }
      }
      infer_list_.push_back(std::move(p));
      thread_var_vec_.push_back(thread_var_);
    }
  }

169
  Optional<Buffer> getBufferFromAccessPtr(const PrimExpr &expr) {
170
171
172
173
174
175
176
177
    auto call = expr.as<CallNode>();
    if (call && call->op.same_as(builtin::tvm_access_ptr())) {
      auto var = call->args[1].as<Var>().value();
      return buffer_data_to_buffer_[var];
    }
    return NullOpt;
  }

178
  void addToUseList(const Buffer &buffer) {
179
180
181
182
183
184
185
    int infer_idx = infer_list_.size();
    if (use_list_.find(buffer) == use_list_.end()) {
      use_list_[buffer] = {};
    }
    use_list_[buffer].push_back(infer_idx);
  }

186
  void VisitStmt_(const ForNode *op) final {
187
188
    if (op->kind == ForKind::kParallel) {
      auto infer = std::make_unique<ParallelOp>(GetRef<For>(op));
189
      for (const auto &[buffer, _] : infer->GetIndiceMap()) {
190
191
192
193
194
195
196
197
198
        addToUseList(buffer);
      }
      infer_list_.push_back(std::move(infer));
      thread_var_vec_.push_back(thread_var_);
    } else {
      StmtExprVisitor::VisitStmt(op->body);
    }
  }

199
  void VisitStmt_(const BlockNode *op) final {
200
201
202
203
    for (auto buffer : op->alloc_buffers) {
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    if (op->annotations.count(attr::kLayoutMap)) {
204
205
206
      auto map =
          op->annotations.Get(attr::kLayoutMap).as<Map<Var, Layout>>().value();
      for (const auto &[var, layout] : map) {
207
208
209
210
211
212
213
214
        auto buffer = buffer_data_to_buffer_[var];
        ICHECK(StructuralEqual()(layout->InputShape(), buffer->shape));
        annotated_layout_map_.Set(buffer, layout);
      }
    }
    StmtExprVisitor::VisitStmt_(op);
  }

215
  void VisitStmt_(const AttrStmtNode *op) final {
216
217
218
219
220
221
222
223
224
225
226
227
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      if (iv->thread_tag == "threadIdx.x") {
        ICHECK(iv->dom->extent.as<IntImmNode>());
        thread_var_ = iv;
      }
    }
    StmtExprVisitor::VisitStmt_(op);
  }

  Map<Var, Buffer> buffer_data_to_buffer_;
  std::vector<std::unique_ptr<Operator>> infer_list_;
228
229
  std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
      use_list_;
230
231
232
233
234
235
236
  IterVar thread_var_;
  std::vector<IterVar> thread_var_vec_;
  Target target_;
  LayoutMap annotated_layout_map_;
};

class LayoutInferencer : public IRMutatorWithAnalyzer {
237
public:
238
239
  static PrimFunc Substitute(PrimFunc f) {
    arith::Analyzer analyzer;
240
    PrimFuncNode *fptr = f.CopyOnWrite();
241
242
243
244
245
246
247
248
249
    fptr->body = ParallelLoopFuser::Fuse(f->body);
    BufferUseDefCollector collector;
    collector.Collect(f);
    auto result = collector.Run();
    LayoutInferencer substituter(result, &analyzer);
    fptr->body = substituter.VisitStmt(f->body);
    return f;
  }

250
251
252
253
private:
  LayoutInferencer(const LayoutInferenceResult result,
                   arith::Analyzer *analyzer)
      : arith::IRMutatorWithAnalyzer(analyzer), result_(result){};
254

255
  Stmt VisitStmt_(const BlockNode *op) final {
256
257
258
259
260
261
262
263
264
265
266
267
268
    Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));

    for (auto buffer : block->alloc_buffers) {
      if (buffer.scope() == "local.framgent") {
        ICHECK(result_.layout_map.count(buffer))
            << "Cannot inference fragment layout for " << buffer;
      }
    }
    auto block_ptr = block.CopyOnWrite();
    block_ptr->annotations.Set(attr::kLayoutMap, result_.layout_map);
    return block;
  }

269
  Stmt VisitStmt_(const ForNode *op) final {
270
271
272
    For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
    if (result_.for_map.count(GetRef<For>(op))) {
      auto loop_layout = result_.for_map[GetRef<For>(op)];
273
274
      for_node =
          PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
275
276
277
278
279
280
281
282
283
284
      for_node = VectorizeLoop(for_node);
      if (result_.predicate_map.count(GetRef<For>(op))) {
        return IfThenElse(result_.predicate_map[GetRef<For>(op)], for_node);
      } else {
        return for_node;
      }
    }
    return for_node;
  }

285
  Stmt VisitStmt_(const AttrStmtNode *op) final {
286
287
288
289
290
291
292
293
294
295
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      ICHECK_NE(iv->thread_tag.length(), 0U);
      if (iv->thread_tag == "threadIdx.x") {
        thread_var_ = iv;
      }
    }
    return IRMutatorWithAnalyzer::VisitStmt_(op);
  }

296
private:
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
  const LayoutInferenceResult result_;
  IterVar thread_var_;
};

tvm::transform::Pass LayoutInference() {
  using namespace tir::transform;
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
    return LayoutInferencer::Substitute(std::move(f));
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {});
}

TVM_REGISTER_GLOBAL("tl.transform.LayoutInference")
    .set_body_typed(LayoutInference);

312
313
} // namespace tl
} // namespace tvm