Commit dd91b1e0 authored by qisan's avatar qisan
Browse files

Feats: vectorize async copy

parent 41887aed
......@@ -10,7 +10,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.fl
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=512) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......
......@@ -389,7 +389,7 @@ TIR_DEFINE_TL_BUILTIN(__ldg).set_num_inputs(-1).set_attr<TCallEffectKind>(
//
TIR_DEFINE_TL_BUILTIN(dcu_async_copy)
.set_num_inputs(6)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
......
......@@ -793,8 +793,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
<< ", " << condition << ");\n";
}
} else if (op->op.same_as(builtin::ptx_commit_group())) {
printf("[DEBUG VisitExpr_] Branch: ptx_commit_group\n");
print_extern_call_stmt("tl::cp_async_commit");
;
// print_extern_call_stmt("tl::cp_async_commit");
} else if (op->op.same_as(builtin::ptx_wait_group())) {
printf("[DEBUG VisitExpr_] Branch: ptx_wait_group\n");
int n = Downcast<IntImm>(op->args[0])->value;
......@@ -1103,42 +1103,51 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
}
else if (op->op.same_as(Op::Get("tl.dcu_async_copy"))) {
// 1. 提取模板参数 (IntImm 直接取值)
auto get_base_expr = [this](const PrimExpr& e) -> std::string {
if (const auto* ramp = e.as<tvm::tir::RampNode>()) {
// 如果是 Ramp,只打印它的起始位置 (base)
return this->PrintExpr(ramp->base);
}
// 否则正常打印
return this->PrintExpr(e);
};
// 辅助函数:尝试获取整数常量
auto get_int_const = [](const PrimExpr& e) -> int {
if (const auto* val = e.as<IntImmNode>()) return static_cast<int>(val->value);
return 0;
};
int N = 16;
int smem_offset = 0;
int load_count = get_int_const(op->args[4]);
int i_sstride = get_int_const(op->args[5]);
int i_gstride = get_int_const(op->args[6]);
int k_gstride = get_int_const(op->args[7]);
// 1. 静态模板参数 (按要求仅保留 N 和 smem_offset)
int N = 16;
// 2. 将运行时参数打印到字符串中 (防止直接操作 stream 导致冲突)
std::string dst_ptr = this->PrintExpr(op->args[0]);
std::string dst_off = this->PrintExpr(op->args[1]);
std::string src_res = this->PrintExpr(op->args[2]);
std::string src_off = this->PrintExpr(op->args[3]);
// 2. 解析 IR 参数
// args[0]: dst_ptr (buf_dyn_shmem)
// args[1]: dst_ramp (T.Ramp...)
// args[2]: src_res (A_dcu_res)
// args[3]: src_ramp (T.Ramp...)
// args[4]: load_count (1)
std::string dst_ptr = this->PrintExpr(op->args[0]);
// 使用新定义的 get_base_expr 避开 lanes > 4 的检查
std::string dst_off = get_base_expr(op->args[1]);
std::string src_res = this->PrintExpr(op->args[2]);
std::string src_off = get_base_expr(op->args[3]);
// 3. 仿照范例进行流输出
// 3. 生成输出
this->PrintIndent();
// 模板参数仅保留 N, smem_offset 和动态提取的 load_count
this->stream << "tl::cp_async_gs<"
<< N << ", "
<< smem_offset << ", "
<< load_count << ", "
<< i_sstride << ", "
<< i_gstride << ", "
<< k_gstride << ">(";
<< N << ">(";
// 拼接第一个参数:(char*)dst + dst_off
// 打印函数参数
// 处理目标地址: ((char*)ptr + offset)
this->stream << "((char*)" << dst_ptr << " + " << dst_off << "), ";
// 拼接第二个参数:src_res
// 打印源资源指针
this->stream << src_res << ", ";
// 拼接第三个参数:src_off
// 打印源偏移
this->stream << src_off << ");\n";
}
else {
......
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/analysis.h>
using namespace tvm::tir;
using tvm::ffi::GetRef;
using tvm::ffi::make_object;
namespace tvm {
namespace tl {
using namespace tir;
using ffi::Array;
using ffi::String;
class ROCmWaitCountRewriter : public StmtMutator {
public:
static Stmt Substitute(Stmt stmt) {
return ROCmWaitCountRewriter()(stmt);
}
private:
// 辅助函数:统计一个代码块内 async 指令的总数
int CountAsyncOps(const Stmt& stmt) {
int total_count = 0;
struct Visitor : public StmtExprVisitor {
int count = 0;
void VisitStmt_(const ForNode* op) override {
// 如果内部还有循环(比如 T.unroll),需要乘上循环次数
int current_count = count;
count = 0;
StmtExprVisitor::VisitStmt_(op);
int loop_count = 0;
if (const auto* extent = op->extent.as<IntImmNode>()) {
loop_count = static_cast<int>(extent->value);
} else {
// 如果是非固定长度循环,这在流水线中很少见,默认按1处理或报警
loop_count = 1;
}
int body_count = count;
count = current_count + (body_count * loop_count);
}
void VisitExpr_(const CallNode* op) override {
// 识别 ptx_cp_async 或对应的异步访存 Op
if (op->op.same_as(builtin::ptx_cp_async()) ||
op->op.same_as(Op::Get("tl.dcu_async_copy"))) {
LOG(INFO) << "Found async copy: " << GetRef<Call>(op);
count++;
}
StmtExprVisitor::VisitExpr_(op);
}
// 兼容某些实现中把 cp_async 放在 Evaluate 里的情况
void VisitStmt_(const EvaluateNode* op) override {
StmtExprVisitor::VisitStmt_(op);
}
} visitor;
visitor(stmt);
return visitor.count;
}
Stmt VisitStmt_(const ForNode* op) override {
// 1. 我们假设流水线的主循环是核心作用域
// 先扫描该循环体内部每一轮会发出多少个 async 操作
int ops_per_iter = CountAsyncOps(op->body);
// 如果没有异步操作,直接跳过
if (ops_per_iter == 0) return StmtMutator::VisitStmt_(op);
// 2. 进入循环内部进行修改,记录当前的倍数
int old_multiplier = multiplier_;
multiplier_ = ops_per_iter;
Stmt new_body = this->VisitStmt(op->body);
multiplier_ = old_multiplier;
if (new_body.same_as(op->body)) return GetRef<Stmt>(op);
auto n = CopyOnWrite(op);
n->body = std::move(new_body);
return Stmt(n);
}
Stmt VisitStmt_(const AttrStmtNode* op) override {
if (op->attr_key == "async_wait_inflight_count" && multiplier_ > 0) {
// 获取原有的 wait 组数 (比如 1)
if (auto int_imm = op->value.as<IntImmNode>()) {
// 计算 ROCm 的指令数: N_groups * Ops_per_group
int64_t new_cont = int_imm->value * multiplier_;
LOG(INFO) << "Original wait count: " << new_cont << ", async ops per iter: " << multiplier_;
// 返回修改后的节点
return AttrStmt(op->node, op->attr_key, make_const(DataType::Int(32), new_cont), op->body);
}
}
return StmtMutator::VisitStmt_(op);
}
int multiplier_ = 0; // 当前作用域下的指令倍率
};
// 包装成标准的 TVM Pass
namespace transform {
using namespace tir::transform;
tvm::transform::Pass FixDCUWaitCount() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = ROCmWaitCountRewriter::Substitute(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "FixDCUWaitCount", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
tvm::ffi::reflection::GlobalDef().def("tl.transform.FixDCUWaitCount", FixDCUWaitCount);
}
} // namespace transform
} // namespace tl
} // namespace tvm
\ No newline at end of file
......@@ -148,6 +148,7 @@ private:
new_args.Set(i + 1, new_index);
}
}
LOG(INFO) << "Rewriting buffer access " << call << " to " << Call(call->dtype, call->op, new_args, call->span);
return Call(call->dtype, call->op, new_args, call->span);
}
......@@ -166,6 +167,7 @@ private:
for (const Buffer &alloc_buffer : op->alloc_buffers) {
buffer_data_to_buffer_.erase(alloc_buffer->data);
}
LOG(INFO) << "Rewriting block " << GetRef<Block>(op) << " to " << GetRef<Block>(n);
return block;
}
......@@ -309,6 +311,7 @@ public:
}
Block block = MakeBlock(stmt, buffer_data_to_buffer_);
block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers);
LOG(INFO) << "Final rewritten pipeline block: " << block;
return BlockRealize({}, Bool(true), block);
}
......@@ -631,6 +634,9 @@ private:
n->body = AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
AttrStmt(zero, tir::attr::async_wait_inflight_count,
pw.wait_count, n->body));
LOG(INFO) << "Inserting async_wait with count " << pw.wait_count
<< " before block with order " << new_blocks[pw.insert_before].order
<< " for async stage " << stage_id;
}
}
......@@ -1102,6 +1108,7 @@ private:
buffer_data_to_buffer_.erase(buffer->data);
}
}
LOG(INFO) << "Finished rewriting the pipeline loop with body:\n" << pipeline;
return pipeline;
}
......@@ -1121,6 +1128,7 @@ private:
for (const auto &buffer : op->alloc_buffers) {
buffer_data_to_buffer_.erase(buffer->data);
}
LOG(INFO) << "Rewriting blockddd " << block;
return block;
}
......@@ -1158,6 +1166,8 @@ tir::transform::Pass InjectSoftwarePipeline() {
auto *fptr = f.CopyOnWrite();
fptr->body = software_pipeline::PipelineInjector::Inject(f);
fptr->body = ConvertSSA(std::move(fptr->body));
LOG(INFO) << "Finished injecting software pipeline for PrimFunc " << f->GetAttr<String>(tvm::attr::kGlobalSymbol).value_or("<unknown>")
<< ", the transformed body is:\n" << fptr->body;
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {});
......
......@@ -94,6 +94,7 @@ CollectResult CollectResources(const Stmt& body) {
CollectResult result;
private:
bool in_async{false};
std::unordered_set<const tvm::tir::VarNode*> loop_vars_;
std::vector<const tvm::tir::StmtNode*> scope_stack_; // 追踪当前遍历的 AST 路径
bool IsSharedScope(const Buffer& buf) {
......@@ -105,27 +106,36 @@ CollectResult CollectResources(const Stmt& body) {
return s == "global" || s == "";
}
void VisitStmt_(const AttrStmtNode* op) override {
scope_stack_.push_back(op);
if (op->attr_key == tvm::tir::attr::thread_extent) {
void VisitStmt_(const AttrStmtNode* attr) override {
scope_stack_.push_back(attr);
if (attr->attr_key == tir::attr::thread_extent) {
// 1. 获取 IterVar
auto iv = op->node.as<tvm::tir::IterVarNode>();
auto iv = attr->node.as<tvm::tir::IterVarNode>();
const std::string& tag = iv->thread_tag;
// 2. 只有当 tag 包含 "threadIdx" 时才加入 (过滤掉 blockIdx)
// 比如: "threadIdx.x", "threadIdx.y", "threadIdx.z"
if (tag.find("threadIdx") != std::string::npos) {
tvm::tir::Var thread_var = iv->var;
LOG(INFO) << "Entering thread scope: " << tag << " with var " << thread_var->name_hint;
loop_vars_.insert(thread_var.get());
StmtExprVisitor::VisitStmt_(op);
StmtExprVisitor::VisitStmt_(attr);
loop_vars_.erase(thread_var.get());
} else {
// 如果是 blockIdx 或其他,直接跳过当前层继续往下走
StmtExprVisitor::VisitStmt_(op);
StmtExprVisitor::VisitStmt_(attr);
}
} else if (attr->attr_key == tir::attr::async_scope) {
ICHECK(in_async == false) << "Nested async scopes not supported";
in_async = true;
StmtExprVisitor::VisitStmt_(attr);
in_async = false;
}
else {
StmtExprVisitor::VisitStmt_(attr);
}
scope_stack_.pop_back();
}
......@@ -145,14 +155,15 @@ CollectResult CollectResources(const Stmt& body) {
}
void VisitStmt_(const BufferStoreNode* op) final {
LOG(INFO) << "Visiting BufferStore: " << op->buffer->name;
Buffer dst = op->buffer;
if (IsSharedScope(dst) && op->value.defined()) {
if (IsSharedScope(dst) && op->value.defined() && in_async) {
if (const auto* load = op->value.as<BufferLoadNode>()) {
Buffer src = load->buffer;
if (IsGlobalScope(src)) {
const StmtNode* target = op;
if (result.inject_target == nullptr) {
// 从下往上回溯栈,寻找最内层的 thread_extent
for (int i = scope_stack_.size() - 1; i >= 0; --i) {
if (scope_stack_[i]->IsInstance<AttrStmtNode>()) {
auto attr = static_cast<const AttrStmtNode*>(scope_stack_[i]);
......@@ -198,10 +209,10 @@ CollectResult CollectResources(const Stmt& body) {
VariableEliminator eliminator(loop_vars_);
tvm::arith::Analyzer analyzer;
Array<PrimExpr> base_indices;
LOG(INFO) << loop_vars_.size() << " loop vars in context.";
for (const auto* var : loop_vars_) {
LOG(INFO) << "Loop Var: " << var->name_hint;
}
LOG(INFO) << loop_vars_.size() << " loop vars in context.";
for (const auto* var : loop_vars_) {
LOG(INFO) << "Loop Var: " << var->name_hint;
}
for (const auto& idx : load->indices) {
// 将所有外层循环变量 (k, i 等) 全部替换为 0
PrimExpr no_loops = eliminator(idx);
......@@ -225,15 +236,18 @@ CollectResult CollectResources(const Stmt& body) {
// 将这个绑定关系和 destination 的 shared buffer 绑死
result.shared_alloc_to_binding[src->name] = {var, val};
}
LOG(INFO) << "result.copies.size() = " << result.copies.size();
}
}
}
StmtExprVisitor::VisitStmt_(op);
}
};
LOG(INFO) << "Starting resource collection...";
Collector col;
col(body);
LOG(INFO) << "Finished resource collection. Found " << col.result.copies.size() << " copy(s).";
return col.result;
}
......@@ -253,6 +267,15 @@ private:
const std::unordered_map<String, Var>& global_to_var)
: copies_(copies), global_to_var_(global_to_var) {}
Stmt VisitStmt_(const AttrStmtNode *attr) {
if (attr->attr_key == tir::attr::async_scope) {
auto body = this->VisitStmt(attr->body);
return body;
}
return StmtMutator::VisitStmt_(attr); // ③ 其他属性:默认保留
}
Stmt VisitStmt_(const BufferStoreNode* op) final {
for (const auto& copy : copies_) {
if (copy.store_stmt.same_as(GetRef<Stmt>(op))) {
......@@ -331,19 +354,23 @@ private:
PrimFunc LowerSharedGlobalCopy(PrimFunc f) {
auto* n = f.CopyOnWrite();
// 1. 收集信息并定位目标注入点
// 收集信息
LOG(INFO) << "Starting LowerSharedGlobalCopy transformation...";
auto res = CollectResources(n->body);
if (res.copies.empty()) return f;
if (res.copies.empty()){
LOG(INFO) << "No shared-global copy patterns detected. Skipping transformation.";
return f;
}
// 【核心修改】:2. 先注入 LetStmt!
// 此时使用的 n->body 是原始 AST,res.inject_target 指针百分之百匹配。
LOG(INFO) << "Replaced " << res.copies.size() << " copy(s) with dcu_async_copy.";
// 注入res声明
Stmt injected = ResourceInjector::Run(n->body, res.shared_alloc_to_binding, res.inject_target);
// 3. 替换拷贝语句
// injected 是套了 LetStmt 的新 AST,但底层的 BufferStore 还是原来的,可以被正常替换。
// 替换拷贝语句
Stmt replaced = StoreReplacer::Run(injected, res.copies, res.global_to_res_var);
// 4. 写回 PrimFunc
// 写回
n->body = std::move(replaced);
return GetRef<PrimFunc>(n);
......
......@@ -52,6 +52,7 @@ public:
// The syncs inserted before each statement
std::unordered_set<const Object *> syncs_inserted_;
std::unordered_set<const Object *> barrier_inserted_;
protected:
bool Enabled(const VarNode *buf, const StorageScope &scope) const final {
......@@ -815,9 +816,9 @@ PrimFunc TileLangThreadSync(PrimFunc func, const std::string &storage_scope) {
StorageScope sync_scope = StorageScope::Create(storage_scope);
auto *n = func.CopyOnWrite();
auto stmt = n->body;
if (sync_scope.rank == StorageRank::kShared && sync_scope.tag.empty()) {
stmt = ThreadSyncAfterWaitQueueInserter(sync_scope)(stmt);
}
// if (sync_scope.rank == StorageRank::kShared && sync_scope.tag.empty()) {
// stmt = ThreadSyncAfterWaitQueueInserter(sync_scope)(stmt);
// }
TileLangThreadSyncPlanner planner(sync_scope);
for (const auto &[_, buffer] : func->buffer_map) {
planner.SetBufferDataToBuffer(buffer->data, buffer);
......
......@@ -3,6 +3,7 @@
#include <tvm/tir/op.h>
#include <tvm/tir/builtin.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/analysis.h>
#include <vector>
......@@ -27,72 +28,95 @@ public:
private:
arith::Analyzer analyzer_;
Var k_var_;
PrimExpr k_extent_; // 新增:记录 k 循环的次数
PrimExpr k_extent_;
bool in_unrolled_i_ = false;
// 通用的步长提取函数:从 expr 中提取指定 var 的步长,并返回剩余的 base
std::pair<PrimExpr, PrimExpr> ExtractStride(PrimExpr expr, Var var) {
if (!var.defined()) return {expr, make_zero(expr.dtype())};
if (!var.defined()) return {expr, make_zero(expr->dtype)};
PrimExpr base = tvm::tir::Substitute(expr, {{var, make_zero(var.dtype())}});
PrimExpr plus_one = tvm::tir::Substitute(expr, {{var, make_const(var.dtype(), 1)}});
PrimExpr stride = analyzer_.Simplify(plus_one - base);
return {analyzer_.Simplify(base), stride};
}
// 核心重写逻辑
Stmt RewriteAsyncCopy(const CallNode* call, Var i_var, PrimExpr i_extent) {
// 1. 预处理:剥离 RampNode 获得基础偏移
PrimExpr raw_dst_off = call->args[1];
PrimExpr raw_src_off = call->args[3];
if (const RampNode* r = raw_dst_off.as<RampNode>()) raw_dst_off = r->base;
if (const RampNode* r = raw_src_off.as<RampNode>()) raw_src_off = r->base;
// 2. 提取 i 的步长
auto [base_i_dst, i_stride_dst] = ExtractStride(raw_dst_off, i_var);
auto [base_i_src, i_stride_src] = ExtractStride(raw_src_off, i_var);
// 3. 提取 k 的步长 (始终尝试从 base_i_src 中提取,无论是否存在 i 循环)
// 只要 k_var_ 在外层循环中定义了,这里就能提取出非 0 的步长
auto [final_src_offset, k_stride_src] = ExtractStride(base_i_src, k_var_);
// 构造最初要求的 8 个参数:
// [dst, dst_off, src, src_off, i_extent, i_stride_dst, i_stride_src, k_stride_src]
Array<PrimExpr> new_args = {
call->args[0], // dst_buf
base_i_dst, // 最终 dst 偏移
call->args[2], // src_buf
final_src_offset, // 最终 src 偏移
i_extent, // i 循环次数 (无循环时为 0)
i_stride_dst, // i 的 dst 步长 (无循环时为 0)
i_stride_src, // i 的 src 步长 (无循环时为 0)
k_stride_src // k 的 src 步长 (即便无 i 循环,这里也能拿到 k 的步长)
};
return Evaluate(Call(call->dtype, call->op, new_args));
}
// 处理无循环包裹的情况
Stmt VisitStmt_(const EvaluateNode* op) final {
if (!in_unrolled_i_) {
if (const CallNode* call = op->value.as<CallNode>()) {
static const Op& dcu_copy_op = Op::Get("tl.dcu_async_copy");
// 只要参数个数不是 8 (我们重写后的目标个数),就进行处理
if (call->op.same_as(dcu_copy_op) && call->args.size() != 8) {
return RewriteAsyncCopy(call, Var(), make_zero(DataType::Int(32)));
}
}
}
return StmtExprMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const ForNode* op) final {
// 1. 记录 k 信息
// 记录 k 信息 (假设 k 在外层)
bool is_k = (op->loop_var->name_hint == "k");
if (is_k) {
k_var_ = op->loop_var;
k_extent_ = op->extent; // 获取 k 的循环次数 (如 64)
k_extent_ = op->extent;
}
// 2. 递归访问子节点
bool is_unrolled = (op->kind == ForKind::kUnrolled);
bool prev_in_unrolled = in_unrolled_i_;
if (is_unrolled) in_unrolled_i_ = true;
Stmt body = this->VisitStmt(op->body);
in_unrolled_i_ = prev_in_unrolled;
// 3. 处理 Async Copy 简化
if (op->kind == ForKind::kUnrolled) {
if (is_unrolled) {
if (const EvaluateNode* eval = body.as<EvaluateNode>()) {
if (const CallNode* call = eval->value.as<CallNode>()) {
static const Op& dcu_copy_op = Op::Get("tl.dcu_async_copy");
if (call->op.same_as(dcu_copy_op)) {
Var i_var = op->loop_var;
PrimExpr i_extent = op->extent; // 获取 i 的循环次数 (如 2)
auto get_i_info = [&](PrimExpr offset) {
if (const RampNode* ramp = offset.as<RampNode>()) {
auto [base, stride] = ExtractStride(ramp->base, i_var);
return std::make_pair(base, stride);
}
return ExtractStride(offset, i_var);
};
// 提取 i 的步长
auto [base_dst, i_stride_dst] = get_i_info(call->args[1]);
auto [base_src, i_stride_src] = get_i_info(call->args[3]);
// 提取 k 的步长 (从 base_src 继续解构)
auto [final_src_offset, k_stride_src] = ExtractStride(base_src, k_var_);
// 构造新的参数列表,包含循环次数
// 建议参数顺序:[dst, dst_off, src, src_off, size, i_extent, i_stride_dst, i_stride_src, k_stride_src]
// 这里的 size 保持原样 (如 8),i_extent 传入 2
Array<PrimExpr> new_args = {
call->args[0], // dst_buf
base_dst, // 基础 dst 偏移
call->args[2], // src_buf
final_src_offset, // 基础 src 偏移
i_extent, // i 循环次数
i_stride_dst, // i 的 dst 步长
i_stride_src, // i 的 src 步长
k_stride_src // k 的 src 步长
};
return Evaluate(Call(call->dtype, call->op, new_args));
// 还原 k 并在返回前处理重写
Stmt result = RewriteAsyncCopy(call, op->loop_var, op->extent);
return result;
}
}
}
}
// 退出循环时清理 k 信息
if (is_k) {
k_var_ = Var();
k_extent_ = PrimExpr();
......
......@@ -213,8 +213,11 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.IfStmtBinding()(mod)
mod = tilelang.transform.PlanAndUpdateBufferAllocationLocation()(mod)
print("OptimizeForTarget")
print(mod)
mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tilelang.transform.MergeIfStmt()(mod)
if allow_fence_proxy(target=target):
# in hopper device, wgmma is an async proxy
......@@ -270,15 +273,15 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge)(mod)
mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
print("OptimizeForTarget2")
print(mod)
# Inject PTX async copy must behind the thread sync pass
# as ptx async copy won't be recognized as a valid buffer load
mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
if not dcu_async_copy_supported(target):
mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
# Inject ds_read for shared to register memory copy on DCU
mod = tilelang.transform.InjectDSRead()(mod)
print("222222222")
print(mod)
# mod = tilelang.transform.InjectDSRead()(mod)
if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target):
mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod)
mod = tilelang.transform.MakePackedAPI()(mod)
......@@ -287,13 +290,11 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# Transform threadblock to persistent threadblock
mod = tilelang.transform.PersistThreadblock()(mod)
print("OptimizeForTarget")
print(mod)
if dcu_async_copy_supported(target):
mod = tilelang.transform.LowerSharedGlobalCopy()(mod)
print("OptimizeForTarget2")
print(mod)
mod = tilelang.transform.SimplifyDCUAsyncCopy()(mod)
mod = tilelang.transform.FixDCUWaitCount()(mod)
#
# mod = tilelang.transform.SimplifyDCUAsyncCopy()(mod)
print("OptimizeForTarget3")
print(mod)
return mod
......@@ -11,7 +11,7 @@ from tilelang.utils import is_fragment
from .mfma_layout import (
shared_16x4_to_local_64x1_layout_A,
shared_4x16_to_local_64x1_layout_B,
shared_16x16_to_local_64x4_layout_A,
# shared_16x16_to_local_64x4_layout_A,
shared_16x16_to_local_64x4_layout_B,
shared_16x32_to_local_64x8_layout_A,
shared_16x32_to_local_64x8_layout_B,
......@@ -19,7 +19,7 @@ from .mfma_layout import (
shared_16x64_to_local_64x16_layout_B,
thread_id_shared_access_64x1_to_16x4_layout_A,
thread_id_shared_access_64x1_to_4x16_layout_B,
thread_id_shared_access_64x4_to_16x16_layout_A,
# thread_id_shared_access_64x4_to_16x16_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_B,
thread_id_shared_access_64x8_to_16x32_layout_A,
thread_id_shared_access_64x8_to_16x32_layout_B,
......@@ -27,10 +27,10 @@ from .mfma_layout import (
thread_id_shared_access_64x16_to_16x64_layout_B,
)
# from .mmac_layout import (
# shared_16x16_to_local_64x4_layout_A,
# thread_id_shared_access_64x4_to_16x16_layout_A,
# )
from .mmac_layout import (
shared_16x16_to_local_64x4_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_A,
)
lift = convert
......@@ -251,6 +251,21 @@ class MatrixCoreIntrinEmitter:
)
return lane_id, warp_n, warp_m
def map_64x16(self, row, col, idx, warp_rows, tx):
new_col = col
if warp_rows > 1:
inter_idx_padding = 2
else:
inter_idx_padding = 1
paddings = inter_idx_padding * self.block_row_warps * 4
print("paddings:", paddings)
new_row = row + paddings * ((tx & 15) // 4)
new_row += (idx & 1) * (paddings // 2) + (idx // 2) * 16 * 2 * self.block_row_warps
return new_row, new_col
def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0):
# share mem a needs warp number
warp_num = self.block_row_warps
......@@ -272,6 +287,7 @@ class MatrixCoreIntrinEmitter:
A_buf = A_region.buffer
A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min
print("A_base0, A_base1:", A_base0, A_base1)
@T.macro
def _warp_ldmatrix_a(
......@@ -281,31 +297,42 @@ class MatrixCoreIntrinEmitter:
thread_binding,
rk=0,
):
tx, _, warp_m = self.extract_thread_binding(thread_binding)
# warp_n[0-256] -> {0,1,2,3}
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
# {0..3,16..19,32..35,48..51} -> 0
# {4..7,20..23,36..39,52..55} -> 1
# {8..11,24..27,40..43,56..59} -> 2
# {12..15,28..31,44..47,60..63} -> 3
warp_interval_idx = (tx & 15)>>2
warp_group_idx = (tx // 32)
warp_group_idx = warp_n
# warp_rows 轮次 warp需要拆分成多轮来访问完整的行块
if is_transposed:
for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
# 每轮初始位置行偏移
row += i * warp_row_init
# warp 组行间隔
row += warp_group_idx * 4
# warp 内行间隔
row += warp_interval_idx * warp_row_interval
# row += i * warp_row_init
# # warp 组行间隔
# row += warp_group_idx * 4
# # warp 内行间隔
# row += warp_interval_idx * warp_row_interval
raise NotImplementedError("Transposed A with preshuffle is not implemented yet")
row, col = self.map_64x16(row, col, i, warp_rows, tx)
l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x)
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col]
else:
for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * (k_pack * micro_size_k))
# # 每轮初始位置行偏移
# row += i * warp_row_init
# # warp 组行间隔
# row += warp_group_idx * 4
# # warp 内行间隔
# row += warp_interval_idx * warp_row_interval
row, col = self.map_64x16(row, col, i, warp_rows, tx)
print("row, col:", row, col)
l, r = (warp_m * 4, rk * chunk + ki * (k_pack * micro_size_k))
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
......
......@@ -20,6 +20,10 @@ print("tileop gemm init...")
def gemm_py_infer_layout(gemm_py: GemmMMA, target: Target, thread_bounds: Range):
print("tileop gemm infer_layout")
thread_nums = thread_bounds.extent
print(f"gemm_py_infer_layout Target: {target}, thread_nums: {thread_nums}")
print(f"gemm_py_infer_layout gemm_py: {gemm_py}")
t = gemm_py.infer_layout(target, thread_nums)
print(f"gemm_py_infer_layout gemm_py.A: {gemm_py.A}, gemm_py.B: {gemm_py.B}, gemm_py.C: {gemm_py.C}")
return gemm_py.infer_layout(target, thread_nums)
......
......@@ -32,8 +32,10 @@ class GemmMMAC(GemmBase):
if self.is_gemm_ss():
return {
self.A: make_swizzled_layout(self.A),
self.B: make_swizzled_layout(self.B),
# self.A: make_swizzled_layout(self.A, allow_pad=False),
# self.B: make_swizzled_layout(self.B, allow_pad=False),
self.A: make_linear_layout(self.A),
self.B: make_linear_layout(self.B),
self.C: mmac_emitter.make_mmac_store_layout(self.C),
}
elif self.is_gemm_sr():
......
......@@ -559,4 +559,8 @@ def LowerSharedGlobalCopy():
def SimplifyDCUAsyncCopy():
"""SimplifyDCUAsyncCopy"""
return _ffi_api.SimplifyDCUAsyncCopy() # type: ignore
\ No newline at end of file
return _ffi_api.SimplifyDCUAsyncCopy() # type: ignore
def FixDCUWaitCount():
"""FixDCUWaitCount"""
return _ffi_api.FixDCUWaitCount() # type: ignore
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment