"testing/vscode:/vscode.git/clone" did not exist on "e59e7f9adc570ca5c7330b418df8a0e867e58d32"
multi_version_buffer_rewriter.cc 10.8 KB
Newer Older
1
2
3
4
5
/*!
 * \file warp_specialized_pipeline.cc
 * \brief Warp specialized Pipeline for cuda GPU (sm90+)
 */

6
#include <tvm/ffi/reflection/registry.h>
7
8
9
10
11
12
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

13
14
#include <utility>

15
16
17
18
19
20
21
#include "../op/builtin.h"

namespace tvm {
namespace tl {

using namespace tir;

22
enum class Role : uint8_t { kConsumer, kProducer, kBoth };
23
24

class WarpSpecializedRoleMarker_ : public StmtVisitor {
25
public:
26
  WarpSpecializedRoleMarker_(Map<Var, Buffer> buffer_data_to_buffer)
27
      : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)) {}
28

29
  Role GetRole(const StmtNode *stmt) const {
30
31
32
33
34
    auto it = map_.find(stmt);
    ICHECK(it != map_.end());
    return it->second;
  }

35
  Role GetRole(const Stmt &stmt) const { return GetRole(stmt.get()); }
36

37
  void VisitStmt_(const EvaluateNode *op) final {
38
39
    Role role = Role::kConsumer;
    if (auto call = op->value.as<CallNode>()) {
40
      if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
41
42
43
44
45
46
47
        role = Role::kProducer;
        has_bulk_copy_ = true;
      }
    }
    SetRole(op, role);
  }

48
49
50
  void VisitStmt_(const BufferStoreNode *op) final {
    bool is_shared_store =
        op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared";
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    if (!is_shared_store) {
      SetRole(op, Role::kConsumer);
      return;
    }

    // Check reads from global
    Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
                /*body*/ GetRef<Stmt>(op));
    auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
    auto reads = access[0];
    Role role = Role::kProducer;
    for (auto read : reads) {
      if (read->buffer.scope() != "global") {
        role = Role::kConsumer;
        break;
      }
    }
68
69
    if (role == Role::kProducer)
      has_simt_copy_ = true;
70
71
72
    SetRole(op, role);
  }

73
  void VisitStmt_(const SeqStmtNode *op) final {
74
75
76
77
78
79
80
81
82
83
84
    StmtVisitor::VisitStmt_(op);
    auto role = GetRole(op->seq[0]);
    for (auto stmt : op->seq) {
      if (role != GetRole(stmt)) {
        role = Role::kBoth;
        break;
      }
    }
    SetRole(op, role);
  }

85
  void VisitStmt_(const IfThenElseNode *op) final {
86
87
88
89
    StmtVisitor::VisitStmt_(op);
    auto role = GetRole(op->then_case);
    if (op->else_case.defined()) {
      auto role_else = GetRole(op->else_case.value());
90
91
      if (role != role_else)
        role = Role::kBoth;
92
93
94
95
    }
    SetRole(op, role);
  }

96
  void VisitStmt_(const BlockRealizeNode *op) final {
97
98
99
100
    StmtVisitor::VisitStmt_(op);
    SetRole(op, GetRole(op->block));
  }

101
  template <class NodeType> void HandleBodyStmt(const NodeType *op) {
102
103
104
105
    StmtVisitor::VisitStmt_(op);
    SetRole(op, GetRole(op->body));
  }

106
107
108
109
110
  void VisitStmt_(const ForNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const LetStmtNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const AttrStmtNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const AssertStmtNode *op) final { HandleBodyStmt(op); }
  void VisitStmt_(const BlockNode *op) final { HandleBodyStmt(op); }
111
112
113
114
115

  bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; }

  bool HasSimtCopy() { return has_simt_copy_; }

116
117
private:
  void SetRole(const StmtNode *stmt, Role role) { map_[stmt] = role; }
118
  Map<Var, Buffer> buffer_data_to_buffer_;
119
  std::unordered_map<const StmtNode *, Role> map_;
120
121
122
123
124
  bool has_simt_copy_ = false;
  bool has_bulk_copy_ = false;
};

class MultiVersionBufferRewriter : public StmtExprMutator {
125
126
public:
  static PrimFunc Substitute(PrimFunc &f) {
127
128
129
130
131
132
133
134
135
136
    auto rewriter = MultiVersionBufferRewriter();
    rewriter.buffer_lca_ = DetectBufferAccessLCA(f);
    for (auto [buffer, _] : rewriter.buffer_lca_) {
      Var buffer_var = buffer->data;
      rewriter.buffer_data_to_buffer_.Set(buffer_var, buffer);
    }
    f.CopyOnWrite()->body = rewriter(f->body);
    return f;
  }

137
private:
138
139
  MultiVersionBufferRewriter() = default;

140
141
  Array<Buffer> GetVersionedBuffers(const Array<Stmt> &seq_stmt,
                                    const Array<Buffer> &scoped_buffers) {
142
143
144
145
146
    std::vector<Role> roles;
    Array<Array<BufferRegion>> reads, writes;
    auto marker = WarpSpecializedRoleMarker_(buffer_data_to_buffer_);
    for (auto stmt : seq_stmt) {
      marker(stmt);
147
148
      Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
                  /*name_hint=*/"", /*body*/ stmt);
149
      auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
150
151
      reads.push_back(access[0]);
      writes.push_back(access[1]);
152
153
154
      roles.push_back(marker.GetRole(stmt));
    }

155
    std::unordered_set<const BufferNode *> consumer_used, producer_used;
156
157
    for (size_t i = 0; i < seq_stmt.size(); i++) {
      if (roles[i] == Role::kProducer) {
158
159
        for (BufferRegion br : writes[i])
          producer_used.insert(br->buffer.get());
160
      } else {
161
162
        for (BufferRegion br : reads[i])
          consumer_used.insert(br->buffer.get());
163
164
165
166
      }
    }
    Array<Buffer> versioned_buffers;
    for (Buffer buffer : scoped_buffers) {
167
168
      if (consumer_used.count(buffer.get()) &&
          producer_used.count(buffer.get())) {
169
170
171
172
173
174
        versioned_buffers.push_back(buffer);
      }
    }
    return versioned_buffers;
  }

175
  static Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
176
177
    ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get()));
    new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
178
    if (!new_buffer->strides.empty()) {
179
180
181
182
183
184
185
      ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
      PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1];
      new_buffer->strides.insert(new_buffer->strides.begin(), stride_0);
    }
    return Buffer(new_buffer);
  }

186
187
188
  Stmt VisitStmt_(const BlockRealizeNode *op) final {
    BlockRealize block_realize =
        Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    Block block = block_realize->block;
    Array<Buffer> alloc_buffers;
    for (auto buffer : block->alloc_buffers) {
      if (buffer_remap_.count(buffer)) {
        Buffer new_buffer = buffer_remap_[buffer];
        alloc_buffers.push_back(new_buffer);
      } else {
        alloc_buffers.push_back(buffer);
      }
    }
    block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers);
    block_realize.CopyOnWrite()->block = block;
    return block_realize;
  }

204
  Stmt VisitStmt_(const ForNode *op) final {
205
    loop_stack_.emplace_back(op->loop_var, op->extent);
206
    auto num_stages_anno = op->annotations.Get("num_stages");
207
    if (!num_stages_anno) {
208
209
210
211
      auto for_node = StmtExprMutator::VisitStmt_(op);
      loop_stack_.pop_back();
      return for_node;
    }
212

213
214
    ICHECK(num_stages_anno->as<IntImmNode>());
    int num_stages = static_cast<int>(num_stages_anno->as<IntImmNode>()->value);
215

216
217
218
219
    const SeqStmtNode *pipeline_body_seq = op->body.as<SeqStmtNode>();
    CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
                                "should be SeqStmt, got "
                             << op->body->GetTypeKey();
220
221
222

    Array<Buffer> scoped_buffers = {};
    for (auto [buffer, stmt] : buffer_lca_) {
223
224
      if (stmt.defined() && stmt.value().get() == op)
        scoped_buffers.push_back(buffer);
225
226
    }

227
228
    Array<Buffer> versioned_buffers =
        GetVersionedBuffers(pipeline_body_seq->seq, scoped_buffers);
229
230
231
232
233
234

    for (auto buffer : versioned_buffers) {
      Var buffer_var = buffer->data;
      Buffer new_buffer = RewriteAllocBuffer(buffer, num_stages);
      buffer_remap_.Set(buffer, new_buffer);
    }
235
236
237
238
239
240
    PrimExpr linear_index = loop_stack_[0].first;
    for (size_t i = 1; i < loop_stack_.size(); ++i) {
      linear_index =
          linear_index * loop_stack_[i].second + loop_stack_[i].first;
    }
    version_index_ = FloorMod(linear_index, num_stages);
241
    auto for_node = StmtExprMutator::VisitStmt_(op);
242
    loop_stack_.pop_back();
243
244
245
246

    return for_node;
  }

247
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
248
249
250
251
252
    BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
    auto it = buffer_remap_.find(load->buffer);
    if (it == buffer_remap_.end()) {
      return std::move(load);
    }
253
254
    const Buffer &new_buffer = (*it).second;
    auto *n = load.CopyOnWrite();
255
256
257
258
259
    n->buffer = new_buffer;
    n->indices.insert(n->indices.begin(), version_index_);
    return std::move(load);
  }

260
  Stmt VisitStmt_(const BufferStoreNode *op) final {
261
262
263
264
265
    BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
    auto it = buffer_remap_.find(store->buffer);
    if (it == buffer_remap_.end()) {
      return std::move(store);
    }
266
267
    const Buffer &new_buffer = (*it).second;
    auto *n = store.CopyOnWrite();
268
269
270
271
272
    n->buffer = new_buffer;
    n->indices.insert(n->indices.begin(), version_index_);
    return std::move(store);
  }

273
  PrimExpr VisitExpr_(const CallNode *op) final {
274
275
276
277
278
279
280
    Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
    if (call->op.same_as(builtin::tvm_access_ptr())) {
      return RewriteBufferAccess(call, {1});
    }
    return call;
  }

281
  PrimExpr RewriteBufferAccess(const Call &call,
282
                               const std::vector<int> &arg_indices) {
283
284
    auto product = [](const Array<PrimExpr> &input) {
      return foldl(
285
286
287
          [](PrimExpr a, PrimExpr b, Span span) {
            return mul(std::move(a), std::move(b), std::move(span));
          },
288
          make_const(DataType::Int(32), 1), input);
289
290
291
292
    };
    Array<PrimExpr> new_args = call->args;
    for (int i : arg_indices) {
      auto buffer_var = Downcast<Var>(call->args[i]);
293
294
295
      if (!buffer_data_to_buffer_.count(buffer_var))
        continue;
      const Buffer &buffer = buffer_data_to_buffer_[buffer_var];
296
297
      auto it = buffer_remap_.find(buffer);
      if (it != buffer_remap_.end()) {
298
299
        const Buffer &new_buffer = (*it).second;
        const PrimExpr &old_index = call->args[i + 1];
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        PrimExpr offset;
        if (new_buffer->strides.empty()) {
          offset = product(buffer->shape);
        } else {
          offset = new_buffer->strides[0];
        }
        PrimExpr new_index = old_index + version_index_ * offset;
        new_args.Set(i + 1, new_index);
      }
    }
    return Call(call->dtype, call->op, new_args, call->span);
  }

  PrimExpr version_index_;
314
  std::vector<std::pair<Var, PrimExpr>> loop_stack_;
315
316
317
318
319
320
321
322
  Map<Var, Buffer> buffer_data_to_buffer_;
  Map<Buffer, Optional<Stmt>> buffer_lca_;
  Map<Buffer, Buffer> buffer_remap_;
};

using namespace tir::transform;

tvm::transform::Pass MultiVersionBuffer() {
323
  auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
324
325
326
327
328
    return MultiVersionBufferRewriter::Substitute(f);
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.MultiVersionBuffer", {});
}

329
330
331
332
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.MultiVersionBuffer", MultiVersionBuffer);
});
333

334
335
} // namespace tl
} // namespace tvm