multi_version_buffer_rewriter.cc 10.7 KB
Newer Older
1
2
3
4
5
/*!
 * \file warp_specialized_pipeline.cc
 * \brief Warp specialized Pipeline for cuda GPU (sm90+)
 */

6
#include <tvm/ffi/reflection/registry.h>
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../op/builtin.h"

namespace tvm {
namespace tl {

using namespace tir;

enum class Role { kConsumer, kProducer, kBoth };

class WarpSpecializedRoleMarker_ : public StmtVisitor {
23
public:
24
25
26
  WarpSpecializedRoleMarker_(Map<Var, Buffer> buffer_data_to_buffer)
      : buffer_data_to_buffer_(buffer_data_to_buffer) {}

27
  Role GetRole(const StmtNode *stmt) const {
28
29
30
31
32
    auto it = map_.find(stmt);
    ICHECK(it != map_.end());
    return it->second;
  }

33
  Role GetRole(const Stmt &stmt) const { return GetRole(stmt.get()); }
34

35
  void VisitStmt_(const EvaluateNode *op) final {
36
37
    Role role = Role::kConsumer;
    if (auto call = op->value.as<CallNode>()) {
38
      if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
39
40
41
42
43
44
45
        role = Role::kProducer;
        has_bulk_copy_ = true;
      }
    }
    SetRole(op, role);
  }

46
47
48
  void VisitStmt_(const BufferStoreNode *op) final {
    bool is_shared_store =
        op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared";
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    if (!is_shared_store) {
      SetRole(op, Role::kConsumer);
      return;
    }

    // Check reads from global
    Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
                /*body*/ GetRef<Stmt>(op));
    auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
    auto reads = access[0];
    Role role = Role::kProducer;
    for (auto read : reads) {
      if (read->buffer.scope() != "global") {
        role = Role::kConsumer;
        break;
      }
    }
66
67
    if (role == Role::kProducer)
      has_simt_copy_ = true;
68
69
70
    SetRole(op, role);
  }

71
  void VisitStmt_(const SeqStmtNode *op) final {
72
73
74
75
76
77
78
79
80
81
82
    StmtVisitor::VisitStmt_(op);
    auto role = GetRole(op->seq[0]);
    for (auto stmt : op->seq) {
      if (role != GetRole(stmt)) {
        role = Role::kBoth;
        break;
      }
    }
    SetRole(op, role);
  }

83
  void VisitStmt_(const IfThenElseNode *op) final {
84
85
86
87
    StmtVisitor::VisitStmt_(op);
    auto role = GetRole(op->then_case);
    if (op->else_case.defined()) {
      auto role_else = GetRole(op->else_case.value());
88
89
      if (role != role_else)
        role = Role::kBoth;
90
91
92
93
    }
    SetRole(op, role);
  }

94
  void VisitStmt_(const BlockRealizeNode *op) final {
95
96
97
98
    StmtVisitor::VisitStmt_(op);
    SetRole(op, GetRole(op->block));
  }

99
  template <class NodeType> void HandleBodyStmt(const NodeType *op) {
100
101
102
103
    StmtVisitor::VisitStmt_(op);
    SetRole(op, GetRole(op->body));
  }

104
105
106
107
108
  void VisitStmt_(const ForNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const LetStmtNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const AttrStmtNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const AssertStmtNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const BlockNode *op) final { HandleBodyStmt(op); }
109
110
111
112
113

  bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; }

  bool HasSimtCopy() { return has_simt_copy_; }

114
115
private:
  void SetRole(const StmtNode *stmt, Role role) { map_[stmt] = role; }
116
  Map<Var, Buffer> buffer_data_to_buffer_;
117
  std::unordered_map<const StmtNode *, Role> map_;
118
119
120
121
122
  bool has_simt_copy_ = false;
  bool has_bulk_copy_ = false;
};

class MultiVersionBufferRewriter : public StmtExprMutator {
123
124
public:
  static PrimFunc Substitute(PrimFunc &f) {
125
126
127
128
129
130
131
132
133
134
    auto rewriter = MultiVersionBufferRewriter();
    rewriter.buffer_lca_ = DetectBufferAccessLCA(f);
    for (auto [buffer, _] : rewriter.buffer_lca_) {
      Var buffer_var = buffer->data;
      rewriter.buffer_data_to_buffer_.Set(buffer_var, buffer);
    }
    f.CopyOnWrite()->body = rewriter(f->body);
    return f;
  }

135
private:
136
137
  MultiVersionBufferRewriter() = default;

138
139
  Array<Buffer> GetVersionedBuffers(Array<Stmt> seq_stmt,
                                    Array<Buffer> scoped_buffers) {
140
141
142
143
144
    std::vector<Role> roles;
    Array<Array<BufferRegion>> reads, writes;
    auto marker = WarpSpecializedRoleMarker_(buffer_data_to_buffer_);
    for (auto stmt : seq_stmt) {
      marker(stmt);
145
146
      Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
                  /*name_hint=*/"", /*body*/ stmt);
147
148
149
150
151
152
      auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
      reads.push_back(std::move(access[0]));
      writes.push_back(std::move(access[1]));
      roles.push_back(marker.GetRole(stmt));
    }

153
    std::unordered_set<const BufferNode *> consumer_used, producer_used;
154
155
    for (size_t i = 0; i < seq_stmt.size(); i++) {
      if (roles[i] == Role::kProducer) {
156
157
        for (BufferRegion br : writes[i])
          producer_used.insert(br->buffer.get());
158
      } else {
159
160
        for (BufferRegion br : reads[i])
          consumer_used.insert(br->buffer.get());
161
162
163
164
      }
    }
    Array<Buffer> versioned_buffers;
    for (Buffer buffer : scoped_buffers) {
165
166
      if (consumer_used.count(buffer.get()) &&
          producer_used.count(buffer.get())) {
167
168
169
170
171
172
        versioned_buffers.push_back(buffer);
      }
    }
    return versioned_buffers;
  }

173
  static Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
174
175
176
177
178
179
180
181
182
183
    ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get()));
    new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
    if (new_buffer->strides.size()) {
      ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
      PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1];
      new_buffer->strides.insert(new_buffer->strides.begin(), stride_0);
    }
    return Buffer(new_buffer);
  }

184
185
186
  Stmt VisitStmt_(const BlockRealizeNode *op) final {
    BlockRealize block_realize =
        Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    Block block = block_realize->block;
    Array<Buffer> alloc_buffers;
    for (auto buffer : block->alloc_buffers) {
      if (buffer_remap_.count(buffer)) {
        Buffer new_buffer = buffer_remap_[buffer];
        alloc_buffers.push_back(new_buffer);
      } else {
        alloc_buffers.push_back(buffer);
      }
    }
    block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers);
    block_realize.CopyOnWrite()->block = block;
    return block_realize;
  }

202
  Stmt VisitStmt_(const ForNode *op) final {
203
    loop_stack_.emplace_back(op->loop_var, op->extent);
204
    auto num_stages_anno = op->annotations.Get("num_stages");
205
    if (!num_stages_anno) {
206
207
208
209
      auto for_node = StmtExprMutator::VisitStmt_(op);
      loop_stack_.pop_back();
      return for_node;
    }
210

211
212
    ICHECK(num_stages_anno->as<IntImmNode>());
    int num_stages = static_cast<int>(num_stages_anno->as<IntImmNode>()->value);
213

214
215
216
217
    const SeqStmtNode *pipeline_body_seq = op->body.as<SeqStmtNode>();
    CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
                                "should be SeqStmt, got "
                             << op->body->GetTypeKey();
218
219
220

    Array<Buffer> scoped_buffers = {};
    for (auto [buffer, stmt] : buffer_lca_) {
221
222
      if (stmt.defined() && stmt.value().get() == op)
        scoped_buffers.push_back(buffer);
223
224
    }

225
226
    Array<Buffer> versioned_buffers =
        GetVersionedBuffers(pipeline_body_seq->seq, scoped_buffers);
227
228
229
230
231
232

    for (auto buffer : versioned_buffers) {
      Var buffer_var = buffer->data;
      Buffer new_buffer = RewriteAllocBuffer(buffer, num_stages);
      buffer_remap_.Set(buffer, new_buffer);
    }
233
234
235
236
237
238
    PrimExpr linear_index = loop_stack_[0].first;
    for (size_t i = 1; i < loop_stack_.size(); ++i) {
      linear_index =
          linear_index * loop_stack_[i].second + loop_stack_[i].first;
    }
    version_index_ = FloorMod(linear_index, num_stages);
239
    auto for_node = StmtExprMutator::VisitStmt_(op);
240
    loop_stack_.pop_back();
241
242
243
244

    return for_node;
  }

245
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
246
247
248
249
250
    BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
    auto it = buffer_remap_.find(load->buffer);
    if (it == buffer_remap_.end()) {
      return std::move(load);
    }
251
252
    const Buffer &new_buffer = (*it).second;
    auto *n = load.CopyOnWrite();
253
254
255
256
257
    n->buffer = new_buffer;
    n->indices.insert(n->indices.begin(), version_index_);
    return std::move(load);
  }

258
  Stmt VisitStmt_(const BufferStoreNode *op) final {
259
260
261
262
263
    BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
    auto it = buffer_remap_.find(store->buffer);
    if (it == buffer_remap_.end()) {
      return std::move(store);
    }
264
265
    const Buffer &new_buffer = (*it).second;
    auto *n = store.CopyOnWrite();
266
267
268
269
270
    n->buffer = new_buffer;
    n->indices.insert(n->indices.begin(), version_index_);
    return std::move(store);
  }

271
  PrimExpr VisitExpr_(const CallNode *op) final {
272
273
274
275
276
277
278
    Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
    if (call->op.same_as(builtin::tvm_access_ptr())) {
      return RewriteBufferAccess(call, {1});
    }
    return call;
  }

279
280
281
282
283
284
  PrimExpr RewriteBufferAccess(const Call &call,
                               const std::vector<int> arg_indices) {
    auto product = [](const Array<PrimExpr> &input) {
      return foldl(
          [](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
          make_const(DataType::Int(32), 1), input);
285
286
287
288
    };
    Array<PrimExpr> new_args = call->args;
    for (int i : arg_indices) {
      auto buffer_var = Downcast<Var>(call->args[i]);
289
290
291
      if (!buffer_data_to_buffer_.count(buffer_var))
        continue;
      const Buffer &buffer = buffer_data_to_buffer_[buffer_var];
292
293
      auto it = buffer_remap_.find(buffer);
      if (it != buffer_remap_.end()) {
294
295
        const Buffer &new_buffer = (*it).second;
        const PrimExpr &old_index = call->args[i + 1];
296
297
298
299
300
301
302
303
304
305
306
307
308
309
        PrimExpr offset;
        if (new_buffer->strides.empty()) {
          offset = product(buffer->shape);
        } else {
          offset = new_buffer->strides[0];
        }
        PrimExpr new_index = old_index + version_index_ * offset;
        new_args.Set(i + 1, new_index);
      }
    }
    return Call(call->dtype, call->op, new_args, call->span);
  }

  PrimExpr version_index_;
310
  std::vector<std::pair<Var, PrimExpr>> loop_stack_;
311
312
313
314
315
316
317
318
319
320
321
322
323
324
  Map<Var, Buffer> buffer_data_to_buffer_;
  Map<Buffer, Optional<Stmt>> buffer_lca_;
  Map<Buffer, Buffer> buffer_remap_;
};

using namespace tir::transform;

tvm::transform::Pass MultiVersionBuffer() {
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
    return MultiVersionBufferRewriter::Substitute(f);
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.MultiVersionBuffer", {});
}

325
326
327
328
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.MultiVersionBuffer", MultiVersionBuffer);
});
329

330
331
} // namespace tl
} // namespace tvm