lower_dcu_resource.cc 15.9 KB
Newer Older
qisan's avatar
qisan committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/op.h>
#include <tvm/tir/builtin.h>
#include <tvm/ffi/cast.h>
#include <tvm/ffi/memory.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/arith/analyzer.h>

#include <vector>
#include <unordered_map>
#include <unordered_set>

using tvm::ffi::GetRef;
using tvm::ffi::make_object;

namespace tvm {
namespace tl {
using namespace tir;
using ffi::Array;
using ffi::String;

// ============================================================================
// 数据结构
// ============================================================================
struct CopyInfo {
    Buffer dst_buffer;
    Buffer src_buffer;
    Array<PrimExpr> dst_indices;
    Array<PrimExpr> src_indices;
    Stmt store_stmt;
};

struct CollectResult {
    std::vector<CopyInfo> copies;
    // 映射: Global Buffer Name -> DCU Resource Var (用于替换Store)
    std::unordered_map<String, Var> global_to_res_var;
    // 映射: Shared Buffer Name -> 要注入的LetStmt绑定 (Var, PrimExpr)
    // 这样我们就可以根据 shared buffer 的位置来决定注入点
    std::unordered_map<String, std::pair<Var, PrimExpr>> shared_alloc_to_binding;

    const StmtNode* inject_target = nullptr;
};

class VariableEliminator : public tvm::tir::ExprMutator {
 public:
  explicit VariableEliminator(const std::unordered_set<const tvm::tir::VarNode*>& vars)
      : vars_to_remove_(vars) {}

  PrimExpr VisitExpr_(const tvm::tir::VarNode* op) override {
    if (vars_to_remove_.count(op)) {
      return tvm::tir::make_zero(op->dtype);
    }
    return GetRef<PrimExpr>(op);
  }

 private:
  const std::unordered_set<const tvm::tir::VarNode*>& vars_to_remove_;
};

class VariableKeeper : public tvm::tir::ExprMutator {
 public:
  explicit VariableKeeper(const std::unordered_set<const tvm::tir::VarNode*>& keep_vars)
      : keep_vars_(keep_vars) {}

  PrimExpr VisitExpr_(const tvm::tir::VarNode* op) override {
    // 关键调试:打印每一个遇到的变量及其地址
    if (keep_vars_.count(op)) {
      LOG(INFO) << "[KEEP] Found var in list: " << op->name_hint << " (" << op << ")";
      return GetRef<PrimExpr>(op);
    } else {
      LOG(INFO) << "[ERASE] Var not in list: " << op->name_hint << " (" << op << ")";
      return tvm::tir::make_zero(op->dtype);
    }
  }

  // 额外处理:防止 Load 节点中的变量丢失
  PrimExpr VisitExpr_(const tvm::tir::BufferLoadNode* op) override {
    // 如果你的索引里嵌套了 BufferLoad,Load 本身不是 Var,
    // 但它里面可能含有 Var。Mutator 默认会递归,但我们可以显式打印。
    return ExprMutator::VisitExpr_(op);
  }

 private:
  const std::unordered_set<const tvm::tir::VarNode*>& keep_vars_;
};

// ============================================================================
// Phase 1: 收集拷贝信息 & 生成资源绑定
// ============================================================================
CollectResult CollectResources(const Stmt& body) {
    class Collector : public StmtExprVisitor {
    public:
        CollectResult result;

    private:
qisan's avatar
qisan committed
97
        bool in_async{false};
qisan's avatar
qisan committed
98
99
100
101
102
103
104
105
106
107
108
        std::unordered_set<const tvm::tir::VarNode*> loop_vars_;
        std::vector<const tvm::tir::StmtNode*> scope_stack_; // 追踪当前遍历的 AST 路径
        bool IsSharedScope(const Buffer& buf) {
            auto s = buf.scope();
            return s == "shared" || s == "shared.dyn";
        }
        bool IsGlobalScope(const Buffer& buf) {
            auto s = buf.scope();
            return s == "global" || s == "";
        }

qisan's avatar
qisan committed
109
110
111
        void VisitStmt_(const AttrStmtNode* attr) override {
            scope_stack_.push_back(attr);
            if (attr->attr_key == tir::attr::thread_extent) {
qisan's avatar
qisan committed
112
                // 1. 获取 IterVar
qisan's avatar
qisan committed
113
                auto iv = attr->node.as<tvm::tir::IterVarNode>();
qisan's avatar
qisan committed
114
115
116
117
                const std::string& tag = iv->thread_tag;

                if (tag.find("threadIdx") != std::string::npos) {
                    tvm::tir::Var thread_var = iv->var;
qisan's avatar
qisan committed
118
                    LOG(INFO) << "Entering thread scope: " << tag << " with var " << thread_var->name_hint;
qisan's avatar
qisan committed
119
120
                    loop_vars_.insert(thread_var.get());

qisan's avatar
qisan committed
121
                    StmtExprVisitor::VisitStmt_(attr);
qisan's avatar
qisan committed
122
123
124
125

                    loop_vars_.erase(thread_var.get());
                } else {
                    // 如果是 blockIdx 或其他,直接跳过当前层继续往下走
qisan's avatar
qisan committed
126
                    StmtExprVisitor::VisitStmt_(attr);
qisan's avatar
qisan committed
127
128
                }

qisan's avatar
qisan committed
129
130
131
132
133
134
135
136
137
138
            } else if (attr->attr_key == tir::attr::async_scope) {
                ICHECK(in_async == false) << "Nested async scopes not supported";
                in_async = true;
                
                StmtExprVisitor::VisitStmt_(attr);
                
                in_async = false;
            }
            else {
                StmtExprVisitor::VisitStmt_(attr);
qisan's avatar
qisan committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
            }
            scope_stack_.pop_back();
        }

        void VisitStmt_(const SeqStmtNode* op) override {
            scope_stack_.push_back(op);
            StmtExprVisitor::VisitStmt_(op);
            scope_stack_.pop_back();
        }

        void VisitStmt_(const ForNode* op) override {
            scope_stack_.push_back(op);
            loop_vars_.insert(op->loop_var.get());
            StmtExprVisitor::VisitStmt_(op);
            loop_vars_.erase(op->loop_var.get());
            scope_stack_.pop_back();
        }

        void VisitStmt_(const BufferStoreNode* op) final {
qisan's avatar
qisan committed
158
            LOG(INFO) << "Visiting BufferStore: " << op->buffer->name;
qisan's avatar
qisan committed
159
160

            Buffer dst = op->buffer;
qisan's avatar
qisan committed
161
            if (IsSharedScope(dst) && op->value.defined() && in_async) {
qisan's avatar
qisan committed
162
163
164
                if (const auto* load = op->value.as<BufferLoadNode>()) {
                    Buffer src = load->buffer;
                    if (IsGlobalScope(src)) {
qisan's avatar
qisan committed
165
                        const StmtNode* target = op;
qisan's avatar
qisan committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
                        if (result.inject_target == nullptr) {
                            for (int i = scope_stack_.size() - 1; i >= 0; --i) {
                                if (scope_stack_[i]->IsInstance<AttrStmtNode>()) {
                                    auto attr = static_cast<const AttrStmtNode*>(scope_stack_[i]);
                                    if (attr->attr_key == tvm::tir::attr::thread_extent) {
                                        // 找到了最内层的线程绑定。它里面的下一个节点(i+1)就是我们应该包裹的节点
                                        if (i + 1 < scope_stack_.size()) {
                                            result.inject_target = scope_stack_[i + 1];
                                        }
                                        break;
                                    }
                                }
                            }
                            if (result.inject_target == nullptr && !scope_stack_.empty()) {
                                    for (const auto* node : scope_stack_) {
                                        if (node->IsInstance<ForNode>() || node->IsInstance<SeqStmtNode>()) {
                                            result.inject_target = node;
                                            break;
                                        }
                                    }
                                }
                            // 如果还是空,直接 fallback 到当前操作
                            if (result.inject_target == nullptr) result.inject_target = op;
                        }


                        // 1. 记录拷贝
                        VariableKeeper keeper(loop_vars_);
                        tvm::arith::Analyzer analyzer;
                        Array<PrimExpr> for_var_only_indices;

                        for (const auto& idx : load->indices) {
                            PrimExpr filtered = keeper(idx); 
                            for_var_only_indices.push_back(analyzer.Simplify(filtered));
                            LOG(INFO) << "ONLY Index: " << idx;
                        }
                        CopyInfo info{dst, src, op->indices, for_var_only_indices, GetRef<Stmt>(op)};
                        result.copies.push_back(info);

                        // 2. 只有当没处理过这个 Global Buffer 时才生成 Binding
                        if (result.global_to_res_var.find(src->name) == result.global_to_res_var.end()) {
                            Var var(src->name + "_dcu_res", DataType::Int(32, 4));

                            VariableEliminator eliminator(loop_vars_);
                            tvm::arith::Analyzer analyzer;
                            Array<PrimExpr> base_indices;
qisan's avatar
qisan committed
212
213
214
215
                            LOG(INFO) << loop_vars_.size() << " loop vars in context.";
                            for (const auto* var : loop_vars_) {
                                LOG(INFO) << "Loop Var: " << var->name_hint;
                            }
qisan's avatar
qisan committed
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
                            for (const auto& idx : load->indices) {
                                // 将所有外层循环变量 (k, i 等) 全部替换为 0
                                PrimExpr no_loops = eliminator(idx);
                                // 化简出最终的基地址表达式
                                base_indices.push_back(analyzer.Simplify(no_loops));
                            }

                            // ✅ 关键点:填充真实的地址信息 src->data (即 A.data)
                            Array<PrimExpr> args;
                            args.push_back(src->data);  // 先加 data

                            // 如果需要把 indices 的每个元素作为独立参数展开:
                            for (const auto& idx : base_indices) {
                                args.push_back(idx);
                                LOG(INFO) << "Clean Index: " << idx;
                            }
                            PrimExpr val = Call(DataType::Int(32, 4), 
                                                Op::Get("tl.make_dcu_resource"), args);
                            
                            result.global_to_res_var[src->name] = var;
                            // 将这个绑定关系和 destination 的 shared buffer 绑死
                            result.shared_alloc_to_binding[src->name] = {var, val};
                        }
qisan's avatar
qisan committed
239
                        LOG(INFO) << "result.copies.size() = " << result.copies.size();
qisan's avatar
qisan committed
240
241
242
243
244
245
                    }
                }
            }
            StmtExprVisitor::VisitStmt_(op);
        }
    };
qisan's avatar
qisan committed
246
    LOG(INFO) << "Starting resource collection...";
qisan's avatar
qisan committed
247
248
249

    Collector col;
    col(body);
qisan's avatar
qisan committed
250
    LOG(INFO) << "Finished resource collection. Found " << col.result.copies.size() << " copy(s).";
qisan's avatar
qisan committed
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    return col.result;
}

// ============================================================================
// Phase 2: 替换 BufferStore -> dcu_async_copy
// ============================================================================
class StoreReplacer : public StmtExprMutator {
public:
    static Stmt Run(Stmt body, const std::vector<CopyInfo>& copies, 
                    const std::unordered_map<String, Var>& global_to_var) {
        StoreReplacer replacer(copies, global_to_var);
        return replacer(std::move(body));
    }

private:
    StoreReplacer(const std::vector<CopyInfo>& copies, 
                  const std::unordered_map<String, Var>& global_to_var)
        : copies_(copies), global_to_var_(global_to_var) {}

qisan's avatar
qisan committed
270
271
272
273
274
275
276
277
278
    Stmt VisitStmt_(const AttrStmtNode *attr) {
        if (attr->attr_key == tir::attr::async_scope) {
            auto body = this->VisitStmt(attr->body);
            return body;
        }
        return StmtMutator::VisitStmt_(attr); // ③ 其他属性:默认保留
    }


qisan's avatar
qisan committed
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    Stmt VisitStmt_(const BufferStoreNode* op) final {
        for (const auto& copy : copies_) {
            if (copy.store_stmt.same_as(GetRef<Stmt>(op))) {
                // Global 取 resource var (A_dcu_res)
                Var src_res = global_to_var_.at(copy.src_buffer->name);
                // Shared 取 data pointer (A_shared.data)
                PrimExpr dst_res = copy.dst_buffer->data; 
                
                PrimExpr copy_size = IntImm(DataType::Int(32), 1);
                PrimExpr predicate = Bool(true);
                
                return Evaluate(
                    Call(DataType::Int(32), Op::Get("tl.dcu_async_copy"),
                         {dst_res, Flatten(copy.dst_indices),
                          src_res, Flatten(copy.src_indices),
                          copy_size, predicate}));
            }
        }
        return StmtExprMutator::VisitStmt_(op);
    }

    PrimExpr Flatten(const Array<PrimExpr>& idx) {
        if (idx.empty()) return IntImm(DataType::Int(32), 0);
        if (idx.size() == 1) return idx[0];
        PrimExpr r = idx[0];
        for (size_t i = 1; i < idx.size(); ++i) r = r + idx[i];
        return r;
    }

    const std::vector<CopyInfo>& copies_;
    const std::unordered_map<String, Var>& global_to_var_;
};

// ============================================================================
// Phase 3: 根据 Shared Alloc 位置进行精准注入
// ============================================================================
class ResourceInjector : public tvm::tir::StmtExprMutator {
public:
    static Stmt Run(Stmt body, 
                    const std::unordered_map<String, std::pair<Var, PrimExpr>>& bindings,
                    const tvm::tir::StmtNode* target) {
        if (!target || bindings.empty()) return body;
        ResourceInjector mutator(bindings, target);
        return mutator(std::move(body));
    }

private:
    ResourceInjector(const std::unordered_map<String, std::pair<Var, PrimExpr>>& bindings,
                     const tvm::tir::StmtNode* target)
        : bindings_(bindings), target_(target) {}

    Stmt VisitStmt(const Stmt& stmt) override {
        // 当我们遍历到刚才标记的那个 AST 节点时
        if (stmt.get() == target_) {
            // 先向下遍历(保持 TVM Mutator 的习惯)
            Stmt new_stmt = StmtExprMutator::VisitStmt(stmt);
            
            // 在这个节点的外面套上所有的 LetStmt
            for (const auto& item : bindings_) {
                Var res_var = item.second.first;
                PrimExpr init_expr = item.second.second;
                new_stmt = tvm::tir::LetStmt(res_var, init_expr, new_stmt);
            }
            return new_stmt; // 返回包裹好的新节点
        }
        return StmtExprMutator::VisitStmt(stmt);
    }

    std::unordered_map<String, std::pair<Var, PrimExpr>> bindings_;
    const tvm::tir::StmtNode* target_;
};

// ============================================================================
// Pass 入口
// ============================================================================
PrimFunc LowerSharedGlobalCopy(PrimFunc f) {
    auto* n = f.CopyOnWrite();
    
qisan's avatar
qisan committed
357
358
    // 收集信息
    LOG(INFO) << "Starting LowerSharedGlobalCopy transformation...";
qisan's avatar
qisan committed
359
    auto res = CollectResources(n->body);
qisan's avatar
qisan committed
360
361
362
363
    if (res.copies.empty()){
        LOG(INFO) << "No shared-global copy patterns detected. Skipping transformation.";
        return f;
    }
qisan's avatar
qisan committed
364

qisan's avatar
qisan committed
365
366
    LOG(INFO) << "Replaced " << res.copies.size() << " copy(s) with dcu_async_copy.";
    // 注入res声明
qisan's avatar
qisan committed
367
368
    Stmt injected = ResourceInjector::Run(n->body, res.shared_alloc_to_binding, res.inject_target);
    
qisan's avatar
qisan committed
369
    // 替换拷贝语句 
qisan's avatar
qisan committed
370
371
    Stmt replaced = StoreReplacer::Run(injected, res.copies, res.global_to_res_var);
    
qisan's avatar
qisan committed
372
373
    
    // 写回
qisan's avatar
qisan committed
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
    n->body = std::move(replaced);
    
    return GetRef<PrimFunc>(n);
}

namespace transform {
using namespace tir::transform;

tvm::transform::Pass LowerSharedGlobalCopy() {
    auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
        return tl::LowerSharedGlobalCopy(std::move(f));
    };
    return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedGlobalCopy", {});
}

TVM_FFI_STATIC_INIT_BLOCK() {
    namespace refl = tvm::ffi::reflection;
    refl::GlobalDef().def("tl.transform.LowerSharedGlobalCopy", LowerSharedGlobalCopy);
}

}  // namespace transform
}  // namespace tl
}  // namespace tvm