Commit b14f201e authored by qisan's avatar qisan
Browse files

Feats: Add async, pipeline and ds_read

parent 44cc93c7
...@@ -827,14 +827,12 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -827,14 +827,12 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
func_name += "_trans"; func_name += "_trans";
print_extern_call_stmt(func_name, 2); print_extern_call_stmt(func_name, 2);
}else if(op->op.same_as(tl::ds_read_vector())){ }else if(op->op.same_as(tl::ds_read_vector())){
//ds_read_b64 %1, %2 offset:%3
// ds_read_m32x16_b16 %0, %1 offset:%2
std::string dst = this->PrintExpr(op->args[0]); std::string dst = this->PrintExpr(op->args[0]);
std::string local_offset = this->PrintExpr(op->args[1]); std::string local_offset = this->PrintExpr(op->args[1]);
std::string lds_offset = this->PrintExpr(op->args[2]); std::string lds_offset = this->PrintExpr(op->args[2]);
os << "tl::ds_read_vector(" os << "tl::ds_read_vector(*(float4_ *)("
<< dst << " + " << local_offset << dst << " + " << local_offset
<< ", " << "), "
<< lds_offset << lds_offset
<< ")"; << ")";
}else if (op->op.same_as(tl::wait_wgmma())) { }else if (op->op.same_as(tl::wait_wgmma())) {
......
...@@ -133,12 +133,11 @@ TL_DEVICE int32x4_t make_wave_buffer_resource(const void *ptr, ...@@ -133,12 +133,11 @@ TL_DEVICE int32x4_t make_wave_buffer_resource(const void *ptr,
// } // }
// } // }
TL_DEVICE void ds_read_vector(void* dst, uint32_t lds_base_ptr) TL_DEVICE void ds_read_vector(float4_& dst, uint32_t lds_base_ptr)
{ {
asm volatile("ds_read_m32x16_b16 %0, %1 offset:0\n\t" asm volatile("ds_read_m32x16_b16 %0, %1 offset:0\n\t"
: "+v"(dst) : "+v"(dst)
: "v"(lds_base_ptr), : "v"(lds_base_ptr));
: "memory");
} }
// template <int M, int N, int offset> // template <int M, int N, int offset>
......
...@@ -57,13 +57,15 @@ class BLocalLayoutTransformer : public StmtExprMutator { ...@@ -57,13 +57,15 @@ class BLocalLayoutTransformer : public StmtExprMutator {
int expand_; int expand_;
Stmt VisitStmt_(const ForNode* op) final { Stmt VisitStmt_(const ForNode* op) final {
// 只处理 serial 外层循环 // 1. 先递归处理子节点(重要:确保处理了嵌套的 For 或 Attr)
Stmt new_body = this->VisitStmt(op->body);
// 2. 检查当前循环是否是目标循环
// 即使 body 变了,我们也尝试看看能不能在这个 loop 层级做变换
auto store = new_body.as<BufferStoreNode>();
if (op->kind != ForKind::kSerial) { if (op->kind != ForKind::kSerial) {
return StmtExprMutator::VisitStmt_(op); return StmtExprMutator::VisitStmt_(op);
} }
// 判断是否是 B_local 写循环
auto store = op->body.as<BufferStoreNode>();
if (!store) { if (!store) {
return StmtExprMutator::VisitStmt_(op); return StmtExprMutator::VisitStmt_(op);
} }
...@@ -79,7 +81,6 @@ class BLocalLayoutTransformer : public StmtExprMutator { ...@@ -79,7 +81,6 @@ class BLocalLayoutTransformer : public StmtExprMutator {
int64_t new_extent = old_extent / expand_; int64_t new_extent = old_extent / expand_;
// 修改循环范围
For new_for = For new_for =
For(op->loop_var, For(op->loop_var,
op->min, op->min,
...@@ -94,100 +95,42 @@ class BLocalLayoutTransformer : public StmtExprMutator { ...@@ -94,100 +95,42 @@ class BLocalLayoutTransformer : public StmtExprMutator {
std::string name = buffer->name; std::string name = buffer->name;
return name.find("B_local") != std::string::npos; return name.find("B_local") != std::string::npos;
} }
PrimExpr UpdateIndexBase(PrimExpr base, const Var& loop_var, int expand) {
Stmt MutateStore(const BufferStoreNode* store, if (const auto* add = base.as<AddNode>()) {
const Var& loop_var) { return UpdateIndexBase(add->a, loop_var, expand) +
UpdateIndexBase(add->b, loop_var, expand);
Array<PrimExpr> new_indices = store->indices; } else if (const auto* mul = base.as<MulNode>()) {
PrimExpr new_value = store->value;
// 修改切片跨度:
// 原来 j*vec : j*vec+vec
// 改为 j*vec : j*vec*expand + vec
PrimExpr idx = store->indices[0];
//T.Ramp(j * 4, 1, 4) -> Ramp(j*8, 1, 4)
std::cout << idx << std::endl;
// 解析 j*vec 结构
// 假设结构为 j * vec + const
// 不改 RHS
// PrimExpr value = store->value;
// 修改写入向量宽度
// 原 value 是 Ramp(base=j*4, stride=1, lanes=4)
// 匹配 j * stride
// Ramp(base=j*8, stride=1, lanes=8)
if (const auto* ramp = idx.as<RampNode>()) {
PrimExpr base = ramp->base;
PrimExpr stride = ramp->stride;
int old_lanes = ramp->lanes.as<IntImmNode>()->value;
int new_lanes = old_lanes * expand_;
// 匹配 base = j * stride_val
if (const auto* mul = base.as<MulNode>()) {
if (mul->a.same_as(loop_var)) { if (mul->a.same_as(loop_var)) {
int64_t old_stride = return mul->a * (mul->b * expand);
mul->b.as<IntImmNode>()->value; } else if (mul->b.same_as(loop_var)) {
return (mul->a * expand) * mul->b;
int64_t new_stride =
old_stride * expand_;
PrimExpr new_base =
loop_var *
make_const(DataType::Int(32), new_stride);
new_indices.Set(
0,
Ramp(new_base, stride, new_lanes));
} }
else if (mul->b.same_as(loop_var)) { }
int64_t old_stride = return base;
mul->a.as<IntImmNode>()->value; }
Stmt MutateStore(const BufferStoreNode* store, const Var& loop_var) {
auto n = tvm::ffi::make_object<BufferStoreNode>(*store);
Array<PrimExpr> new_indices = store->indices;
int64_t new_stride = if (const auto* ramp = store->indices[0].as<RampNode>()) {
old_stride * expand_; PrimExpr new_base = UpdateIndexBase(ramp->base, loop_var, expand_);
PrimExpr new_base = int new_lanes = ramp->lanes.as<IntImmNode>()->value * expand_;
make_const(DataType::Int(32), new_stride) *
loop_var;
new_indices.Set( new_indices.Set(0, Ramp(new_base, ramp->stride, new_lanes));
0,
Ramp(new_base, stride, new_lanes));
}
}
} }
if (auto* load = new_value.as<BufferLoadNode>()) { PrimExpr new_value = store->value;
// BufferLoad with region access: B_shared[start : end] if (const auto* load = store->value.as<BufferLoadNode>()) {
// end - start = lanes,需要同步扩展 if (const auto* l_ramp = load->indices[0].as<RampNode>()) {
Array<PrimExpr> value_indices = load->indices; Array<PrimExpr> v_indices = load->indices;
if (auto* old_ramp = load->indices[0].as<RampNode>()) { int v_new_lanes = l_ramp->lanes.as<IntImmNode>()->value * expand_;
v_indices.Set(0, Ramp(l_ramp->base, l_ramp->stride, v_new_lanes));
PrimExpr scalar_base = old_ramp->base; // 必须是 scalar new_value = BufferLoad(load->buffer, v_indices);
PrimExpr stride = old_ramp->stride;
//RHS 4 lane
int old_lanes = old_ramp->lanes.as<IntImmNode>()->value;
//RHS 8 lane
int new_lanes = old_lanes * expand_;
value_indices.Set(
0,
Ramp(scalar_base, stride, new_lanes)
);
new_value = BufferLoad(load->buffer, value_indices);
} }
} }
return BufferStore(store->buffer, return BufferStore(store->buffer, new_value, new_indices);
new_value,
new_indices);
} }
}; };
......
...@@ -13,7 +13,6 @@ namespace tl { ...@@ -13,7 +13,6 @@ namespace tl {
using ffi::Array; using ffi::Array;
using namespace tir; using namespace tir;
// 1. 辅助类:统计 Shared -> Register 的加载量
class LoadCounter : public StmtExprVisitor { class LoadCounter : public StmtExprVisitor {
public: public:
int total_loads = 0; int total_loads = 0;
...@@ -39,12 +38,31 @@ public: ...@@ -39,12 +38,31 @@ public:
} }
ExprVisitor::VisitExpr_(op); ExprVisitor::VisitExpr_(op);
} }
void VisitExpr_(const CallNode* op) override {
std::string func_name = "";
if (auto opt_op = op->op.as<OpNode>()) {
func_name = opt_op->name;
} else if (auto global_var = op->op.as<GlobalVarNode>()) {
func_name = global_var->name_hint;
}
if (func_name.find("ds_read") != std::string::npos) {
total_loads += current_multiplier;
}
ExprVisitor::VisitExpr_(op);
}
private:
bool IsSharedMem(const Buffer& buf) {
std::string scope = buf.scope();
std::string name = buf->name;
return (scope == "shared" || name.find("shared") != std::string::npos ||
name.find("shmem") != std::string::npos || name.find("LDS") != std::string::npos);
}
}; };
// 2. 核心 Mutator namespace {
class MMABarrierMutator : public StmtExprMutator {
public: bool StmtContainsMMA(const Stmt& stmt) {
bool ContainsMMA(const Stmt& stmt) {
bool found = false; bool found = false;
PostOrderVisit(stmt, [&found](const ObjectRef& node) { PostOrderVisit(stmt, [&found](const ObjectRef& node) {
if (const CallNode* call = node.as<CallNode>()) { if (const CallNode* call = node.as<CallNode>()) {
...@@ -61,26 +79,126 @@ public: ...@@ -61,26 +79,126 @@ public:
} }
}); });
return found; return found;
} }
Stmt VisitStmt_(const SeqStmtNode* op) override { void ScanStmtDefault(const Stmt& s, std::vector<Stmt>* fence_targets);
// --- 步骤 1: 预扫描,确定最后一个需要插入 Fence 的位置 ---
int last_fence_idx = -1; void ScanSeqStmt(const SeqStmtNode* op, std::vector<Stmt>* fence_targets) {
int temp_pending_count = 0; int pending = 0;
for (size_t i = 0; i < op->seq.size(); ++i) { for (size_t i = 0; i < op->seq.size(); ++i) {
if (ContainsMMA(op->seq[i])) { const Stmt& stmt = op->seq[i];
if (temp_pending_count > 0) { if (StmtContainsMMA(stmt)) {
last_fence_idx = static_cast<int>(i); if (pending > 0) {
temp_pending_count = 0; // 模拟重置 fence_targets->push_back(stmt);
pending = 0;
} }
ScanStmtDefault(stmt, fence_targets);
} else { } else {
LoadCounter counter; LoadCounter counter;
counter(op->seq[i]); counter(stmt);
temp_pending_count += counter.total_loads; pending += counter.total_loads;
ScanStmtDefault(stmt, fence_targets);
} }
} }
}
// --- 步骤 2: 实际构造新的 Sequence --- void ScanStmtDefault(const Stmt& s, std::vector<Stmt>* fence_targets) {
if (const auto* seq = s.as<SeqStmtNode>()) {
ScanSeqStmt(seq, fence_targets);
return;
}
if (const auto* op = s.as<AttrStmtNode>()) {
ScanStmtDefault(op->body, fence_targets);
return;
}
if (const auto* op = s.as<LetStmtNode>()) {
ScanStmtDefault(op->body, fence_targets);
return;
}
if (const auto* op = s.as<IfThenElseNode>()) {
ScanStmtDefault(op->then_case, fence_targets);
if (op->else_case) {
ScanStmtDefault(op->else_case.value(), fence_targets);
}
return;
}
if (const auto* op = s.as<ForNode>()) {
ScanStmtDefault(op->body, fence_targets);
return;
}
if (const auto* op = s.as<WhileNode>()) {
ScanStmtDefault(op->body, fence_targets);
return;
}
if (const auto* op = s.as<AllocateNode>()) {
ScanStmtDefault(op->body, fence_targets);
return;
}
if (const auto* op = s.as<AllocateConstNode>()) {
ScanStmtDefault(op->body, fence_targets);
return;
}
if (const auto* op = s.as<DeclBufferNode>()) {
ScanStmtDefault(op->body, fence_targets);
return;
}
if (const auto* op = s.as<BufferRealizeNode>()) {
ScanStmtDefault(op->body, fence_targets);
return;
}
if (const auto* op = s.as<AssertStmtNode>()) {
ScanStmtDefault(op->body, fence_targets);
return;
}
if (const auto* op = s.as<BlockNode>()) {
if (op->init.defined()) {
ScanStmtDefault(op->init.value(), fence_targets);
}
ScanStmtDefault(op->body, fence_targets);
return;
}
if (const auto* op = s.as<BlockRealizeNode>()) {
ScanStmtDefault(op->block, fence_targets);
return;
}
}
Stmt ComputeGlobalLastFenceMMAStmt(const Stmt& root) {
std::vector<Stmt> fence_targets;
ScanStmtDefault(root, &fence_targets);
if (fence_targets.empty()) {
return Stmt();
}
return fence_targets.back();
}
}
class MMABarrierMutator : public StmtExprMutator {
public:
explicit MMABarrierMutator(const Stmt& root_body)
: global_last_fence_mma_(ComputeGlobalLastFenceMMAStmt(root_body)) {}
bool ContainsMMA(const Stmt& stmt) {
bool found = false;
PostOrderVisit(stmt, [&found](const ObjectRef& node) {
if (const CallNode* call = node.as<CallNode>()) {
std::string op_name = "";
if (const OpNode* op = call->op.as<OpNode>()) {
op_name = op->name;
} else if (const GlobalVarNode* gv = call->op.as<GlobalVarNode>()) {
op_name = gv->name_hint;
}
if (op_name.find("mmac") != std::string::npos ||
op_name.find("mma") != std::string::npos) {
found = true;
}
}
});
return found;
}
Stmt VisitStmt_(const SeqStmtNode* op) override {
Array<Stmt> new_seq; Array<Stmt> new_seq;
int pending_load_count = 0; int pending_load_count = 0;
...@@ -89,16 +207,16 @@ public: ...@@ -89,16 +207,16 @@ public:
if (ContainsMMA(stmt)) { if (ContainsMMA(stmt)) {
if (pending_load_count > 0) { if (pending_load_count > 0) {
// 判断是否是该序列中最后一个 Fence int fence_val =
int fence_val = (static_cast<int>(i) == last_fence_idx) ? 0 : pending_load_count; (global_last_fence_mma_.defined() && stmt.same_as(global_last_fence_mma_))
? 0
: pending_load_count;
Array<PrimExpr> args = {Integer(fence_val)}; Array<PrimExpr> args = {Integer(fence_val)};
// 构造 Fence
auto fence_call = Call(DataType::Void(), Op::Get("tl.async_gld_fence"), args); auto fence_call = Call(DataType::Void(), Op::Get("tl.async_gld_fence"), args);
new_seq.push_back(Evaluate(fence_call)); new_seq.push_back(Evaluate(fence_call));
// 构造 Barrier
auto barrier_call = Call(DataType::Void(), Op::Get("tl.wave_barrier"), {}); auto barrier_call = Call(DataType::Void(), Op::Get("tl.wave_barrier"), {});
new_seq.push_back(Evaluate(barrier_call)); new_seq.push_back(Evaluate(barrier_call));
...@@ -114,15 +232,17 @@ public: ...@@ -114,15 +232,17 @@ public:
} }
return SeqStmt(new_seq); return SeqStmt(new_seq);
} }
private:
Stmt global_last_fence_mma_;
}; };
// 3. Pass 包装
namespace transform { namespace transform {
using namespace tir::transform; using namespace tir::transform;
Pass InsertAsyncMMAFence() { Pass InsertAsyncMMAFence() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite(); auto* n = f.CopyOnWrite();
MMABarrierMutator mutator; MMABarrierMutator mutator(n->body);
n->body = mutator(n->body); n->body = mutator(n->body);
return f; return f;
}; };
......
...@@ -20,9 +20,6 @@ using namespace tir; ...@@ -20,9 +20,6 @@ using namespace tir;
using ffi::Array; using ffi::Array;
using ffi::String; using ffi::String;
// ============================================================================
// 数据结构
// ============================================================================
struct CopyInfo { struct CopyInfo {
Buffer dst_buffer; Buffer dst_buffer;
Buffer src_buffer; Buffer src_buffer;
...@@ -33,10 +30,7 @@ struct CopyInfo { ...@@ -33,10 +30,7 @@ struct CopyInfo {
struct CollectResult { struct CollectResult {
std::vector<CopyInfo> copies; std::vector<CopyInfo> copies;
// 映射: Global Buffer Name -> DCU Resource Var (用于替换Store)
std::unordered_map<String, Var> global_to_res_var; std::unordered_map<String, Var> global_to_res_var;
// 映射: Shared Buffer Name -> 要注入的LetStmt绑定 (Var, PrimExpr)
// 这样我们就可以根据 shared buffer 的位置来决定注入点
std::unordered_map<String, std::pair<Var, PrimExpr>> shared_alloc_to_binding; std::unordered_map<String, std::pair<Var, PrimExpr>> shared_alloc_to_binding;
const StmtNode* inject_target = nullptr; const StmtNode* inject_target = nullptr;
...@@ -64,7 +58,6 @@ class VariableKeeper : public tvm::tir::ExprMutator { ...@@ -64,7 +58,6 @@ class VariableKeeper : public tvm::tir::ExprMutator {
: keep_vars_(keep_vars) {} : keep_vars_(keep_vars) {}
PrimExpr VisitExpr_(const tvm::tir::VarNode* op) override { PrimExpr VisitExpr_(const tvm::tir::VarNode* op) override {
// 关键调试:打印每一个遇到的变量及其地址
if (keep_vars_.count(op)) { if (keep_vars_.count(op)) {
return GetRef<PrimExpr>(op); return GetRef<PrimExpr>(op);
} else { } else {
...@@ -72,10 +65,7 @@ class VariableKeeper : public tvm::tir::ExprMutator { ...@@ -72,10 +65,7 @@ class VariableKeeper : public tvm::tir::ExprMutator {
} }
} }
// 额外处理:防止 Load 节点中的变量丢失
PrimExpr VisitExpr_(const tvm::tir::BufferLoadNode* op) override { PrimExpr VisitExpr_(const tvm::tir::BufferLoadNode* op) override {
// 如果你的索引里嵌套了 BufferLoad,Load 本身不是 Var,
// 但它里面可能含有 Var。Mutator 默认会递归,但我们可以显式打印。
return ExprMutator::VisitExpr_(op); return ExprMutator::VisitExpr_(op);
} }
...@@ -83,9 +73,6 @@ class VariableKeeper : public tvm::tir::ExprMutator { ...@@ -83,9 +73,6 @@ class VariableKeeper : public tvm::tir::ExprMutator {
const std::unordered_set<const tvm::tir::VarNode*>& keep_vars_; const std::unordered_set<const tvm::tir::VarNode*>& keep_vars_;
}; };
// ============================================================================
// Phase 1: 收集拷贝信息 & 生成资源绑定
// ============================================================================
CollectResult CollectResources(const Stmt& body) { CollectResult CollectResources(const Stmt& body) {
class Collector : public StmtExprVisitor { class Collector : public StmtExprVisitor {
public: public:
...@@ -94,7 +81,7 @@ CollectResult CollectResources(const Stmt& body) { ...@@ -94,7 +81,7 @@ CollectResult CollectResources(const Stmt& body) {
private: private:
bool in_async{false}; bool in_async{false};
std::unordered_set<const tvm::tir::VarNode*> loop_vars_; std::unordered_set<const tvm::tir::VarNode*> loop_vars_;
std::vector<const tvm::tir::StmtNode*> scope_stack_; // 追踪当前遍历的 AST 路径 std::vector<const tvm::tir::StmtNode*> scope_stack_;
bool IsSharedScope(const Buffer& buf) { bool IsSharedScope(const Buffer& buf) {
auto s = buf.scope(); auto s = buf.scope();
return s == "shared" || s == "shared.dyn"; return s == "shared" || s == "shared.dyn";
...@@ -107,7 +94,6 @@ CollectResult CollectResources(const Stmt& body) { ...@@ -107,7 +94,6 @@ CollectResult CollectResources(const Stmt& body) {
void VisitStmt_(const AttrStmtNode* attr) override { void VisitStmt_(const AttrStmtNode* attr) override {
scope_stack_.push_back(attr); scope_stack_.push_back(attr);
if (attr->attr_key == tir::attr::thread_extent) { if (attr->attr_key == tir::attr::thread_extent) {
// 1. 获取 IterVar
auto iv = attr->node.as<tvm::tir::IterVarNode>(); auto iv = attr->node.as<tvm::tir::IterVarNode>();
const std::string& tag = iv->thread_tag; const std::string& tag = iv->thread_tag;
...@@ -119,7 +105,6 @@ CollectResult CollectResources(const Stmt& body) { ...@@ -119,7 +105,6 @@ CollectResult CollectResources(const Stmt& body) {
loop_vars_.erase(thread_var.get()); loop_vars_.erase(thread_var.get());
} else { } else {
// 如果是 blockIdx 或其他,直接跳过当前层继续往下走
StmtExprVisitor::VisitStmt_(attr); StmtExprVisitor::VisitStmt_(attr);
} }
...@@ -173,7 +158,6 @@ CollectResult CollectResources(const Stmt& body) { ...@@ -173,7 +158,6 @@ CollectResult CollectResources(const Stmt& body) {
if (scope_stack_[i]->IsInstance<AttrStmtNode>()) { if (scope_stack_[i]->IsInstance<AttrStmtNode>()) {
auto attr = static_cast<const AttrStmtNode*>(scope_stack_[i]); auto attr = static_cast<const AttrStmtNode*>(scope_stack_[i]);
if (attr->attr_key == tvm::tir::attr::thread_extent) { if (attr->attr_key == tvm::tir::attr::thread_extent) {
// 找到了最内层的线程绑定。它里面的下一个节点(i+1)就是我们应该包裹的节点
if (i + 1 < scope_stack_.size()) { if (i + 1 < scope_stack_.size()) {
result.inject_target = scope_stack_[i + 1]; result.inject_target = scope_stack_[i + 1];
} }
...@@ -189,12 +173,10 @@ CollectResult CollectResources(const Stmt& body) { ...@@ -189,12 +173,10 @@ CollectResult CollectResources(const Stmt& body) {
} }
} }
} }
// 如果还是空,直接 fallback 到当前操作
if (result.inject_target == nullptr) result.inject_target = op; if (result.inject_target == nullptr) result.inject_target = op;
} }
// 1. 记录拷贝
VariableKeeper keeper(loop_vars_); VariableKeeper keeper(loop_vars_);
tvm::arith::Analyzer analyzer; tvm::arith::Analyzer analyzer;
Array<PrimExpr> for_var_only_indices; Array<PrimExpr> for_var_only_indices;
...@@ -206,7 +188,6 @@ CollectResult CollectResources(const Stmt& body) { ...@@ -206,7 +188,6 @@ CollectResult CollectResources(const Stmt& body) {
CopyInfo info{dst, src, op->indices, for_var_only_indices, GetRef<Stmt>(op)}; CopyInfo info{dst, src, op->indices, for_var_only_indices, GetRef<Stmt>(op)};
result.copies.push_back(info); result.copies.push_back(info);
// 2. 只有当没处理过这个 Global Buffer 时才生成 Binding
if (result.global_to_res_var.find(src->name) == result.global_to_res_var.end()) { if (result.global_to_res_var.find(src->name) == result.global_to_res_var.end()) {
Var var(src->name + "_dcu_res", DataType::Int(32, 4)); Var var(src->name + "_dcu_res", DataType::Int(32, 4));
...@@ -214,17 +195,13 @@ CollectResult CollectResources(const Stmt& body) { ...@@ -214,17 +195,13 @@ CollectResult CollectResources(const Stmt& body) {
tvm::arith::Analyzer analyzer; tvm::arith::Analyzer analyzer;
Array<PrimExpr> base_indices; Array<PrimExpr> base_indices;
for (const auto& idx : load->indices) { for (const auto& idx : load->indices) {
// 将所有外层循环变量 (k, i 等) 全部替换为 0
PrimExpr no_loops = eliminator(idx); PrimExpr no_loops = eliminator(idx);
// 化简出最终的基地址表达式
base_indices.push_back(analyzer.Simplify(no_loops)); base_indices.push_back(analyzer.Simplify(no_loops));
} }
// ✅ 关键点:填充真实的地址信息 src->data (即 A.data)
Array<PrimExpr> args; Array<PrimExpr> args;
args.push_back(src->data); // 先加 data args.push_back(src->data);
// 如果需要把 indices 的每个元素作为独立参数展开:
for (const auto& idx : base_indices) { for (const auto& idx : base_indices) {
args.push_back(idx); args.push_back(idx);
} }
...@@ -232,7 +209,6 @@ CollectResult CollectResources(const Stmt& body) { ...@@ -232,7 +209,6 @@ CollectResult CollectResources(const Stmt& body) {
Op::Get("tl.make_dcu_resource"), args); Op::Get("tl.make_dcu_resource"), args);
result.global_to_res_var[src->name] = var; result.global_to_res_var[src->name] = var;
// 将这个绑定关系和 destination 的 shared buffer 绑死
result.shared_alloc_to_binding[src->name] = {var, val}; result.shared_alloc_to_binding[src->name] = {var, val};
} }
} }
...@@ -247,9 +223,6 @@ CollectResult CollectResources(const Stmt& body) { ...@@ -247,9 +223,6 @@ CollectResult CollectResources(const Stmt& body) {
return col.result; return col.result;
} }
// ============================================================================
// Phase 2: 替换 BufferStore -> dcu_async_copy
// ============================================================================
class StoreReplacer : public StmtExprMutator { class StoreReplacer : public StmtExprMutator {
public: public:
static Stmt Run(Stmt body, const std::vector<CopyInfo>& copies, static Stmt Run(Stmt body, const std::vector<CopyInfo>& copies,
...@@ -268,16 +241,14 @@ private: ...@@ -268,16 +241,14 @@ private:
auto body = this->VisitStmt(attr->body); auto body = this->VisitStmt(attr->body);
return body; return body;
} }
return StmtMutator::VisitStmt_(attr); // ③ 其他属性:默认保留 return StmtMutator::VisitStmt_(attr);
} }
Stmt VisitStmt_(const BufferStoreNode* op) final { Stmt VisitStmt_(const BufferStoreNode* op) final {
for (const auto& copy : copies_) { for (const auto& copy : copies_) {
if (copy.store_stmt.same_as(GetRef<Stmt>(op))) { if (copy.store_stmt.same_as(GetRef<Stmt>(op))) {
// Global 取 resource var (A_dcu_res)
Var src_res = global_to_var_.at(copy.src_buffer->name); Var src_res = global_to_var_.at(copy.src_buffer->name);
// Shared 取 data pointer (A_shared.data)
PrimExpr dst_res = copy.dst_buffer->data; PrimExpr dst_res = copy.dst_buffer->data;
PrimExpr copy_size = IntImm(DataType::Int(32), 1); PrimExpr copy_size = IntImm(DataType::Int(32), 1);
...@@ -305,9 +276,6 @@ private: ...@@ -305,9 +276,6 @@ private:
const std::unordered_map<String, Var>& global_to_var_; const std::unordered_map<String, Var>& global_to_var_;
}; };
// ============================================================================
// Phase 3: 根据 Shared Alloc 位置进行精准注入
// ============================================================================
class ResourceInjector : public tvm::tir::StmtExprMutator { class ResourceInjector : public tvm::tir::StmtExprMutator {
public: public:
static Stmt Run(Stmt body, static Stmt Run(Stmt body,
...@@ -324,18 +292,15 @@ private: ...@@ -324,18 +292,15 @@ private:
: bindings_(bindings), target_(target) {} : bindings_(bindings), target_(target) {}
Stmt VisitStmt(const Stmt& stmt) override { Stmt VisitStmt(const Stmt& stmt) override {
// 当我们遍历到刚才标记的那个 AST 节点时
if (stmt.get() == target_) { if (stmt.get() == target_) {
// 先向下遍历(保持 TVM Mutator 的习惯)
Stmt new_stmt = StmtExprMutator::VisitStmt(stmt); Stmt new_stmt = StmtExprMutator::VisitStmt(stmt);
// 在这个节点的外面套上所有的 LetStmt
for (const auto& item : bindings_) { for (const auto& item : bindings_) {
Var res_var = item.second.first; Var res_var = item.second.first;
PrimExpr init_expr = item.second.second; PrimExpr init_expr = item.second.second;
new_stmt = tvm::tir::LetStmt(res_var, init_expr, new_stmt); new_stmt = tvm::tir::LetStmt(res_var, init_expr, new_stmt);
} }
return new_stmt; // 返回包裹好的新节点 return new_stmt;
} }
return StmtExprMutator::VisitStmt(stmt); return StmtExprMutator::VisitStmt(stmt);
} }
...@@ -344,26 +309,18 @@ private: ...@@ -344,26 +309,18 @@ private:
const tvm::tir::StmtNode* target_; const tvm::tir::StmtNode* target_;
}; };
// ============================================================================
// Pass 入口
// ============================================================================
PrimFunc LowerSharedGlobalCopy(PrimFunc f) { PrimFunc LowerSharedGlobalCopy(PrimFunc f) {
auto* n = f.CopyOnWrite(); auto* n = f.CopyOnWrite();
// 收集信息
auto res = CollectResources(n->body); auto res = CollectResources(n->body);
if (res.copies.empty()){ if (res.copies.empty()){
return f; return f;
} }
// 注入res声明
Stmt injected = ResourceInjector::Run(n->body, res.shared_alloc_to_binding, res.inject_target); Stmt injected = ResourceInjector::Run(n->body, res.shared_alloc_to_binding, res.inject_target);
// 替换拷贝语句
Stmt replaced = StoreReplacer::Run(injected, res.copies, res.global_to_res_var); Stmt replaced = StoreReplacer::Run(injected, res.copies, res.global_to_res_var);
// 写回
n->body = std::move(replaced); n->body = std::move(replaced);
return GetRef<PrimFunc>(n); return GetRef<PrimFunc>(n);
......
...@@ -31,7 +31,6 @@ private: ...@@ -31,7 +31,6 @@ private:
PrimExpr k_extent_; PrimExpr k_extent_;
bool in_unrolled_i_ = false; bool in_unrolled_i_ = false;
// 通用的步长提取函数:从 expr 中提取指定 var 的步长,并返回剩余的 base
std::pair<PrimExpr, PrimExpr> ExtractStride(PrimExpr expr, Var var) { std::pair<PrimExpr, PrimExpr> ExtractStride(PrimExpr expr, Var var) {
if (!var.defined()) return {expr, make_zero(expr->dtype)}; if (!var.defined()) return {expr, make_zero(expr->dtype)};
...@@ -41,33 +40,26 @@ private: ...@@ -41,33 +40,26 @@ private:
return {analyzer_.Simplify(base), stride}; return {analyzer_.Simplify(base), stride};
} }
// 核心重写逻辑
Stmt RewriteAsyncCopy(const CallNode* call, Var i_var, PrimExpr i_extent) { Stmt RewriteAsyncCopy(const CallNode* call, Var i_var, PrimExpr i_extent) {
// 1. 预处理:剥离 RampNode 获得基础偏移
PrimExpr raw_dst_off = call->args[1]; PrimExpr raw_dst_off = call->args[1];
PrimExpr raw_src_off = call->args[3]; PrimExpr raw_src_off = call->args[3];
if (const RampNode* r = raw_dst_off.as<RampNode>()) raw_dst_off = r->base; if (const RampNode* r = raw_dst_off.as<RampNode>()) raw_dst_off = r->base;
if (const RampNode* r = raw_src_off.as<RampNode>()) raw_src_off = r->base; if (const RampNode* r = raw_src_off.as<RampNode>()) raw_src_off = r->base;
// 2. 提取 i 的步长
auto [base_i_dst, i_stride_dst] = ExtractStride(raw_dst_off, i_var); auto [base_i_dst, i_stride_dst] = ExtractStride(raw_dst_off, i_var);
auto [base_i_src, i_stride_src] = ExtractStride(raw_src_off, i_var); auto [base_i_src, i_stride_src] = ExtractStride(raw_src_off, i_var);
// 3. 提取 k 的步长 (始终尝试从 base_i_src 中提取,无论是否存在 i 循环)
// 只要 k_var_ 在外层循环中定义了,这里就能提取出非 0 的步长
auto [final_src_offset, k_stride_src] = ExtractStride(base_i_src, k_var_); auto [final_src_offset, k_stride_src] = ExtractStride(base_i_src, k_var_);
// 构造最初要求的 8 个参数:
// [dst, dst_off, src, src_off, i_extent, i_stride_dst, i_stride_src, k_stride_src]
Array<PrimExpr> new_args = { Array<PrimExpr> new_args = {
call->args[0], // dst_buf call->args[0],
base_i_dst, // 最终 dst 偏移 base_i_dst,
call->args[2], // src_buf call->args[2],
final_src_offset, // 最终 src 偏移 final_src_offset,
i_extent, // i 循环次数 (无循环时为 0) i_extent,
i_stride_dst, // i 的 dst 步长 (无循环时为 0) i_stride_dst,
i_stride_src, // i 的 src 步长 (无循环时为 0) i_stride_src,
k_stride_src // k 的 src 步长 (即便无 i 循环,这里也能拿到 k 的步长) k_stride_src
}; };
return Evaluate(Call(call->dtype, call->op, new_args)); return Evaluate(Call(call->dtype, call->op, new_args));
...@@ -78,7 +70,6 @@ private: ...@@ -78,7 +70,6 @@ private:
if (!in_unrolled_i_) { if (!in_unrolled_i_) {
if (const CallNode* call = op->value.as<CallNode>()) { if (const CallNode* call = op->value.as<CallNode>()) {
static const Op& dcu_copy_op = Op::Get("tl.dcu_async_copy"); static const Op& dcu_copy_op = Op::Get("tl.dcu_async_copy");
// 只要参数个数不是 8 (我们重写后的目标个数),就进行处理
if (call->op.same_as(dcu_copy_op) && call->args.size() != 8) { if (call->op.same_as(dcu_copy_op) && call->args.size() != 8) {
return RewriteAsyncCopy(call, Var(), make_zero(DataType::Int(32))); return RewriteAsyncCopy(call, Var(), make_zero(DataType::Int(32)));
} }
...@@ -108,7 +99,6 @@ private: ...@@ -108,7 +99,6 @@ private:
if (const CallNode* call = eval->value.as<CallNode>()) { if (const CallNode* call = eval->value.as<CallNode>()) {
static const Op& dcu_copy_op = Op::Get("tl.dcu_async_copy"); static const Op& dcu_copy_op = Op::Get("tl.dcu_async_copy");
if (call->op.same_as(dcu_copy_op)) { if (call->op.same_as(dcu_copy_op)) {
// 还原 k 并在返回前处理重写
Stmt result = RewriteAsyncCopy(call, op->loop_var, op->extent); Stmt result = RewriteAsyncCopy(call, op->loop_var, op->extent);
return result; return result;
} }
...@@ -116,7 +106,6 @@ private: ...@@ -116,7 +106,6 @@ private:
} }
} }
// 退出循环时清理 k 信息
if (is_k) { if (is_k) {
k_var_ = Var(); k_var_ = Var();
k_extent_ = PrimExpr(); k_extent_ = PrimExpr();
...@@ -128,9 +117,7 @@ private: ...@@ -128,9 +117,7 @@ private:
return Stmt(n); return Stmt(n);
} }
}; };
// ============================================================================
// Pass 入口
// ============================================================================
PrimFunc SimplifyDCUAsyncCopy(PrimFunc f) { PrimFunc SimplifyDCUAsyncCopy(PrimFunc f) {
auto* n = f.CopyOnWrite(); auto* n = f.CopyOnWrite();
n->body = AsyncCopySimplifier::Run(std::move(n->body)); n->body = AsyncCopySimplifier::Run(std::move(n->body));
......
...@@ -222,25 +222,16 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -222,25 +222,16 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.PipelinePlanning()(mod) mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.RegisterPipelinePlanning()(mod) mod = tilelang.transform.RegisterPipelinePlanning()(mod)
print("OptimizeForTarget")
print(mod)
mod = tilelang.transform.InjectRegisterSoftwarePipeline()(mod)
print("OptimizeForTarget2")
print(mod) mod = tilelang.transform.InjectRegisterSoftwarePipeline()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod) mod = tilelang.transform.InjectSoftwarePipeline()(mod)
print("OptimizeForTarget2")
print(mod)
mod = tilelang.transform.MergeIfStmt()(mod) mod = tilelang.transform.MergeIfStmt()(mod)
if allow_fence_proxy(target=target): if allow_fence_proxy(target=target):
# in hopper device, wgmma is an async proxy # in hopper device, wgmma is an async proxy
# so we need to inject a fence proxy before it # so we need to inject a fence proxy before it
mod = tilelang.transform.InjectFenceProxy()(mod) mod = tilelang.transform.InjectFenceProxy()(mod)
print("OptimizeForTarget2.5")
print(mod)
mod = tilelang.transform.LowerOpaqueBlock()(mod) mod = tilelang.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.Simplify()(mod) mod = tilelang.transform.Simplify()(mod)
mod = tir.transform.NarrowDataType(32)(mod) mod = tir.transform.NarrowDataType(32)(mod)
...@@ -311,8 +302,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -311,8 +302,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
if dcu_async_copy_supported(target): if dcu_async_copy_supported(target):
print("--------------support dcu async copy------------------") print("--------------support dcu async copy------------------")
mod = tilelang.transform.LowerSharedGlobalCopy()(mod) mod = tilelang.transform.LowerSharedGlobalCopy()(mod)
print("222222222")
print(mod)
mod = tilelang.transform.FixDCUWaitCount()(mod) mod = tilelang.transform.FixDCUWaitCount()(mod)
mod = tilelang.transform.InjectBLocalLayoutTransform()(mod) mod = tilelang.transform.InjectBLocalLayoutTransform()(mod)
print("InjectBLocalLayoutTransform ............") print("InjectBLocalLayoutTransform ............")
...@@ -321,8 +310,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -321,8 +310,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
print("InjectDSRead ............") print("InjectDSRead ............")
print(mod) print(mod)
mod = tilelang.transform.InsertAsyncMMAFence()(mod) mod = tilelang.transform.InsertAsyncMMAFence()(mod)
print("333333333")
print(mod)
# Register pipeline planning only writes software_pipeline annotations. # Register pipeline planning only writes software_pipeline annotations.
# We must inject after planning so prologue/body/epilogue are materialized. # We must inject after planning so prologue/body/epilogue are materialized.
# mod = tilelang.transform.RegisterPipelinePlanning()(mod) # mod = tilelang.transform.RegisterPipelinePlanning()(mod)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment