Commit b14f201e authored by qisan's avatar qisan
Browse files

Feats: Add async, pipeline and ds_read

parent 44cc93c7
......@@ -827,14 +827,12 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
func_name += "_trans";
print_extern_call_stmt(func_name, 2);
}else if(op->op.same_as(tl::ds_read_vector())){
//ds_read_b64 %1, %2 offset:%3
// ds_read_m32x16_b16 %0, %1 offset:%2
std::string dst = this->PrintExpr(op->args[0]);
std::string local_offset = this->PrintExpr(op->args[1]);
std::string lds_offset = this->PrintExpr(op->args[2]);
os << "tl::ds_read_vector("
os << "tl::ds_read_vector(*(float4_ *)("
<< dst << " + " << local_offset
<< ", "
<< "), "
<< lds_offset
<< ")";
}else if (op->op.same_as(tl::wait_wgmma())) {
......
......@@ -133,12 +133,11 @@ TL_DEVICE int32x4_t make_wave_buffer_resource(const void *ptr,
// }
// }
TL_DEVICE void ds_read_vector(void* dst, uint32_t lds_base_ptr)
TL_DEVICE void ds_read_vector(float4_& dst, uint32_t lds_base_ptr)
{
asm volatile("ds_read_m32x16_b16 %0, %1 offset:0\n\t"
: "+v"(dst)
: "v"(lds_base_ptr),
: "memory");
: "v"(lds_base_ptr));
}
// template <int M, int N, int offset>
......
......@@ -57,13 +57,15 @@ class BLocalLayoutTransformer : public StmtExprMutator {
int expand_;
Stmt VisitStmt_(const ForNode* op) final {
// 只处理 serial 外层循环
// 1. 先递归处理子节点(重要:确保处理了嵌套的 For 或 Attr)
Stmt new_body = this->VisitStmt(op->body);
// 2. 检查当前循环是否是目标循环
// 即使 body 变了,我们也尝试看看能不能在这个 loop 层级做变换
auto store = new_body.as<BufferStoreNode>();
if (op->kind != ForKind::kSerial) {
return StmtExprMutator::VisitStmt_(op);
}
// 判断是否是 B_local 写循环
auto store = op->body.as<BufferStoreNode>();
if (!store) {
return StmtExprMutator::VisitStmt_(op);
}
......@@ -79,7 +81,6 @@ class BLocalLayoutTransformer : public StmtExprMutator {
int64_t new_extent = old_extent / expand_;
// 修改循环范围
For new_for =
For(op->loop_var,
op->min,
......@@ -94,100 +95,42 @@ class BLocalLayoutTransformer : public StmtExprMutator {
std::string name = buffer->name;
return name.find("B_local") != std::string::npos;
}
Stmt MutateStore(const BufferStoreNode* store,
const Var& loop_var) {
Array<PrimExpr> new_indices = store->indices;
PrimExpr new_value = store->value;
// 修改切片跨度:
// 原来 j*vec : j*vec+vec
// 改为 j*vec : j*vec*expand + vec
PrimExpr idx = store->indices[0];
//T.Ramp(j * 4, 1, 4) -> Ramp(j*8, 1, 4)
std::cout << idx << std::endl;
// 解析 j*vec 结构
// 假设结构为 j * vec + const
// 不改 RHS
// PrimExpr value = store->value;
// 修改写入向量宽度
// 原 value 是 Ramp(base=j*4, stride=1, lanes=4)
// 匹配 j * stride
// Ramp(base=j*8, stride=1, lanes=8)
if (const auto* ramp = idx.as<RampNode>()) {
PrimExpr base = ramp->base;
PrimExpr stride = ramp->stride;
int old_lanes = ramp->lanes.as<IntImmNode>()->value;
int new_lanes = old_lanes * expand_;
// 匹配 base = j * stride_val
if (const auto* mul = base.as<MulNode>()) {
if (mul->a.same_as(loop_var)) {
int64_t old_stride =
mul->b.as<IntImmNode>()->value;
int64_t new_stride =
old_stride * expand_;
PrimExpr new_base =
loop_var *
make_const(DataType::Int(32), new_stride);
new_indices.Set(
0,
Ramp(new_base, stride, new_lanes));
}
else if (mul->b.same_as(loop_var)) {
int64_t old_stride =
mul->a.as<IntImmNode>()->value;
int64_t new_stride =
old_stride * expand_;
PrimExpr new_base =
make_const(DataType::Int(32), new_stride) *
loop_var;
new_indices.Set(
0,
Ramp(new_base, stride, new_lanes));
}
PrimExpr UpdateIndexBase(PrimExpr base, const Var& loop_var, int expand) {
if (const auto* add = base.as<AddNode>()) {
return UpdateIndexBase(add->a, loop_var, expand) +
UpdateIndexBase(add->b, loop_var, expand);
} else if (const auto* mul = base.as<MulNode>()) {
if (mul->a.same_as(loop_var)) {
return mul->a * (mul->b * expand);
} else if (mul->b.same_as(loop_var)) {
return (mul->a * expand) * mul->b;
}
}
return base;
}
Stmt MutateStore(const BufferStoreNode* store, const Var& loop_var) {
auto n = tvm::ffi::make_object<BufferStoreNode>(*store);
Array<PrimExpr> new_indices = store->indices;
if (const auto* ramp = store->indices[0].as<RampNode>()) {
PrimExpr new_base = UpdateIndexBase(ramp->base, loop_var, expand_);
int new_lanes = ramp->lanes.as<IntImmNode>()->value * expand_;
new_indices.Set(0, Ramp(new_base, ramp->stride, new_lanes));
}
}
if (auto* load = new_value.as<BufferLoadNode>()) {
// BufferLoad with region access: B_shared[start : end]
// end - start = lanes,需要同步扩展
Array<PrimExpr> value_indices = load->indices;
if (auto* old_ramp = load->indices[0].as<RampNode>()) {
PrimExpr scalar_base = old_ramp->base; // 必须是 scalar
PrimExpr stride = old_ramp->stride;
//RHS 4 lane
int old_lanes = old_ramp->lanes.as<IntImmNode>()->value;
//RHS 8 lane
int new_lanes = old_lanes * expand_;
value_indices.Set(
0,
Ramp(scalar_base, stride, new_lanes)
);
new_value = BufferLoad(load->buffer, value_indices);
PrimExpr new_value = store->value;
if (const auto* load = store->value.as<BufferLoadNode>()) {
if (const auto* l_ramp = load->indices[0].as<RampNode>()) {
Array<PrimExpr> v_indices = load->indices;
int v_new_lanes = l_ramp->lanes.as<IntImmNode>()->value * expand_;
v_indices.Set(0, Ramp(l_ramp->base, l_ramp->stride, v_new_lanes));
new_value = BufferLoad(load->buffer, v_indices);
}
}
}
return BufferStore(store->buffer,
new_value,
new_indices);
return BufferStore(store->buffer, new_value, new_indices);
}
};
......
......@@ -13,7 +13,6 @@ namespace tl {
using ffi::Array;
using namespace tir;
// 1. 辅助类:统计 Shared -> Register 的加载量
class LoadCounter : public StmtExprVisitor {
public:
int total_loads = 0;
......@@ -39,11 +38,147 @@ public:
}
ExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const CallNode* op) override {
std::string func_name = "";
if (auto opt_op = op->op.as<OpNode>()) {
func_name = opt_op->name;
} else if (auto global_var = op->op.as<GlobalVarNode>()) {
func_name = global_var->name_hint;
}
if (func_name.find("ds_read") != std::string::npos) {
total_loads += current_multiplier;
}
ExprVisitor::VisitExpr_(op);
}
private:
bool IsSharedMem(const Buffer& buf) {
std::string scope = buf.scope();
std::string name = buf->name;
return (scope == "shared" || name.find("shared") != std::string::npos ||
name.find("shmem") != std::string::npos || name.find("LDS") != std::string::npos);
}
};
// 2. 核心 Mutator
namespace {
bool StmtContainsMMA(const Stmt& stmt) {
bool found = false;
PostOrderVisit(stmt, [&found](const ObjectRef& node) {
if (const CallNode* call = node.as<CallNode>()) {
std::string op_name = "";
if (const OpNode* op = call->op.as<OpNode>()) {
op_name = op->name;
} else if (const GlobalVarNode* gv = call->op.as<GlobalVarNode>()) {
op_name = gv->name_hint;
}
if (op_name.find("mmac") != std::string::npos ||
op_name.find("mma") != std::string::npos) {
found = true;
}
}
});
return found;
}
void ScanStmtDefault(const Stmt& s, std::vector<Stmt>* fence_targets);
void ScanSeqStmt(const SeqStmtNode* op, std::vector<Stmt>* fence_targets) {
int pending = 0;
for (size_t i = 0; i < op->seq.size(); ++i) {
const Stmt& stmt = op->seq[i];
if (StmtContainsMMA(stmt)) {
if (pending > 0) {
fence_targets->push_back(stmt);
pending = 0;
}
ScanStmtDefault(stmt, fence_targets);
} else {
LoadCounter counter;
counter(stmt);
pending += counter.total_loads;
ScanStmtDefault(stmt, fence_targets);
}
}
}
void ScanStmtDefault(const Stmt& s, std::vector<Stmt>* fence_targets) {
if (const auto* seq = s.as<SeqStmtNode>()) {
ScanSeqStmt(seq, fence_targets);
return;
}
if (const auto* op = s.as<AttrStmtNode>()) {
ScanStmtDefault(op->body, fence_targets);
return;
}
if (const auto* op = s.as<LetStmtNode>()) {
ScanStmtDefault(op->body, fence_targets);
return;
}
if (const auto* op = s.as<IfThenElseNode>()) {
ScanStmtDefault(op->then_case, fence_targets);
if (op->else_case) {
ScanStmtDefault(op->else_case.value(), fence_targets);
}
return;
}
if (const auto* op = s.as<ForNode>()) {
ScanStmtDefault(op->body, fence_targets);
return;
}
if (const auto* op = s.as<WhileNode>()) {
ScanStmtDefault(op->body, fence_targets);
return;
}
if (const auto* op = s.as<AllocateNode>()) {
ScanStmtDefault(op->body, fence_targets);
return;
}
if (const auto* op = s.as<AllocateConstNode>()) {
ScanStmtDefault(op->body, fence_targets);
return;
}
if (const auto* op = s.as<DeclBufferNode>()) {
ScanStmtDefault(op->body, fence_targets);
return;
}
if (const auto* op = s.as<BufferRealizeNode>()) {
ScanStmtDefault(op->body, fence_targets);
return;
}
if (const auto* op = s.as<AssertStmtNode>()) {
ScanStmtDefault(op->body, fence_targets);
return;
}
if (const auto* op = s.as<BlockNode>()) {
if (op->init.defined()) {
ScanStmtDefault(op->init.value(), fence_targets);
}
ScanStmtDefault(op->body, fence_targets);
return;
}
if (const auto* op = s.as<BlockRealizeNode>()) {
ScanStmtDefault(op->block, fence_targets);
return;
}
}
Stmt ComputeGlobalLastFenceMMAStmt(const Stmt& root) {
std::vector<Stmt> fence_targets;
ScanStmtDefault(root, &fence_targets);
if (fence_targets.empty()) {
return Stmt();
}
return fence_targets.back();
}
}
class MMABarrierMutator : public StmtExprMutator {
public:
explicit MMABarrierMutator(const Stmt& root_body)
: global_last_fence_mma_(ComputeGlobalLastFenceMMAStmt(root_body)) {}
bool ContainsMMA(const Stmt& stmt) {
bool found = false;
PostOrderVisit(stmt, [&found](const ObjectRef& node) {
......@@ -64,23 +199,6 @@ public:
}
Stmt VisitStmt_(const SeqStmtNode* op) override {
// --- 步骤 1: 预扫描,确定最后一个需要插入 Fence 的位置 ---
int last_fence_idx = -1;
int temp_pending_count = 0;
for (size_t i = 0; i < op->seq.size(); ++i) {
if (ContainsMMA(op->seq[i])) {
if (temp_pending_count > 0) {
last_fence_idx = static_cast<int>(i);
temp_pending_count = 0; // 模拟重置
}
} else {
LoadCounter counter;
counter(op->seq[i]);
temp_pending_count += counter.total_loads;
}
}
// --- 步骤 2: 实际构造新的 Sequence ---
Array<Stmt> new_seq;
int pending_load_count = 0;
......@@ -89,16 +207,16 @@ public:
if (ContainsMMA(stmt)) {
if (pending_load_count > 0) {
// 判断是否是该序列中最后一个 Fence
int fence_val = (static_cast<int>(i) == last_fence_idx) ? 0 : pending_load_count;
int fence_val =
(global_last_fence_mma_.defined() && stmt.same_as(global_last_fence_mma_))
? 0
: pending_load_count;
Array<PrimExpr> args = {Integer(fence_val)};
// 构造 Fence
auto fence_call = Call(DataType::Void(), Op::Get("tl.async_gld_fence"), args);
new_seq.push_back(Evaluate(fence_call));
// 构造 Barrier
auto barrier_call = Call(DataType::Void(), Op::Get("tl.wave_barrier"), {});
new_seq.push_back(Evaluate(barrier_call));
......@@ -114,15 +232,17 @@ public:
}
return SeqStmt(new_seq);
}
private:
Stmt global_last_fence_mma_;
};
// 3. Pass 包装
namespace transform {
using namespace tir::transform;
Pass InsertAsyncMMAFence() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
MMABarrierMutator mutator;
MMABarrierMutator mutator(n->body);
n->body = mutator(n->body);
return f;
};
......
......@@ -20,9 +20,6 @@ using namespace tir;
using ffi::Array;
using ffi::String;
// ============================================================================
// 数据结构
// ============================================================================
struct CopyInfo {
Buffer dst_buffer;
Buffer src_buffer;
......@@ -33,10 +30,7 @@ struct CopyInfo {
struct CollectResult {
std::vector<CopyInfo> copies;
// 映射: Global Buffer Name -> DCU Resource Var (用于替换Store)
std::unordered_map<String, Var> global_to_res_var;
// 映射: Shared Buffer Name -> 要注入的LetStmt绑定 (Var, PrimExpr)
// 这样我们就可以根据 shared buffer 的位置来决定注入点
std::unordered_map<String, std::pair<Var, PrimExpr>> shared_alloc_to_binding;
const StmtNode* inject_target = nullptr;
......@@ -64,7 +58,6 @@ class VariableKeeper : public tvm::tir::ExprMutator {
: keep_vars_(keep_vars) {}
PrimExpr VisitExpr_(const tvm::tir::VarNode* op) override {
// 关键调试:打印每一个遇到的变量及其地址
if (keep_vars_.count(op)) {
return GetRef<PrimExpr>(op);
} else {
......@@ -72,10 +65,7 @@ class VariableKeeper : public tvm::tir::ExprMutator {
}
}
// 额外处理:防止 Load 节点中的变量丢失
PrimExpr VisitExpr_(const tvm::tir::BufferLoadNode* op) override {
// 如果你的索引里嵌套了 BufferLoad,Load 本身不是 Var,
// 但它里面可能含有 Var。Mutator 默认会递归,但我们可以显式打印。
return ExprMutator::VisitExpr_(op);
}
......@@ -83,9 +73,6 @@ class VariableKeeper : public tvm::tir::ExprMutator {
const std::unordered_set<const tvm::tir::VarNode*>& keep_vars_;
};
// ============================================================================
// Phase 1: 收集拷贝信息 & 生成资源绑定
// ============================================================================
CollectResult CollectResources(const Stmt& body) {
class Collector : public StmtExprVisitor {
public:
......@@ -94,7 +81,7 @@ CollectResult CollectResources(const Stmt& body) {
private:
bool in_async{false};
std::unordered_set<const tvm::tir::VarNode*> loop_vars_;
std::vector<const tvm::tir::StmtNode*> scope_stack_; // 追踪当前遍历的 AST 路径
std::vector<const tvm::tir::StmtNode*> scope_stack_;
bool IsSharedScope(const Buffer& buf) {
auto s = buf.scope();
return s == "shared" || s == "shared.dyn";
......@@ -107,7 +94,6 @@ CollectResult CollectResources(const Stmt& body) {
void VisitStmt_(const AttrStmtNode* attr) override {
scope_stack_.push_back(attr);
if (attr->attr_key == tir::attr::thread_extent) {
// 1. 获取 IterVar
auto iv = attr->node.as<tvm::tir::IterVarNode>();
const std::string& tag = iv->thread_tag;
......@@ -119,7 +105,6 @@ CollectResult CollectResources(const Stmt& body) {
loop_vars_.erase(thread_var.get());
} else {
// 如果是 blockIdx 或其他,直接跳过当前层继续往下走
StmtExprVisitor::VisitStmt_(attr);
}
......@@ -173,7 +158,6 @@ CollectResult CollectResources(const Stmt& body) {
if (scope_stack_[i]->IsInstance<AttrStmtNode>()) {
auto attr = static_cast<const AttrStmtNode*>(scope_stack_[i]);
if (attr->attr_key == tvm::tir::attr::thread_extent) {
// 找到了最内层的线程绑定。它里面的下一个节点(i+1)就是我们应该包裹的节点
if (i + 1 < scope_stack_.size()) {
result.inject_target = scope_stack_[i + 1];
}
......@@ -189,12 +173,10 @@ CollectResult CollectResources(const Stmt& body) {
}
}
}
// 如果还是空,直接 fallback 到当前操作
if (result.inject_target == nullptr) result.inject_target = op;
}
// 1. 记录拷贝
VariableKeeper keeper(loop_vars_);
tvm::arith::Analyzer analyzer;
Array<PrimExpr> for_var_only_indices;
......@@ -206,7 +188,6 @@ CollectResult CollectResources(const Stmt& body) {
CopyInfo info{dst, src, op->indices, for_var_only_indices, GetRef<Stmt>(op)};
result.copies.push_back(info);
// 2. 只有当没处理过这个 Global Buffer 时才生成 Binding
if (result.global_to_res_var.find(src->name) == result.global_to_res_var.end()) {
Var var(src->name + "_dcu_res", DataType::Int(32, 4));
......@@ -214,17 +195,13 @@ CollectResult CollectResources(const Stmt& body) {
tvm::arith::Analyzer analyzer;
Array<PrimExpr> base_indices;
for (const auto& idx : load->indices) {
// 将所有外层循环变量 (k, i 等) 全部替换为 0
PrimExpr no_loops = eliminator(idx);
// 化简出最终的基地址表达式
base_indices.push_back(analyzer.Simplify(no_loops));
}
// ✅ 关键点:填充真实的地址信息 src->data (即 A.data)
Array<PrimExpr> args;
args.push_back(src->data); // 先加 data
args.push_back(src->data);
// 如果需要把 indices 的每个元素作为独立参数展开:
for (const auto& idx : base_indices) {
args.push_back(idx);
}
......@@ -232,7 +209,6 @@ CollectResult CollectResources(const Stmt& body) {
Op::Get("tl.make_dcu_resource"), args);
result.global_to_res_var[src->name] = var;
// 将这个绑定关系和 destination 的 shared buffer 绑死
result.shared_alloc_to_binding[src->name] = {var, val};
}
}
......@@ -247,9 +223,6 @@ CollectResult CollectResources(const Stmt& body) {
return col.result;
}
// ============================================================================
// Phase 2: 替换 BufferStore -> dcu_async_copy
// ============================================================================
class StoreReplacer : public StmtExprMutator {
public:
static Stmt Run(Stmt body, const std::vector<CopyInfo>& copies,
......@@ -268,16 +241,14 @@ private:
auto body = this->VisitStmt(attr->body);
return body;
}
return StmtMutator::VisitStmt_(attr); // ③ 其他属性:默认保留
return StmtMutator::VisitStmt_(attr);
}
Stmt VisitStmt_(const BufferStoreNode* op) final {
for (const auto& copy : copies_) {
if (copy.store_stmt.same_as(GetRef<Stmt>(op))) {
// Global 取 resource var (A_dcu_res)
Var src_res = global_to_var_.at(copy.src_buffer->name);
// Shared 取 data pointer (A_shared.data)
PrimExpr dst_res = copy.dst_buffer->data;
PrimExpr copy_size = IntImm(DataType::Int(32), 1);
......@@ -305,9 +276,6 @@ private:
const std::unordered_map<String, Var>& global_to_var_;
};
// ============================================================================
// Phase 3: 根据 Shared Alloc 位置进行精准注入
// ============================================================================
class ResourceInjector : public tvm::tir::StmtExprMutator {
public:
static Stmt Run(Stmt body,
......@@ -324,18 +292,15 @@ private:
: bindings_(bindings), target_(target) {}
Stmt VisitStmt(const Stmt& stmt) override {
// 当我们遍历到刚才标记的那个 AST 节点时
if (stmt.get() == target_) {
// 先向下遍历(保持 TVM Mutator 的习惯)
Stmt new_stmt = StmtExprMutator::VisitStmt(stmt);
// 在这个节点的外面套上所有的 LetStmt
for (const auto& item : bindings_) {
Var res_var = item.second.first;
PrimExpr init_expr = item.second.second;
new_stmt = tvm::tir::LetStmt(res_var, init_expr, new_stmt);
}
return new_stmt; // 返回包裹好的新节点
return new_stmt;
}
return StmtExprMutator::VisitStmt(stmt);
}
......@@ -344,26 +309,18 @@ private:
const tvm::tir::StmtNode* target_;
};
// ============================================================================
// Pass 入口
// ============================================================================
PrimFunc LowerSharedGlobalCopy(PrimFunc f) {
auto* n = f.CopyOnWrite();
// 收集信息
auto res = CollectResources(n->body);
if (res.copies.empty()){
return f;
}
// 注入res声明
Stmt injected = ResourceInjector::Run(n->body, res.shared_alloc_to_binding, res.inject_target);
// 替换拷贝语句
Stmt replaced = StoreReplacer::Run(injected, res.copies, res.global_to_res_var);
// 写回
n->body = std::move(replaced);
return GetRef<PrimFunc>(n);
......
......@@ -31,7 +31,6 @@ private:
PrimExpr k_extent_;
bool in_unrolled_i_ = false;
// 通用的步长提取函数:从 expr 中提取指定 var 的步长,并返回剩余的 base
std::pair<PrimExpr, PrimExpr> ExtractStride(PrimExpr expr, Var var) {
if (!var.defined()) return {expr, make_zero(expr->dtype)};
......@@ -41,33 +40,26 @@ private:
return {analyzer_.Simplify(base), stride};
}
// 核心重写逻辑
Stmt RewriteAsyncCopy(const CallNode* call, Var i_var, PrimExpr i_extent) {
// 1. 预处理:剥离 RampNode 获得基础偏移
PrimExpr raw_dst_off = call->args[1];
PrimExpr raw_src_off = call->args[3];
if (const RampNode* r = raw_dst_off.as<RampNode>()) raw_dst_off = r->base;
if (const RampNode* r = raw_src_off.as<RampNode>()) raw_src_off = r->base;
// 2. 提取 i 的步长
auto [base_i_dst, i_stride_dst] = ExtractStride(raw_dst_off, i_var);
auto [base_i_src, i_stride_src] = ExtractStride(raw_src_off, i_var);
// 3. 提取 k 的步长 (始终尝试从 base_i_src 中提取,无论是否存在 i 循环)
// 只要 k_var_ 在外层循环中定义了,这里就能提取出非 0 的步长
auto [final_src_offset, k_stride_src] = ExtractStride(base_i_src, k_var_);
// 构造最初要求的 8 个参数:
// [dst, dst_off, src, src_off, i_extent, i_stride_dst, i_stride_src, k_stride_src]
Array<PrimExpr> new_args = {
call->args[0], // dst_buf
base_i_dst, // 最终 dst 偏移
call->args[2], // src_buf
final_src_offset, // 最终 src 偏移
i_extent, // i 循环次数 (无循环时为 0)
i_stride_dst, // i 的 dst 步长 (无循环时为 0)
i_stride_src, // i 的 src 步长 (无循环时为 0)
k_stride_src // k 的 src 步长 (即便无 i 循环,这里也能拿到 k 的步长)
call->args[0],
base_i_dst,
call->args[2],
final_src_offset,
i_extent,
i_stride_dst,
i_stride_src,
k_stride_src
};
return Evaluate(Call(call->dtype, call->op, new_args));
......@@ -78,7 +70,6 @@ private:
if (!in_unrolled_i_) {
if (const CallNode* call = op->value.as<CallNode>()) {
static const Op& dcu_copy_op = Op::Get("tl.dcu_async_copy");
// 只要参数个数不是 8 (我们重写后的目标个数),就进行处理
if (call->op.same_as(dcu_copy_op) && call->args.size() != 8) {
return RewriteAsyncCopy(call, Var(), make_zero(DataType::Int(32)));
}
......@@ -108,7 +99,6 @@ private:
if (const CallNode* call = eval->value.as<CallNode>()) {
static const Op& dcu_copy_op = Op::Get("tl.dcu_async_copy");
if (call->op.same_as(dcu_copy_op)) {
// 还原 k 并在返回前处理重写
Stmt result = RewriteAsyncCopy(call, op->loop_var, op->extent);
return result;
}
......@@ -116,7 +106,6 @@ private:
}
}
// 退出循环时清理 k 信息
if (is_k) {
k_var_ = Var();
k_extent_ = PrimExpr();
......@@ -128,9 +117,7 @@ private:
return Stmt(n);
}
};
// ============================================================================
// Pass 入口
// ============================================================================
PrimFunc SimplifyDCUAsyncCopy(PrimFunc f) {
auto* n = f.CopyOnWrite();
n->body = AsyncCopySimplifier::Run(std::move(n->body));
......
......@@ -222,25 +222,16 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.RegisterPipelinePlanning()(mod)
print("OptimizeForTarget")
print(mod)
mod = tilelang.transform.InjectRegisterSoftwarePipeline()(mod)
print("OptimizeForTarget2")
print(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
print("OptimizeForTarget2")
print(mod)
mod = tilelang.transform.MergeIfStmt()(mod)
if allow_fence_proxy(target=target):
# in hopper device, wgmma is an async proxy
# so we need to inject a fence proxy before it
mod = tilelang.transform.InjectFenceProxy()(mod)
print("OptimizeForTarget2.5")
print(mod)
mod = tilelang.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.Simplify()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
......@@ -311,8 +302,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
if dcu_async_copy_supported(target):
print("--------------support dcu async copy------------------")
mod = tilelang.transform.LowerSharedGlobalCopy()(mod)
print("222222222")
print(mod)
mod = tilelang.transform.FixDCUWaitCount()(mod)
mod = tilelang.transform.InjectBLocalLayoutTransform()(mod)
print("InjectBLocalLayoutTransform ............")
......@@ -321,8 +310,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
print("InjectDSRead ............")
print(mod)
mod = tilelang.transform.InsertAsyncMMAFence()(mod)
print("333333333")
print(mod)
# Register pipeline planning only writes software_pipeline annotations.
# We must inject after planning so prologue/body/epilogue are materialized.
# mod = tilelang.transform.RegisterPipelinePlanning()(mod)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment