lower_dcu_resource.cc 12.1 KB
Newer Older
qisan's avatar
qisan committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/op.h>
#include <tvm/tir/builtin.h>
#include <tvm/ffi/cast.h>
#include <tvm/ffi/memory.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/arith/analyzer.h>

#include <vector>
#include <unordered_map>
#include <unordered_set>

using tvm::ffi::GetRef;
using tvm::ffi::make_object;

namespace tvm {
namespace tl {
using namespace tir;
using ffi::Array;
using ffi::String;

struct CopyInfo {
    Buffer dst_buffer;
    Buffer src_buffer;
    Array<PrimExpr> dst_indices;
    Array<PrimExpr> src_indices;
    Stmt store_stmt;
};

struct CollectResult {
    std::vector<CopyInfo> copies;
    std::unordered_map<String, Var> global_to_res_var;
    std::unordered_map<String, std::pair<Var, PrimExpr>> shared_alloc_to_binding;

    const StmtNode* inject_target = nullptr;
};

class VariableEliminator : public tvm::tir::ExprMutator {
 public:
  explicit VariableEliminator(const std::unordered_set<const tvm::tir::VarNode*>& vars)
      : vars_to_remove_(vars) {}

  PrimExpr VisitExpr_(const tvm::tir::VarNode* op) override {
    if (vars_to_remove_.count(op)) {
      return tvm::tir::make_zero(op->dtype);
    }
    return GetRef<PrimExpr>(op);
  }

 private:
  const std::unordered_set<const tvm::tir::VarNode*>& vars_to_remove_;
};

class VariableKeeper : public tvm::tir::ExprMutator {
 public:
  explicit VariableKeeper(const std::unordered_set<const tvm::tir::VarNode*>& keep_vars)
      : keep_vars_(keep_vars) {}

  PrimExpr VisitExpr_(const tvm::tir::VarNode* op) override {
    if (keep_vars_.count(op)) {
      return GetRef<PrimExpr>(op);
    } else {
      return tvm::tir::make_zero(op->dtype);
    }
  }

  PrimExpr VisitExpr_(const tvm::tir::BufferLoadNode* op) override {
    return ExprMutator::VisitExpr_(op);
  }

 private:
  const std::unordered_set<const tvm::tir::VarNode*>& keep_vars_;
};

CollectResult CollectResources(const Stmt& body) {
    class Collector : public StmtExprVisitor {
    public:
        CollectResult result;

    private:
qisan's avatar
qisan committed
82
        bool in_async{false};
qisan's avatar
qisan committed
83
        std::unordered_set<const tvm::tir::VarNode*> loop_vars_;
qisan's avatar
qisan committed
84
        std::vector<const tvm::tir::StmtNode*> scope_stack_;
qisan's avatar
qisan committed
85
86
87
88
89
90
91
92
93
        bool IsSharedScope(const Buffer& buf) {
            auto s = buf.scope();
            return s == "shared" || s == "shared.dyn";
        }
        bool IsGlobalScope(const Buffer& buf) {
            auto s = buf.scope();
            return s == "global" || s == "";
        }

qisan's avatar
qisan committed
94
95
96
97
        void VisitStmt_(const AttrStmtNode* attr) override {
            scope_stack_.push_back(attr);
            if (attr->attr_key == tir::attr::thread_extent) {
                auto iv = attr->node.as<tvm::tir::IterVarNode>();
qisan's avatar
qisan committed
98
99
100
101
102
103
                const std::string& tag = iv->thread_tag;

                if (tag.find("threadIdx") != std::string::npos) {
                    tvm::tir::Var thread_var = iv->var;
                    loop_vars_.insert(thread_var.get());

qisan's avatar
qisan committed
104
                    StmtExprVisitor::VisitStmt_(attr);
qisan's avatar
qisan committed
105
106
107

                    loop_vars_.erase(thread_var.get());
                } else {
qisan's avatar
qisan committed
108
                    StmtExprVisitor::VisitStmt_(attr);
qisan's avatar
qisan committed
109
110
                }

qisan's avatar
qisan committed
111
112
113
114
115
116
117
118
119
120
            } else if (attr->attr_key == tir::attr::async_scope) {
                ICHECK(in_async == false) << "Nested async scopes not supported";
                in_async = true;
                
                StmtExprVisitor::VisitStmt_(attr);
                
                in_async = false;
            }
            else {
                StmtExprVisitor::VisitStmt_(attr);
qisan's avatar
qisan committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
            }
            scope_stack_.pop_back();
        }

        void VisitStmt_(const SeqStmtNode* op) override {
            scope_stack_.push_back(op);
            StmtExprVisitor::VisitStmt_(op);
            scope_stack_.pop_back();
        }

        void VisitStmt_(const ForNode* op) override {
            scope_stack_.push_back(op);
            loop_vars_.insert(op->loop_var.get());
            StmtExprVisitor::VisitStmt_(op);
            loop_vars_.erase(op->loop_var.get());
            scope_stack_.pop_back();
        }

qisan's avatar
qisan committed
139
140
141
142
143
144
145
146
147
        static const BufferLoadNode* PeelGlobalLoadValue(const PrimExpr& v) {
            if (const auto* load = v.as<BufferLoadNode>()) {
                return load;
            }
            if (const auto* cast = v.as<CastNode>()) {
                return cast->value.as<BufferLoadNode>();
            }
            return nullptr;
        }
qisan's avatar
qisan committed
148

qisan's avatar
qisan committed
149
        void VisitStmt_(const BufferStoreNode* op) final {
qisan's avatar
qisan committed
150
            Buffer dst = op->buffer;
qisan's avatar
qisan committed
151
            if (IsSharedScope(dst) && op->value.defined() && in_async) {
qisan's avatar
qisan committed
152
                if (const auto* load = PeelGlobalLoadValue(op->value)) {
qisan's avatar
qisan committed
153
154
                    Buffer src = load->buffer;
                    if (IsGlobalScope(src)) {
qisan's avatar
qisan committed
155
                        const StmtNode* target = op;
qisan's avatar
qisan committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
                        if (result.inject_target == nullptr) {
                            for (int i = scope_stack_.size() - 1; i >= 0; --i) {
                                if (scope_stack_[i]->IsInstance<AttrStmtNode>()) {
                                    auto attr = static_cast<const AttrStmtNode*>(scope_stack_[i]);
                                    if (attr->attr_key == tvm::tir::attr::thread_extent) {
                                        if (i + 1 < scope_stack_.size()) {
                                            result.inject_target = scope_stack_[i + 1];
                                        }
                                        break;
                                    }
                                }
                            }
                            if (result.inject_target == nullptr && !scope_stack_.empty()) {
                                    for (const auto* node : scope_stack_) {
                                        if (node->IsInstance<ForNode>() || node->IsInstance<SeqStmtNode>()) {
                                            result.inject_target = node;
                                            break;
                                        }
                                    }
                                }
                            if (result.inject_target == nullptr) result.inject_target = op;
                        }


                        VariableKeeper keeper(loop_vars_);
                        tvm::arith::Analyzer analyzer;
                        Array<PrimExpr> for_var_only_indices;

                        for (const auto& idx : load->indices) {
                            PrimExpr filtered = keeper(idx); 
                            for_var_only_indices.push_back(analyzer.Simplify(filtered));
                        }
                        CopyInfo info{dst, src, op->indices, for_var_only_indices, GetRef<Stmt>(op)};
                        result.copies.push_back(info);

                        if (result.global_to_res_var.find(src->name) == result.global_to_res_var.end()) {
                            Var var(src->name + "_dcu_res", DataType::Int(32, 4));

                            VariableEliminator eliminator(loop_vars_);
                            tvm::arith::Analyzer analyzer;
                            Array<PrimExpr> base_indices;
                            for (const auto& idx : load->indices) {
                                PrimExpr no_loops = eliminator(idx);
                                base_indices.push_back(analyzer.Simplify(no_loops));
                            }

                            Array<PrimExpr> args;
qisan's avatar
qisan committed
203
                            args.push_back(src->data);
qisan's avatar
qisan committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238

                            for (const auto& idx : base_indices) {
                                args.push_back(idx);
                            }
                            PrimExpr val = Call(DataType::Int(32, 4), 
                                                Op::Get("tl.make_dcu_resource"), args);
                            
                            result.global_to_res_var[src->name] = var;
                            result.shared_alloc_to_binding[src->name] = {var, val};
                        }
                    }
                }
            }
            StmtExprVisitor::VisitStmt_(op);
        }
    };

    Collector col;
    col(body);
    return col.result;
}

class StoreReplacer : public StmtExprMutator {
public:
    static Stmt Run(Stmt body, const std::vector<CopyInfo>& copies, 
                    const std::unordered_map<String, Var>& global_to_var) {
        StoreReplacer replacer(copies, global_to_var);
        return replacer(std::move(body));
    }

private:
    StoreReplacer(const std::vector<CopyInfo>& copies, 
                  const std::unordered_map<String, Var>& global_to_var)
        : copies_(copies), global_to_var_(global_to_var) {}

qisan's avatar
qisan committed
239
240
241
242
243
    Stmt VisitStmt_(const AttrStmtNode *attr) {
        if (attr->attr_key == tir::attr::async_scope) {
            auto body = this->VisitStmt(attr->body);
            return body;
        }
qisan's avatar
qisan committed
244
        return StmtMutator::VisitStmt_(attr); 
qisan's avatar
qisan committed
245
246
247
    }


qisan's avatar
qisan committed
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
    Stmt VisitStmt_(const BufferStoreNode* op) final {
        for (const auto& copy : copies_) {
            if (copy.store_stmt.same_as(GetRef<Stmt>(op))) {
                Var src_res = global_to_var_.at(copy.src_buffer->name);
                PrimExpr dst_res = copy.dst_buffer->data; 
                
                PrimExpr copy_size = IntImm(DataType::Int(32), 1);
                PrimExpr predicate = Bool(true);
                
                return Evaluate(
                    Call(DataType::Int(32), Op::Get("tl.dcu_async_copy"),
                         {dst_res, Flatten(copy.dst_indices),
                          src_res, Flatten(copy.src_indices),
                          copy_size, predicate}));
            }
        }
        return StmtExprMutator::VisitStmt_(op);
    }

    PrimExpr Flatten(const Array<PrimExpr>& idx) {
        if (idx.empty()) return IntImm(DataType::Int(32), 0);
        if (idx.size() == 1) return idx[0];
        PrimExpr r = idx[0];
        for (size_t i = 1; i < idx.size(); ++i) r = r + idx[i];
        return r;
    }

    const std::vector<CopyInfo>& copies_;
    const std::unordered_map<String, Var>& global_to_var_;
};

class ResourceInjector : public tvm::tir::StmtExprMutator {
public:
    static Stmt Run(Stmt body, 
                    const std::unordered_map<String, std::pair<Var, PrimExpr>>& bindings,
                    const tvm::tir::StmtNode* target) {
        if (!target || bindings.empty()) return body;
        ResourceInjector mutator(bindings, target);
        return mutator(std::move(body));
    }

private:
    ResourceInjector(const std::unordered_map<String, std::pair<Var, PrimExpr>>& bindings,
                     const tvm::tir::StmtNode* target)
        : bindings_(bindings), target_(target) {}

    Stmt VisitStmt(const Stmt& stmt) override {
        if (stmt.get() == target_) {
            Stmt new_stmt = StmtExprMutator::VisitStmt(stmt);
            
            for (const auto& item : bindings_) {
                Var res_var = item.second.first;
                PrimExpr init_expr = item.second.second;
                new_stmt = tvm::tir::LetStmt(res_var, init_expr, new_stmt);
            }
qisan's avatar
qisan committed
303
            return new_stmt; 
qisan's avatar
qisan committed
304
305
306
307
308
309
310
311
312
313
314
315
        }
        return StmtExprMutator::VisitStmt(stmt);
    }

    std::unordered_map<String, std::pair<Var, PrimExpr>> bindings_;
    const tvm::tir::StmtNode* target_;
};

PrimFunc LowerSharedGlobalCopy(PrimFunc f) {
    auto* n = f.CopyOnWrite();
    
    auto res = CollectResources(n->body);
qisan's avatar
qisan committed
316
317
318
    if (res.copies.empty()){
        return f;
    }
qisan's avatar
qisan committed
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346

    Stmt injected = ResourceInjector::Run(n->body, res.shared_alloc_to_binding, res.inject_target);
    
    Stmt replaced = StoreReplacer::Run(injected, res.copies, res.global_to_res_var);
    
    n->body = std::move(replaced);
    
    return GetRef<PrimFunc>(n);
}

namespace transform {
using namespace tir::transform;

tvm::transform::Pass LowerSharedGlobalCopy() {
    auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
        return tl::LowerSharedGlobalCopy(std::move(f));
    };
    return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedGlobalCopy", {});
}

TVM_FFI_STATIC_INIT_BLOCK() {
    namespace refl = tvm::ffi::reflection;
    refl::GlobalDef().def("tl.transform.LowerSharedGlobalCopy", LowerSharedGlobalCopy);
}

}  // namespace transform
}  // namespace tl
}  // namespace tvm