"git@developer.sourcefind.cn:OpenDAS/openfold.git" did not exist on "54ec5c4cbec31473935a899fa3c03e732d393866"
Commit dd91b1e0 authored by qisan's avatar qisan
Browse files

Feats: vectorize async copy

parent 41887aed
...@@ -10,7 +10,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.fl ...@@ -10,7 +10,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.fl
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=512) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......
...@@ -389,7 +389,7 @@ TIR_DEFINE_TL_BUILTIN(__ldg).set_num_inputs(-1).set_attr<TCallEffectKind>( ...@@ -389,7 +389,7 @@ TIR_DEFINE_TL_BUILTIN(__ldg).set_num_inputs(-1).set_attr<TCallEffectKind>(
// //
TIR_DEFINE_TL_BUILTIN(dcu_async_copy) TIR_DEFINE_TL_BUILTIN(dcu_async_copy)
.set_num_inputs(6) .set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
......
...@@ -793,8 +793,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -793,8 +793,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
<< ", " << condition << ");\n"; << ", " << condition << ");\n";
} }
} else if (op->op.same_as(builtin::ptx_commit_group())) { } else if (op->op.same_as(builtin::ptx_commit_group())) {
printf("[DEBUG VisitExpr_] Branch: ptx_commit_group\n"); ;
print_extern_call_stmt("tl::cp_async_commit"); // print_extern_call_stmt("tl::cp_async_commit");
} else if (op->op.same_as(builtin::ptx_wait_group())) { } else if (op->op.same_as(builtin::ptx_wait_group())) {
printf("[DEBUG VisitExpr_] Branch: ptx_wait_group\n"); printf("[DEBUG VisitExpr_] Branch: ptx_wait_group\n");
int n = Downcast<IntImm>(op->args[0])->value; int n = Downcast<IntImm>(op->args[0])->value;
...@@ -1103,42 +1103,51 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1103,42 +1103,51 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
} }
else if (op->op.same_as(Op::Get("tl.dcu_async_copy"))) { else if (op->op.same_as(Op::Get("tl.dcu_async_copy"))) {
// 1. 提取模板参数 (IntImm 直接取值) auto get_base_expr = [this](const PrimExpr& e) -> std::string {
if (const auto* ramp = e.as<tvm::tir::RampNode>()) {
// 如果是 Ramp,只打印它的起始位置 (base)
return this->PrintExpr(ramp->base);
}
// 否则正常打印
return this->PrintExpr(e);
};
// 辅助函数:尝试获取整数常量
auto get_int_const = [](const PrimExpr& e) -> int { auto get_int_const = [](const PrimExpr& e) -> int {
if (const auto* val = e.as<IntImmNode>()) return static_cast<int>(val->value); if (const auto* val = e.as<IntImmNode>()) return static_cast<int>(val->value);
return 0; return 0;
}; };
int N = 16; // 1. 静态模板参数 (按要求仅保留 N 和 smem_offset)
int smem_offset = 0; int N = 16;
int load_count = get_int_const(op->args[4]);
int i_sstride = get_int_const(op->args[5]);
int i_gstride = get_int_const(op->args[6]);
int k_gstride = get_int_const(op->args[7]);
// 2. 将运行时参数打印到字符串中 (防止直接操作 stream 导致冲突) // 2. 解析 IR 参数
std::string dst_ptr = this->PrintExpr(op->args[0]); // args[0]: dst_ptr (buf_dyn_shmem)
std::string dst_off = this->PrintExpr(op->args[1]); // args[1]: dst_ramp (T.Ramp...)
std::string src_res = this->PrintExpr(op->args[2]); // args[2]: src_res (A_dcu_res)
std::string src_off = this->PrintExpr(op->args[3]); // args[3]: src_ramp (T.Ramp...)
// args[4]: load_count (1)
std::string dst_ptr = this->PrintExpr(op->args[0]);
// 使用新定义的 get_base_expr 避开 lanes > 4 的检查
std::string dst_off = get_base_expr(op->args[1]);
std::string src_res = this->PrintExpr(op->args[2]);
std::string src_off = get_base_expr(op->args[3]);
// 3. 仿照范例进行流输出 // 3. 生成输出
this->PrintIndent(); this->PrintIndent();
// 模板参数仅保留 N, smem_offset 和动态提取的 load_count
this->stream << "tl::cp_async_gs<" this->stream << "tl::cp_async_gs<"
<< N << ", " << N << ">(";
<< smem_offset << ", "
<< load_count << ", "
<< i_sstride << ", "
<< i_gstride << ", "
<< k_gstride << ">(";
// 拼接第一个参数:(char*)dst + dst_off // 打印函数参数
// 处理目标地址: ((char*)ptr + offset)
this->stream << "((char*)" << dst_ptr << " + " << dst_off << "), "; this->stream << "((char*)" << dst_ptr << " + " << dst_off << "), ";
// 拼接第二个参数:src_res // 打印源资源指针
this->stream << src_res << ", "; this->stream << src_res << ", ";
// 拼接第三个参数:src_off // 打印源偏移
this->stream << src_off << ");\n"; this->stream << src_off << ");\n";
} }
else { else {
......
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/analysis.h>
using namespace tvm::tir;
using tvm::ffi::GetRef;
using tvm::ffi::make_object;
namespace tvm {
namespace tl {
using namespace tir;
using ffi::Array;
using ffi::String;
class ROCmWaitCountRewriter : public StmtMutator {
public:
static Stmt Substitute(Stmt stmt) {
return ROCmWaitCountRewriter()(stmt);
}
private:
// 辅助函数:统计一个代码块内 async 指令的总数
int CountAsyncOps(const Stmt& stmt) {
int total_count = 0;
struct Visitor : public StmtExprVisitor {
int count = 0;
void VisitStmt_(const ForNode* op) override {
// 如果内部还有循环(比如 T.unroll),需要乘上循环次数
int current_count = count;
count = 0;
StmtExprVisitor::VisitStmt_(op);
int loop_count = 0;
if (const auto* extent = op->extent.as<IntImmNode>()) {
loop_count = static_cast<int>(extent->value);
} else {
// 如果是非固定长度循环,这在流水线中很少见,默认按1处理或报警
loop_count = 1;
}
int body_count = count;
count = current_count + (body_count * loop_count);
}
void VisitExpr_(const CallNode* op) override {
// 识别 ptx_cp_async 或对应的异步访存 Op
if (op->op.same_as(builtin::ptx_cp_async()) ||
op->op.same_as(Op::Get("tl.dcu_async_copy"))) {
LOG(INFO) << "Found async copy: " << GetRef<Call>(op);
count++;
}
StmtExprVisitor::VisitExpr_(op);
}
// 兼容某些实现中把 cp_async 放在 Evaluate 里的情况
void VisitStmt_(const EvaluateNode* op) override {
StmtExprVisitor::VisitStmt_(op);
}
} visitor;
visitor(stmt);
return visitor.count;
}
Stmt VisitStmt_(const ForNode* op) override {
// 1. 我们假设流水线的主循环是核心作用域
// 先扫描该循环体内部每一轮会发出多少个 async 操作
int ops_per_iter = CountAsyncOps(op->body);
// 如果没有异步操作,直接跳过
if (ops_per_iter == 0) return StmtMutator::VisitStmt_(op);
// 2. 进入循环内部进行修改,记录当前的倍数
int old_multiplier = multiplier_;
multiplier_ = ops_per_iter;
Stmt new_body = this->VisitStmt(op->body);
multiplier_ = old_multiplier;
if (new_body.same_as(op->body)) return GetRef<Stmt>(op);
auto n = CopyOnWrite(op);
n->body = std::move(new_body);
return Stmt(n);
}
Stmt VisitStmt_(const AttrStmtNode* op) override {
if (op->attr_key == "async_wait_inflight_count" && multiplier_ > 0) {
// 获取原有的 wait 组数 (比如 1)
if (auto int_imm = op->value.as<IntImmNode>()) {
// 计算 ROCm 的指令数: N_groups * Ops_per_group
int64_t new_cont = int_imm->value * multiplier_;
LOG(INFO) << "Original wait count: " << new_cont << ", async ops per iter: " << multiplier_;
// 返回修改后的节点
return AttrStmt(op->node, op->attr_key, make_const(DataType::Int(32), new_cont), op->body);
}
}
return StmtMutator::VisitStmt_(op);
}
int multiplier_ = 0; // 当前作用域下的指令倍率
};
// 包装成标准的 TVM Pass
namespace transform {
using namespace tir::transform;
tvm::transform::Pass FixDCUWaitCount() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = ROCmWaitCountRewriter::Substitute(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "FixDCUWaitCount", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
tvm::ffi::reflection::GlobalDef().def("tl.transform.FixDCUWaitCount", FixDCUWaitCount);
}
} // namespace transform
} // namespace tl
} // namespace tvm
\ No newline at end of file
...@@ -148,6 +148,7 @@ private: ...@@ -148,6 +148,7 @@ private:
new_args.Set(i + 1, new_index); new_args.Set(i + 1, new_index);
} }
} }
LOG(INFO) << "Rewriting buffer access " << call << " to " << Call(call->dtype, call->op, new_args, call->span);
return Call(call->dtype, call->op, new_args, call->span); return Call(call->dtype, call->op, new_args, call->span);
} }
...@@ -166,6 +167,7 @@ private: ...@@ -166,6 +167,7 @@ private:
for (const Buffer &alloc_buffer : op->alloc_buffers) { for (const Buffer &alloc_buffer : op->alloc_buffers) {
buffer_data_to_buffer_.erase(alloc_buffer->data); buffer_data_to_buffer_.erase(alloc_buffer->data);
} }
LOG(INFO) << "Rewriting block " << GetRef<Block>(op) << " to " << GetRef<Block>(n);
return block; return block;
} }
...@@ -309,6 +311,7 @@ public: ...@@ -309,6 +311,7 @@ public:
} }
Block block = MakeBlock(stmt, buffer_data_to_buffer_); Block block = MakeBlock(stmt, buffer_data_to_buffer_);
block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers); block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers);
LOG(INFO) << "Final rewritten pipeline block: " << block;
return BlockRealize({}, Bool(true), block); return BlockRealize({}, Bool(true), block);
} }
...@@ -631,6 +634,9 @@ private: ...@@ -631,6 +634,9 @@ private:
n->body = AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id, n->body = AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
AttrStmt(zero, tir::attr::async_wait_inflight_count, AttrStmt(zero, tir::attr::async_wait_inflight_count,
pw.wait_count, n->body)); pw.wait_count, n->body));
LOG(INFO) << "Inserting async_wait with count " << pw.wait_count
<< " before block with order " << new_blocks[pw.insert_before].order
<< " for async stage " << stage_id;
} }
} }
...@@ -1102,6 +1108,7 @@ private: ...@@ -1102,6 +1108,7 @@ private:
buffer_data_to_buffer_.erase(buffer->data); buffer_data_to_buffer_.erase(buffer->data);
} }
} }
LOG(INFO) << "Finished rewriting the pipeline loop with body:\n" << pipeline;
return pipeline; return pipeline;
} }
...@@ -1121,6 +1128,7 @@ private: ...@@ -1121,6 +1128,7 @@ private:
for (const auto &buffer : op->alloc_buffers) { for (const auto &buffer : op->alloc_buffers) {
buffer_data_to_buffer_.erase(buffer->data); buffer_data_to_buffer_.erase(buffer->data);
} }
LOG(INFO) << "Rewriting blockddd " << block;
return block; return block;
} }
...@@ -1158,6 +1166,8 @@ tir::transform::Pass InjectSoftwarePipeline() { ...@@ -1158,6 +1166,8 @@ tir::transform::Pass InjectSoftwarePipeline() {
auto *fptr = f.CopyOnWrite(); auto *fptr = f.CopyOnWrite();
fptr->body = software_pipeline::PipelineInjector::Inject(f); fptr->body = software_pipeline::PipelineInjector::Inject(f);
fptr->body = ConvertSSA(std::move(fptr->body)); fptr->body = ConvertSSA(std::move(fptr->body));
LOG(INFO) << "Finished injecting software pipeline for PrimFunc " << f->GetAttr<String>(tvm::attr::kGlobalSymbol).value_or("<unknown>")
<< ", the transformed body is:\n" << fptr->body;
return f; return f;
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {}); return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {});
......
...@@ -94,6 +94,7 @@ CollectResult CollectResources(const Stmt& body) { ...@@ -94,6 +94,7 @@ CollectResult CollectResources(const Stmt& body) {
CollectResult result; CollectResult result;
private: private:
bool in_async{false};
std::unordered_set<const tvm::tir::VarNode*> loop_vars_; std::unordered_set<const tvm::tir::VarNode*> loop_vars_;
std::vector<const tvm::tir::StmtNode*> scope_stack_; // 追踪当前遍历的 AST 路径 std::vector<const tvm::tir::StmtNode*> scope_stack_; // 追踪当前遍历的 AST 路径
bool IsSharedScope(const Buffer& buf) { bool IsSharedScope(const Buffer& buf) {
...@@ -105,27 +106,36 @@ CollectResult CollectResources(const Stmt& body) { ...@@ -105,27 +106,36 @@ CollectResult CollectResources(const Stmt& body) {
return s == "global" || s == ""; return s == "global" || s == "";
} }
void VisitStmt_(const AttrStmtNode* op) override { void VisitStmt_(const AttrStmtNode* attr) override {
scope_stack_.push_back(op); scope_stack_.push_back(attr);
if (op->attr_key == tvm::tir::attr::thread_extent) { if (attr->attr_key == tir::attr::thread_extent) {
// 1. 获取 IterVar // 1. 获取 IterVar
auto iv = op->node.as<tvm::tir::IterVarNode>(); auto iv = attr->node.as<tvm::tir::IterVarNode>();
const std::string& tag = iv->thread_tag; const std::string& tag = iv->thread_tag;
// 2. 只有当 tag 包含 "threadIdx" 时才加入 (过滤掉 blockIdx)
// 比如: "threadIdx.x", "threadIdx.y", "threadIdx.z"
if (tag.find("threadIdx") != std::string::npos) { if (tag.find("threadIdx") != std::string::npos) {
tvm::tir::Var thread_var = iv->var; tvm::tir::Var thread_var = iv->var;
LOG(INFO) << "Entering thread scope: " << tag << " with var " << thread_var->name_hint;
loop_vars_.insert(thread_var.get()); loop_vars_.insert(thread_var.get());
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(attr);
loop_vars_.erase(thread_var.get()); loop_vars_.erase(thread_var.get());
} else { } else {
// 如果是 blockIdx 或其他,直接跳过当前层继续往下走 // 如果是 blockIdx 或其他,直接跳过当前层继续往下走
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(attr);
} }
} else if (attr->attr_key == tir::attr::async_scope) {
ICHECK(in_async == false) << "Nested async scopes not supported";
in_async = true;
StmtExprVisitor::VisitStmt_(attr);
in_async = false;
}
else {
StmtExprVisitor::VisitStmt_(attr);
} }
scope_stack_.pop_back(); scope_stack_.pop_back();
} }
...@@ -145,14 +155,15 @@ CollectResult CollectResources(const Stmt& body) { ...@@ -145,14 +155,15 @@ CollectResult CollectResources(const Stmt& body) {
} }
void VisitStmt_(const BufferStoreNode* op) final { void VisitStmt_(const BufferStoreNode* op) final {
LOG(INFO) << "Visiting BufferStore: " << op->buffer->name;
Buffer dst = op->buffer; Buffer dst = op->buffer;
if (IsSharedScope(dst) && op->value.defined()) { if (IsSharedScope(dst) && op->value.defined() && in_async) {
if (const auto* load = op->value.as<BufferLoadNode>()) { if (const auto* load = op->value.as<BufferLoadNode>()) {
Buffer src = load->buffer; Buffer src = load->buffer;
if (IsGlobalScope(src)) { if (IsGlobalScope(src)) {
const StmtNode* target = op;
if (result.inject_target == nullptr) { if (result.inject_target == nullptr) {
// 从下往上回溯栈,寻找最内层的 thread_extent
for (int i = scope_stack_.size() - 1; i >= 0; --i) { for (int i = scope_stack_.size() - 1; i >= 0; --i) {
if (scope_stack_[i]->IsInstance<AttrStmtNode>()) { if (scope_stack_[i]->IsInstance<AttrStmtNode>()) {
auto attr = static_cast<const AttrStmtNode*>(scope_stack_[i]); auto attr = static_cast<const AttrStmtNode*>(scope_stack_[i]);
...@@ -198,10 +209,10 @@ CollectResult CollectResources(const Stmt& body) { ...@@ -198,10 +209,10 @@ CollectResult CollectResources(const Stmt& body) {
VariableEliminator eliminator(loop_vars_); VariableEliminator eliminator(loop_vars_);
tvm::arith::Analyzer analyzer; tvm::arith::Analyzer analyzer;
Array<PrimExpr> base_indices; Array<PrimExpr> base_indices;
LOG(INFO) << loop_vars_.size() << " loop vars in context."; LOG(INFO) << loop_vars_.size() << " loop vars in context.";
for (const auto* var : loop_vars_) { for (const auto* var : loop_vars_) {
LOG(INFO) << "Loop Var: " << var->name_hint; LOG(INFO) << "Loop Var: " << var->name_hint;
} }
for (const auto& idx : load->indices) { for (const auto& idx : load->indices) {
// 将所有外层循环变量 (k, i 等) 全部替换为 0 // 将所有外层循环变量 (k, i 等) 全部替换为 0
PrimExpr no_loops = eliminator(idx); PrimExpr no_loops = eliminator(idx);
...@@ -225,15 +236,18 @@ CollectResult CollectResources(const Stmt& body) { ...@@ -225,15 +236,18 @@ CollectResult CollectResources(const Stmt& body) {
// 将这个绑定关系和 destination 的 shared buffer 绑死 // 将这个绑定关系和 destination 的 shared buffer 绑死
result.shared_alloc_to_binding[src->name] = {var, val}; result.shared_alloc_to_binding[src->name] = {var, val};
} }
LOG(INFO) << "result.copies.size() = " << result.copies.size();
} }
} }
} }
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
} }
}; };
LOG(INFO) << "Starting resource collection...";
Collector col; Collector col;
col(body); col(body);
LOG(INFO) << "Finished resource collection. Found " << col.result.copies.size() << " copy(s).";
return col.result; return col.result;
} }
...@@ -253,6 +267,15 @@ private: ...@@ -253,6 +267,15 @@ private:
const std::unordered_map<String, Var>& global_to_var) const std::unordered_map<String, Var>& global_to_var)
: copies_(copies), global_to_var_(global_to_var) {} : copies_(copies), global_to_var_(global_to_var) {}
Stmt VisitStmt_(const AttrStmtNode *attr) {
if (attr->attr_key == tir::attr::async_scope) {
auto body = this->VisitStmt(attr->body);
return body;
}
return StmtMutator::VisitStmt_(attr); // ③ 其他属性:默认保留
}
Stmt VisitStmt_(const BufferStoreNode* op) final { Stmt VisitStmt_(const BufferStoreNode* op) final {
for (const auto& copy : copies_) { for (const auto& copy : copies_) {
if (copy.store_stmt.same_as(GetRef<Stmt>(op))) { if (copy.store_stmt.same_as(GetRef<Stmt>(op))) {
...@@ -331,19 +354,23 @@ private: ...@@ -331,19 +354,23 @@ private:
PrimFunc LowerSharedGlobalCopy(PrimFunc f) { PrimFunc LowerSharedGlobalCopy(PrimFunc f) {
auto* n = f.CopyOnWrite(); auto* n = f.CopyOnWrite();
// 1. 收集信息并定位目标注入点 // 收集信息
LOG(INFO) << "Starting LowerSharedGlobalCopy transformation...";
auto res = CollectResources(n->body); auto res = CollectResources(n->body);
if (res.copies.empty()) return f; if (res.copies.empty()){
LOG(INFO) << "No shared-global copy patterns detected. Skipping transformation.";
return f;
}
// 【核心修改】:2. 先注入 LetStmt! LOG(INFO) << "Replaced " << res.copies.size() << " copy(s) with dcu_async_copy.";
// 此时使用的 n->body 是原始 AST,res.inject_target 指针百分之百匹配。 // 注入res声明
Stmt injected = ResourceInjector::Run(n->body, res.shared_alloc_to_binding, res.inject_target); Stmt injected = ResourceInjector::Run(n->body, res.shared_alloc_to_binding, res.inject_target);
// 3. 替换拷贝语句 // 替换拷贝语句
// injected 是套了 LetStmt 的新 AST,但底层的 BufferStore 还是原来的,可以被正常替换。
Stmt replaced = StoreReplacer::Run(injected, res.copies, res.global_to_res_var); Stmt replaced = StoreReplacer::Run(injected, res.copies, res.global_to_res_var);
// 4. 写回 PrimFunc
// 写回
n->body = std::move(replaced); n->body = std::move(replaced);
return GetRef<PrimFunc>(n); return GetRef<PrimFunc>(n);
......
...@@ -52,6 +52,7 @@ public: ...@@ -52,6 +52,7 @@ public:
// The syncs inserted before each statement // The syncs inserted before each statement
std::unordered_set<const Object *> syncs_inserted_; std::unordered_set<const Object *> syncs_inserted_;
std::unordered_set<const Object *> barrier_inserted_;
protected: protected:
bool Enabled(const VarNode *buf, const StorageScope &scope) const final { bool Enabled(const VarNode *buf, const StorageScope &scope) const final {
...@@ -815,9 +816,9 @@ PrimFunc TileLangThreadSync(PrimFunc func, const std::string &storage_scope) { ...@@ -815,9 +816,9 @@ PrimFunc TileLangThreadSync(PrimFunc func, const std::string &storage_scope) {
StorageScope sync_scope = StorageScope::Create(storage_scope); StorageScope sync_scope = StorageScope::Create(storage_scope);
auto *n = func.CopyOnWrite(); auto *n = func.CopyOnWrite();
auto stmt = n->body; auto stmt = n->body;
if (sync_scope.rank == StorageRank::kShared && sync_scope.tag.empty()) { // if (sync_scope.rank == StorageRank::kShared && sync_scope.tag.empty()) {
stmt = ThreadSyncAfterWaitQueueInserter(sync_scope)(stmt); // stmt = ThreadSyncAfterWaitQueueInserter(sync_scope)(stmt);
} // }
TileLangThreadSyncPlanner planner(sync_scope); TileLangThreadSyncPlanner planner(sync_scope);
for (const auto &[_, buffer] : func->buffer_map) { for (const auto &[_, buffer] : func->buffer_map) {
planner.SetBufferDataToBuffer(buffer->data, buffer); planner.SetBufferDataToBuffer(buffer->data, buffer);
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/tir/analysis.h>
#include <vector> #include <vector>
...@@ -27,72 +28,95 @@ public: ...@@ -27,72 +28,95 @@ public:
private: private:
arith::Analyzer analyzer_; arith::Analyzer analyzer_;
Var k_var_; Var k_var_;
PrimExpr k_extent_; // 新增:记录 k 循环的次数 PrimExpr k_extent_;
bool in_unrolled_i_ = false;
// 通用的步长提取函数:从 expr 中提取指定 var 的步长,并返回剩余的 base
std::pair<PrimExpr, PrimExpr> ExtractStride(PrimExpr expr, Var var) { std::pair<PrimExpr, PrimExpr> ExtractStride(PrimExpr expr, Var var) {
if (!var.defined()) return {expr, make_zero(expr.dtype())}; if (!var.defined()) return {expr, make_zero(expr->dtype)};
PrimExpr base = tvm::tir::Substitute(expr, {{var, make_zero(var.dtype())}}); PrimExpr base = tvm::tir::Substitute(expr, {{var, make_zero(var.dtype())}});
PrimExpr plus_one = tvm::tir::Substitute(expr, {{var, make_const(var.dtype(), 1)}}); PrimExpr plus_one = tvm::tir::Substitute(expr, {{var, make_const(var.dtype(), 1)}});
PrimExpr stride = analyzer_.Simplify(plus_one - base); PrimExpr stride = analyzer_.Simplify(plus_one - base);
return {analyzer_.Simplify(base), stride}; return {analyzer_.Simplify(base), stride};
} }
// 核心重写逻辑
Stmt RewriteAsyncCopy(const CallNode* call, Var i_var, PrimExpr i_extent) {
// 1. 预处理:剥离 RampNode 获得基础偏移
PrimExpr raw_dst_off = call->args[1];
PrimExpr raw_src_off = call->args[3];
if (const RampNode* r = raw_dst_off.as<RampNode>()) raw_dst_off = r->base;
if (const RampNode* r = raw_src_off.as<RampNode>()) raw_src_off = r->base;
// 2. 提取 i 的步长
auto [base_i_dst, i_stride_dst] = ExtractStride(raw_dst_off, i_var);
auto [base_i_src, i_stride_src] = ExtractStride(raw_src_off, i_var);
// 3. 提取 k 的步长 (始终尝试从 base_i_src 中提取,无论是否存在 i 循环)
// 只要 k_var_ 在外层循环中定义了,这里就能提取出非 0 的步长
auto [final_src_offset, k_stride_src] = ExtractStride(base_i_src, k_var_);
// 构造最初要求的 8 个参数:
// [dst, dst_off, src, src_off, i_extent, i_stride_dst, i_stride_src, k_stride_src]
Array<PrimExpr> new_args = {
call->args[0], // dst_buf
base_i_dst, // 最终 dst 偏移
call->args[2], // src_buf
final_src_offset, // 最终 src 偏移
i_extent, // i 循环次数 (无循环时为 0)
i_stride_dst, // i 的 dst 步长 (无循环时为 0)
i_stride_src, // i 的 src 步长 (无循环时为 0)
k_stride_src // k 的 src 步长 (即便无 i 循环,这里也能拿到 k 的步长)
};
return Evaluate(Call(call->dtype, call->op, new_args));
}
// 处理无循环包裹的情况
Stmt VisitStmt_(const EvaluateNode* op) final {
if (!in_unrolled_i_) {
if (const CallNode* call = op->value.as<CallNode>()) {
static const Op& dcu_copy_op = Op::Get("tl.dcu_async_copy");
// 只要参数个数不是 8 (我们重写后的目标个数),就进行处理
if (call->op.same_as(dcu_copy_op) && call->args.size() != 8) {
return RewriteAsyncCopy(call, Var(), make_zero(DataType::Int(32)));
}
}
}
return StmtExprMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const ForNode* op) final { Stmt VisitStmt_(const ForNode* op) final {
// 1. 记录 k 信息 // 记录 k 信息 (假设 k 在外层)
bool is_k = (op->loop_var->name_hint == "k"); bool is_k = (op->loop_var->name_hint == "k");
if (is_k) { if (is_k) {
k_var_ = op->loop_var; k_var_ = op->loop_var;
k_extent_ = op->extent; // 获取 k 的循环次数 (如 64) k_extent_ = op->extent;
} }
// 2. 递归访问子节点 bool is_unrolled = (op->kind == ForKind::kUnrolled);
bool prev_in_unrolled = in_unrolled_i_;
if (is_unrolled) in_unrolled_i_ = true;
Stmt body = this->VisitStmt(op->body); Stmt body = this->VisitStmt(op->body);
in_unrolled_i_ = prev_in_unrolled;
// 3. 处理 Async Copy 简化 if (is_unrolled) {
if (op->kind == ForKind::kUnrolled) {
if (const EvaluateNode* eval = body.as<EvaluateNode>()) { if (const EvaluateNode* eval = body.as<EvaluateNode>()) {
if (const CallNode* call = eval->value.as<CallNode>()) { if (const CallNode* call = eval->value.as<CallNode>()) {
static const Op& dcu_copy_op = Op::Get("tl.dcu_async_copy"); static const Op& dcu_copy_op = Op::Get("tl.dcu_async_copy");
if (call->op.same_as(dcu_copy_op)) { if (call->op.same_as(dcu_copy_op)) {
// 还原 k 并在返回前处理重写
Var i_var = op->loop_var; Stmt result = RewriteAsyncCopy(call, op->loop_var, op->extent);
PrimExpr i_extent = op->extent; // 获取 i 的循环次数 (如 2) return result;
auto get_i_info = [&](PrimExpr offset) {
if (const RampNode* ramp = offset.as<RampNode>()) {
auto [base, stride] = ExtractStride(ramp->base, i_var);
return std::make_pair(base, stride);
}
return ExtractStride(offset, i_var);
};
// 提取 i 的步长
auto [base_dst, i_stride_dst] = get_i_info(call->args[1]);
auto [base_src, i_stride_src] = get_i_info(call->args[3]);
// 提取 k 的步长 (从 base_src 继续解构)
auto [final_src_offset, k_stride_src] = ExtractStride(base_src, k_var_);
// 构造新的参数列表,包含循环次数
// 建议参数顺序:[dst, dst_off, src, src_off, size, i_extent, i_stride_dst, i_stride_src, k_stride_src]
// 这里的 size 保持原样 (如 8),i_extent 传入 2
Array<PrimExpr> new_args = {
call->args[0], // dst_buf
base_dst, // 基础 dst 偏移
call->args[2], // src_buf
final_src_offset, // 基础 src 偏移
i_extent, // i 循环次数
i_stride_dst, // i 的 dst 步长
i_stride_src, // i 的 src 步长
k_stride_src // k 的 src 步长
};
return Evaluate(Call(call->dtype, call->op, new_args));
} }
} }
} }
} }
// 退出循环时清理 k 信息
if (is_k) { if (is_k) {
k_var_ = Var(); k_var_ = Var();
k_extent_ = PrimExpr(); k_extent_ = PrimExpr();
......
...@@ -213,8 +213,11 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -213,8 +213,11 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.IfStmtBinding()(mod) mod = tilelang.transform.IfStmtBinding()(mod)
mod = tilelang.transform.PlanAndUpdateBufferAllocationLocation()(mod) mod = tilelang.transform.PlanAndUpdateBufferAllocationLocation()(mod)
print("OptimizeForTarget")
print(mod)
mod = tilelang.transform.PipelinePlanning()(mod) mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod) mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tilelang.transform.MergeIfStmt()(mod) mod = tilelang.transform.MergeIfStmt()(mod)
if allow_fence_proxy(target=target): if allow_fence_proxy(target=target):
# in hopper device, wgmma is an async proxy # in hopper device, wgmma is an async proxy
...@@ -270,15 +273,15 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -270,15 +273,15 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge)(mod) mod = tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge)(mod)
mod = tilelang.transform.ThreadSync("shared")(mod) mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod) mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
print("OptimizeForTarget2")
print(mod)
# Inject PTX async copy must behind the thread sync pass # Inject PTX async copy must behind the thread sync pass
# as ptx async copy won't be recognized as a valid buffer load # as ptx async copy won't be recognized as a valid buffer load
mod = tilelang.transform.InjectPTXAsyncCopy()(mod) if not dcu_async_copy_supported(target):
mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
# Inject ds_read for shared to register memory copy on DCU # Inject ds_read for shared to register memory copy on DCU
mod = tilelang.transform.InjectDSRead()(mod) # mod = tilelang.transform.InjectDSRead()(mod)
print("222222222")
print(mod)
if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target): if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target):
mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod) mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod)
mod = tilelang.transform.MakePackedAPI()(mod) mod = tilelang.transform.MakePackedAPI()(mod)
...@@ -287,13 +290,11 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -287,13 +290,11 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# Transform threadblock to persistent threadblock # Transform threadblock to persistent threadblock
mod = tilelang.transform.PersistThreadblock()(mod) mod = tilelang.transform.PersistThreadblock()(mod)
print("OptimizeForTarget")
print(mod)
if dcu_async_copy_supported(target): if dcu_async_copy_supported(target):
mod = tilelang.transform.LowerSharedGlobalCopy()(mod) mod = tilelang.transform.LowerSharedGlobalCopy()(mod)
print("OptimizeForTarget2") mod = tilelang.transform.FixDCUWaitCount()(mod)
print(mod) #
mod = tilelang.transform.SimplifyDCUAsyncCopy()(mod) # mod = tilelang.transform.SimplifyDCUAsyncCopy()(mod)
print("OptimizeForTarget3") print("OptimizeForTarget3")
print(mod) print(mod)
return mod return mod
...@@ -11,7 +11,7 @@ from tilelang.utils import is_fragment ...@@ -11,7 +11,7 @@ from tilelang.utils import is_fragment
from .mfma_layout import ( from .mfma_layout import (
shared_16x4_to_local_64x1_layout_A, shared_16x4_to_local_64x1_layout_A,
shared_4x16_to_local_64x1_layout_B, shared_4x16_to_local_64x1_layout_B,
shared_16x16_to_local_64x4_layout_A, # shared_16x16_to_local_64x4_layout_A,
shared_16x16_to_local_64x4_layout_B, shared_16x16_to_local_64x4_layout_B,
shared_16x32_to_local_64x8_layout_A, shared_16x32_to_local_64x8_layout_A,
shared_16x32_to_local_64x8_layout_B, shared_16x32_to_local_64x8_layout_B,
...@@ -19,7 +19,7 @@ from .mfma_layout import ( ...@@ -19,7 +19,7 @@ from .mfma_layout import (
shared_16x64_to_local_64x16_layout_B, shared_16x64_to_local_64x16_layout_B,
thread_id_shared_access_64x1_to_16x4_layout_A, thread_id_shared_access_64x1_to_16x4_layout_A,
thread_id_shared_access_64x1_to_4x16_layout_B, thread_id_shared_access_64x1_to_4x16_layout_B,
thread_id_shared_access_64x4_to_16x16_layout_A, # thread_id_shared_access_64x4_to_16x16_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_B, thread_id_shared_access_64x4_to_16x16_layout_B,
thread_id_shared_access_64x8_to_16x32_layout_A, thread_id_shared_access_64x8_to_16x32_layout_A,
thread_id_shared_access_64x8_to_16x32_layout_B, thread_id_shared_access_64x8_to_16x32_layout_B,
...@@ -27,10 +27,10 @@ from .mfma_layout import ( ...@@ -27,10 +27,10 @@ from .mfma_layout import (
thread_id_shared_access_64x16_to_16x64_layout_B, thread_id_shared_access_64x16_to_16x64_layout_B,
) )
# from .mmac_layout import ( from .mmac_layout import (
# shared_16x16_to_local_64x4_layout_A, shared_16x16_to_local_64x4_layout_A,
# thread_id_shared_access_64x4_to_16x16_layout_A, thread_id_shared_access_64x4_to_16x16_layout_A,
# ) )
lift = convert lift = convert
...@@ -251,6 +251,21 @@ class MatrixCoreIntrinEmitter: ...@@ -251,6 +251,21 @@ class MatrixCoreIntrinEmitter:
) )
return lane_id, warp_n, warp_m return lane_id, warp_n, warp_m
def map_64x16(self, row, col, idx, warp_rows, tx):
new_col = col
if warp_rows > 1:
inter_idx_padding = 2
else:
inter_idx_padding = 1
paddings = inter_idx_padding * self.block_row_warps * 4
print("paddings:", paddings)
new_row = row + paddings * ((tx & 15) // 4)
new_row += (idx & 1) * (paddings // 2) + (idx // 2) * 16 * 2 * self.block_row_warps
return new_row, new_col
def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0): def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0):
# share mem a needs warp number # share mem a needs warp number
warp_num = self.block_row_warps warp_num = self.block_row_warps
...@@ -272,6 +287,7 @@ class MatrixCoreIntrinEmitter: ...@@ -272,6 +287,7 @@ class MatrixCoreIntrinEmitter:
A_buf = A_region.buffer A_buf = A_region.buffer
A_base0 = A_region.region[-2].min A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min A_base1 = A_region.region[-1].min
print("A_base0, A_base1:", A_base0, A_base1)
@T.macro @T.macro
def _warp_ldmatrix_a( def _warp_ldmatrix_a(
...@@ -281,31 +297,42 @@ class MatrixCoreIntrinEmitter: ...@@ -281,31 +297,42 @@ class MatrixCoreIntrinEmitter:
thread_binding, thread_binding,
rk=0, rk=0,
): ):
tx, _, warp_m = self.extract_thread_binding(thread_binding) # warp_n[0-256] -> {0,1,2,3}
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
# {0..3,16..19,32..35,48..51} -> 0 # {0..3,16..19,32..35,48..51} -> 0
# {4..7,20..23,36..39,52..55} -> 1 # {4..7,20..23,36..39,52..55} -> 1
# {8..11,24..27,40..43,56..59} -> 2 # {8..11,24..27,40..43,56..59} -> 2
# {12..15,28..31,44..47,60..63} -> 3 # {12..15,28..31,44..47,60..63} -> 3
warp_interval_idx = (tx & 15)>>2 warp_interval_idx = (tx & 15)>>2
warp_group_idx = (tx // 32) warp_group_idx = warp_n
# warp_rows 轮次 warp需要拆分成多轮来访问完整的行块 # warp_rows 轮次 warp需要拆分成多轮来访问完整的行块
if is_transposed: if is_transposed:
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a): for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
# 每轮初始位置行偏移 # 每轮初始位置行偏移
row += i * warp_row_init # row += i * warp_row_init
# warp 组行间隔 # # warp 组行间隔
row += warp_group_idx * 4 # row += warp_group_idx * 4
# warp 内行间隔 # # warp 内行间隔
row += warp_interval_idx * warp_row_interval # row += warp_interval_idx * warp_row_interval
raise NotImplementedError("Transposed A with preshuffle is not implemented yet")
row, col = self.map_64x16(row, col, i, warp_rows, tx)
l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x) l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x)
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col] A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col]
else: else:
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a): for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * (k_pack * micro_size_k)) # # 每轮初始位置行偏移
# row += i * warp_row_init
# # warp 组行间隔
# row += warp_group_idx * 4
# # warp 内行间隔
# row += warp_interval_idx * warp_row_interval
row, col = self.map_64x16(row, col, i, warp_rows, tx)
print("row, col:", row, col)
l, r = (warp_m * 4, rk * chunk + ki * (k_pack * micro_size_k))
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col] A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
......
...@@ -20,6 +20,10 @@ print("tileop gemm init...") ...@@ -20,6 +20,10 @@ print("tileop gemm init...")
def gemm_py_infer_layout(gemm_py: GemmMMA, target: Target, thread_bounds: Range): def gemm_py_infer_layout(gemm_py: GemmMMA, target: Target, thread_bounds: Range):
print("tileop gemm infer_layout") print("tileop gemm infer_layout")
thread_nums = thread_bounds.extent thread_nums = thread_bounds.extent
print(f"gemm_py_infer_layout Target: {target}, thread_nums: {thread_nums}")
print(f"gemm_py_infer_layout gemm_py: {gemm_py}")
t = gemm_py.infer_layout(target, thread_nums)
print(f"gemm_py_infer_layout gemm_py.A: {gemm_py.A}, gemm_py.B: {gemm_py.B}, gemm_py.C: {gemm_py.C}")
return gemm_py.infer_layout(target, thread_nums) return gemm_py.infer_layout(target, thread_nums)
......
...@@ -32,8 +32,10 @@ class GemmMMAC(GemmBase): ...@@ -32,8 +32,10 @@ class GemmMMAC(GemmBase):
if self.is_gemm_ss(): if self.is_gemm_ss():
return { return {
self.A: make_swizzled_layout(self.A), # self.A: make_swizzled_layout(self.A, allow_pad=False),
self.B: make_swizzled_layout(self.B), # self.B: make_swizzled_layout(self.B, allow_pad=False),
self.A: make_linear_layout(self.A),
self.B: make_linear_layout(self.B),
self.C: mmac_emitter.make_mmac_store_layout(self.C), self.C: mmac_emitter.make_mmac_store_layout(self.C),
} }
elif self.is_gemm_sr(): elif self.is_gemm_sr():
......
...@@ -559,4 +559,8 @@ def LowerSharedGlobalCopy(): ...@@ -559,4 +559,8 @@ def LowerSharedGlobalCopy():
def SimplifyDCUAsyncCopy(): def SimplifyDCUAsyncCopy():
"""SimplifyDCUAsyncCopy""" """SimplifyDCUAsyncCopy"""
return _ffi_api.SimplifyDCUAsyncCopy() # type: ignore return _ffi_api.SimplifyDCUAsyncCopy() # type: ignore
\ No newline at end of file
def FixDCUWaitCount():
"""FixDCUWaitCount"""
return _ffi_api.FixDCUWaitCount() # type: ignore
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment