Commit 32d0b3cb authored by qisan's avatar qisan
Browse files

Feats: support async_copy pass!

parent a0ec0f57
...@@ -10,13 +10,13 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.fl ...@@ -10,13 +10,13 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.fl
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local) T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
T.copy(A[by * block_M, k * block_K], A_shared) T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared) T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local) T.gemm(A_shared, B_shared, C_local)
...@@ -27,7 +27,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.fl ...@@ -27,7 +27,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.fl
def main(): def main():
kernel = matmul(1024, 1024, 1024, 128, 128, 32) kernel = matmul(1024, 1024, 1024, 256, 256, 16)
import torch import torch
......
...@@ -387,5 +387,16 @@ TIR_DEFINE_TL_BUILTIN(warp_reduce_bitor) ...@@ -387,5 +387,16 @@ TIR_DEFINE_TL_BUILTIN(warp_reduce_bitor)
TIR_DEFINE_TL_BUILTIN(__ldg).set_num_inputs(-1).set_attr<TCallEffectKind>( TIR_DEFINE_TL_BUILTIN(__ldg).set_num_inputs(-1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kPure)); "TCallEffectKind", Integer(CallEffectKind::kPure));
//
TIR_DEFINE_TL_BUILTIN(dcu_async_copy)
.set_num_inputs(6)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(make_dcu_resource)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -888,7 +888,7 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower, ...@@ -888,7 +888,7 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
*/ */
Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Target target = T.target; Target target = T.target;
printf("Lowering CopyNode with target: %s\n", target->str().c_str());
using namespace tvm::transform; using namespace tvm::transform;
PassContext pass_ctx = PassContext::Current(); PassContext pass_ctx = PassContext::Current();
bool disable_tma_lower = bool disable_tma_lower =
...@@ -940,6 +940,7 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -940,6 +940,7 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
*/ */
Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, Stmt CopyNode::LowerNormalCopy(const LowerArgs &T,
arith::Analyzer *analyzer) const { arith::Analyzer *analyzer) const {
printf("Lowering normal copy for target: %s\n", T.target->str().c_str());
bool is_cpu_target = T.target->GetTargetDeviceType() == kDLCPU; bool is_cpu_target = T.target->GetTargetDeviceType() == kDLCPU;
auto simt_loop = MakeSIMTLoop(analyzer); auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop)); auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
......
...@@ -290,6 +290,7 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T, ...@@ -290,6 +290,7 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T,
LayoutMap results; LayoutMap results;
if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.infer_layout")) { if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.infer_layout")) {
printf("GemmPyNode::InferLayout: calling tl.gemm_py.infer_layout\n");
results = Downcast<LayoutMap>( results = Downcast<LayoutMap>(
(*f)(tvm::ffi::GetRef<GemmPy>(this), T.target, T.thread_bounds)); (*f)(tvm::ffi::GetRef<GemmPy>(this), T.target, T.thread_bounds));
// Bind all fragment layouts with the provided thread range // Bind all fragment layouts with the provided thread range
...@@ -303,7 +304,8 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T, ...@@ -303,7 +304,8 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T,
} else { } else {
LOG(FATAL) << "No infer layout function found for gemm_py"; LOG(FATAL) << "No infer layout function found for gemm_py";
} }
LOG(INFO) << "GemmPyNode::InferLayout results:";
LOG(INFO) << results;
completed_ = true; completed_ = true;
return results; return results;
} }
......
...@@ -242,7 +242,8 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -242,7 +242,8 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const { InferLevel level) const {
if (loop_layout_.defined()) if (loop_layout_.defined())
return {}; return {};
LOG(INFO) << "Inferring layout for T.Parallel loop with inference level "
<< static_cast<int>(level) << "...\n";
// Expand let bindings to find fragment buffer accesses // Expand let bindings to find fragment buffer accesses
if (!T.let_var_to_expr.empty()) { if (!T.let_var_to_expr.empty()) {
const_cast<ParallelOpNode *>(this)->ExpandLetBindings(T.let_var_to_expr); const_cast<ParallelOpNode *>(this)->ExpandLetBindings(T.let_var_to_expr);
...@@ -424,7 +425,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -424,7 +425,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
LOG(FATAL) << msg.str(); LOG(FATAL) << msg.str();
} }
} }
DLOG(INFO) << "[compute_loop_layout_from_buffer] ... and get " LOG(INFO) << "[compute_loop_layout_from_buffer] ... and get "
<< result->DebugOutput() << '\n'; << result->DebugOutput() << '\n';
return result; return result;
}; };
......
...@@ -381,7 +381,9 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream &os) { // NOLINT(*) ...@@ -381,7 +381,9 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
case 32: { case 32: {
if (t.is_scalar()) { if (t.is_scalar()) {
os << "int"; os << "int";
} else if (t.lanes() <= 4) { } else if (t.lanes() == 4) {
os << "int32x4_t";
} else if (t.lanes() < 4) {
os << "int" << t.lanes(); os << "int" << t.lanes();
} else if (t.lanes() <= 8) { } else if (t.lanes() <= 8) {
// Emit CUDA code to access int32 vector elements for 4 < lanes <= 8. // Emit CUDA code to access int32 vector elements for 4 < lanes <= 8.
...@@ -1088,7 +1090,58 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1088,7 +1090,58 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
// HIP doesn't need explicit register management like CUDA // HIP doesn't need explicit register management like CUDA
// This is a no-op for HIP // This is a no-op for HIP
return; return;
} else if (op->op.same_as(Op::Get("tl.make_dcu_resource"))) {
CHECK_EQ(op->args.size(), 2) << "make_dcu_resource expects 2 arguments";
std::string base_ptr = this->PrintExpr(op->args[0]);
std::string offset;
if (const RampNode* ramp = op->args[1].as<RampNode>()) {
offset = this->PrintExpr(ramp->base);
} else { } else {
offset = this->PrintExpr(op->args[1]);
}
os << "make_wave_buffer_resource(" << base_ptr << " + (" << offset << "))";
}
else if (op->op.same_as(Op::Get("tl.dcu_async_copy"))) {
// 1. 提取模板参数 (IntImm 直接取值)
auto get_int_const = [](const PrimExpr& e) -> int {
if (const auto* val = e.as<IntImmNode>()) return static_cast<int>(val->value);
return 0;
};
int N = 16;
int smem_offset = 0;
int load_count = get_int_const(op->args[4]);
int i_sstride = get_int_const(op->args[5]);
int i_gstride = get_int_const(op->args[6]);
int k_gstride = get_int_const(op->args[7]);
// 2. 将运行时参数打印到字符串中 (防止直接操作 stream 导致冲突)
std::string dst_ptr = this->PrintExpr(op->args[0]);
std::string dst_off = this->PrintExpr(op->args[1]);
std::string src_res = this->PrintExpr(op->args[2]);
std::string src_off = this->PrintExpr(op->args[3]);
// 3. 仿照范例进行流输出
this->PrintIndent();
this->stream << "cp_async_gs<"
<< N << ", "
<< smem_offset << ", "
<< load_count << ", "
<< i_sstride << ", "
<< i_gstride << ", "
<< k_gstride << ">(";
// 拼接第一个参数:(char*)dst + dst_off
this->stream << "((char*)" << dst_ptr << " + " << dst_off << "), ";
// 拼接第二个参数:src_res
this->stream << src_res << ", ";
// 拼接第三个参数:src_off
this->stream << src_off << ");\n";
}
else {
printf("[DEBUG VisitExpr_] Branch: CodeGenC::VisitExpr_ (fallback)\n"); printf("[DEBUG VisitExpr_] Branch: CodeGenC::VisitExpr_ (fallback)\n");
CodeGenC::VisitExpr_(op, os); CodeGenC::VisitExpr_(op, os);
} }
......
...@@ -38,6 +38,20 @@ __device__ void inc_m0(uint32_t m0_inc) { ...@@ -38,6 +38,20 @@ __device__ void inc_m0(uint32_t m0_inc) {
asm volatile("s_add_u32 m0, %0, m0" : : "n"(m0_inc) : "memory"); asm volatile("s_add_u32 m0, %0, m0" : : "n"(m0_inc) : "memory");
} }
#define UPDATE_WAVE_BUFFER_RESOURCE(res, stride) \
do { \
/* 1. 提取 64 位基地址,确保低位不进行符号位扩展 */ \
uint64_t __current_addr = (static_cast<uint64_t>((res).y) << 32) | \
(static_cast<uint32_t>((res).x)); \
\
/* 2. 增加步长 (自动处理类型提升) */ \
__current_addr += (stride); \
\
/* 3. 写回分量到 SGPRs */ \
(res).x = static_cast<int32_t>(__current_addr); \
(res).y = static_cast<int32_t>(__current_addr >> 32); \
} while (0)
namespace tl { namespace tl {
// AMDGPU automatically commit memory fence // AMDGPU automatically commit memory fence
...@@ -72,20 +86,96 @@ CK_TILE_DEVICE void async_buffer_load_dword_v(void *smem, int32x4_t rsrc, ...@@ -72,20 +86,96 @@ CK_TILE_DEVICE void async_buffer_load_dword_v(void *smem, int32x4_t rsrc,
: "memory"); : "memory");
} }
template <int N> template <int N, int smem_offset, int load_count, int i_sstride, int i_gstride, int k_gstride>
TL_DEVICE void cp_async_gs(void *lds_base_ptr, void *global_base_ptr) { TL_DEVICE void cp_async_gs(void *lds_base_ptr, int32x4_t res, int offset) {
if constexpr (N == 16) { if constexpr (N == 16) {
*(uint4 *)lds_base_ptr = *(uint4 *)global_base_ptr; if constexpr (load_count == 1){
} else if constexpr (N == 8) { async_buffer_load_dwordx4_v<smem_offset>(
*(uint2 *)lds_base_ptr = *(uint2 *)global_base_ptr;
} else if constexpr (N == 4) {
async_buffer_load_dword_v(
lds_base_ptr, lds_base_ptr,
make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x), res,
threadIdx.x * N /*assume 4 bytes*/); offset
);
UPDATE_WAVE_BUFFER_RESOURCE(res, k_gstride);
}
else if constexpr (load_count == 2){
async_buffer_load_dwordx4_v<smem_offset>(
lds_base_ptr,
res,
current_offset
);
UPDATE_WAVE_BUFFER_RESOURCE(res, i_gstride);
async_buffer_load_dwordx4_v<smem_offset + i_sstride>(
lds_base_ptr,
res,
current_offset
);
UPDATE_WAVE_BUFFER_RESOURCE(res, k_gstride - i_gstride);
}
else if constexpr (load_count == 4){
async_buffer_load_dwordx4_v<smem_offset>(
lds_base_ptr,
res,
current_offset
);
UPDATE_WAVE_BUFFER_RESOURCE(res, i_gstride);
async_buffer_load_dwordx4_v<smem_offset + i_sstride>(
lds_base_ptr,
res,
current_offset
);
UPDATE_WAVE_BUFFER_RESOURCE(res, i_gstride);
async_buffer_load_dwordx4_v<smem_offset + 2 * i_sstride>(
lds_base_ptr,
res,
current_offset
);
UPDATE_WAVE_BUFFER_RESOURCE(res, i_gstride);
async_buffer_load_dwordx4_v<smem_offset + 3 * i_sstride>(
lds_base_ptr,
res,
current_offset
);
UPDATE_WAVE_BUFFER_RESOURCE(res, k_gstride - 3 * i_gstride);
}
else {
#pragma unroll
for (int i = 0; i < load_count - 1; ++i) {
async_buffer_load_dwordx4_v<smem_offset>(
lds_base_ptr + i * i_sstride,
res,
current_offset
);
UPDATE_WAVE_BUFFER_RESOURCE(res, i_gstride);
}
async_buffer_load_dwordx4_v<smem_offset>(
lds_base_ptr + (load_count - 1) * i_sstride,
res,
current_offset
);
UPDATE_WAVE_BUFFER_RESOURCE(res, k_gstride - (load_count - 1) * i_gstride);
}
}
else {
not implemented;
} }
} }
template <int N>
// TL_DEVICE void cp_async_gs(void *lds_base_ptr, void *global_base_ptr) {
// if constexpr (N == 16) {
// *(uint4 *)lds_base_ptr = *(uint4 *)global_base_ptr;
// } else if constexpr (N == 8) {
// *(uint2 *)lds_base_ptr = *(uint2 *)global_base_ptr;
// } else if constexpr (N == 4) {
// async_buffer_load_dword_v(
// lds_base_ptr,
// make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x),
// threadIdx.x * N /*assume 4 bytes*/);
// }
// }
template <int M, int N, int offset> template <int M, int N, int offset>
TL_DEVICE void ds_read_vector(float4_& dst, uint32_t lds_base_ptr) TL_DEVICE void ds_read_vector(float4_& dst, uint32_t lds_base_ptr)
{ {
......
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/op.h>
#include <tvm/tir/builtin.h>
#include <tvm/ffi/cast.h>
#include <tvm/ffi/memory.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/arith/analyzer.h>
#include <vector>
#include <unordered_map>
#include <unordered_set>
using tvm::ffi::GetRef;
using tvm::ffi::make_object;
namespace tvm {
namespace tl {
using namespace tir;
using ffi::Array;
using ffi::String;
// ============================================================================
// 数据结构
// ============================================================================
struct CopyInfo {
Buffer dst_buffer;
Buffer src_buffer;
Array<PrimExpr> dst_indices;
Array<PrimExpr> src_indices;
Stmt store_stmt;
};
struct CollectResult {
std::vector<CopyInfo> copies;
// 映射: Global Buffer Name -> DCU Resource Var (用于替换Store)
std::unordered_map<String, Var> global_to_res_var;
// 映射: Shared Buffer Name -> 要注入的LetStmt绑定 (Var, PrimExpr)
// 这样我们就可以根据 shared buffer 的位置来决定注入点
std::unordered_map<String, std::pair<Var, PrimExpr>> shared_alloc_to_binding;
const StmtNode* inject_target = nullptr;
};
class VariableEliminator : public tvm::tir::ExprMutator {
public:
explicit VariableEliminator(const std::unordered_set<const tvm::tir::VarNode*>& vars)
: vars_to_remove_(vars) {}
PrimExpr VisitExpr_(const tvm::tir::VarNode* op) override {
if (vars_to_remove_.count(op)) {
return tvm::tir::make_zero(op->dtype);
}
return GetRef<PrimExpr>(op);
}
private:
const std::unordered_set<const tvm::tir::VarNode*>& vars_to_remove_;
};
class VariableKeeper : public tvm::tir::ExprMutator {
public:
explicit VariableKeeper(const std::unordered_set<const tvm::tir::VarNode*>& keep_vars)
: keep_vars_(keep_vars) {}
PrimExpr VisitExpr_(const tvm::tir::VarNode* op) override {
// 关键调试:打印每一个遇到的变量及其地址
if (keep_vars_.count(op)) {
LOG(INFO) << "[KEEP] Found var in list: " << op->name_hint << " (" << op << ")";
return GetRef<PrimExpr>(op);
} else {
LOG(INFO) << "[ERASE] Var not in list: " << op->name_hint << " (" << op << ")";
return tvm::tir::make_zero(op->dtype);
}
}
// 额外处理:防止 Load 节点中的变量丢失
PrimExpr VisitExpr_(const tvm::tir::BufferLoadNode* op) override {
// 如果你的索引里嵌套了 BufferLoad,Load 本身不是 Var,
// 但它里面可能含有 Var。Mutator 默认会递归,但我们可以显式打印。
return ExprMutator::VisitExpr_(op);
}
private:
const std::unordered_set<const tvm::tir::VarNode*>& keep_vars_;
};
// ============================================================================
// Phase 1: 收集拷贝信息 & 生成资源绑定
// ============================================================================
CollectResult CollectResources(const Stmt& body) {
class Collector : public StmtExprVisitor {
public:
CollectResult result;
private:
std::unordered_set<const tvm::tir::VarNode*> loop_vars_;
std::vector<const tvm::tir::StmtNode*> scope_stack_; // 追踪当前遍历的 AST 路径
bool IsSharedScope(const Buffer& buf) {
auto s = buf.scope();
return s == "shared" || s == "shared.dyn";
}
bool IsGlobalScope(const Buffer& buf) {
auto s = buf.scope();
return s == "global" || s == "";
}
void VisitStmt_(const AttrStmtNode* op) override {
scope_stack_.push_back(op);
if (op->attr_key == tvm::tir::attr::thread_extent) {
// 1. 获取 IterVar
auto iv = op->node.as<tvm::tir::IterVarNode>();
const std::string& tag = iv->thread_tag;
// 2. 只有当 tag 包含 "threadIdx" 时才加入 (过滤掉 blockIdx)
// 比如: "threadIdx.x", "threadIdx.y", "threadIdx.z"
if (tag.find("threadIdx") != std::string::npos) {
tvm::tir::Var thread_var = iv->var;
loop_vars_.insert(thread_var.get());
StmtExprVisitor::VisitStmt_(op);
loop_vars_.erase(thread_var.get());
} else {
// 如果是 blockIdx 或其他,直接跳过当前层继续往下走
StmtExprVisitor::VisitStmt_(op);
}
}
scope_stack_.pop_back();
}
void VisitStmt_(const SeqStmtNode* op) override {
scope_stack_.push_back(op);
StmtExprVisitor::VisitStmt_(op);
scope_stack_.pop_back();
}
void VisitStmt_(const ForNode* op) override {
scope_stack_.push_back(op);
loop_vars_.insert(op->loop_var.get());
StmtExprVisitor::VisitStmt_(op);
loop_vars_.erase(op->loop_var.get());
scope_stack_.pop_back();
}
void VisitStmt_(const BufferStoreNode* op) final {
Buffer dst = op->buffer;
if (IsSharedScope(dst) && op->value.defined()) {
if (const auto* load = op->value.as<BufferLoadNode>()) {
Buffer src = load->buffer;
if (IsGlobalScope(src)) {
if (result.inject_target == nullptr) {
// 从下往上回溯栈,寻找最内层的 thread_extent
for (int i = scope_stack_.size() - 1; i >= 0; --i) {
if (scope_stack_[i]->IsInstance<AttrStmtNode>()) {
auto attr = static_cast<const AttrStmtNode*>(scope_stack_[i]);
if (attr->attr_key == tvm::tir::attr::thread_extent) {
// 找到了最内层的线程绑定。它里面的下一个节点(i+1)就是我们应该包裹的节点
if (i + 1 < scope_stack_.size()) {
result.inject_target = scope_stack_[i + 1];
}
break;
}
}
}
if (result.inject_target == nullptr && !scope_stack_.empty()) {
for (const auto* node : scope_stack_) {
if (node->IsInstance<ForNode>() || node->IsInstance<SeqStmtNode>()) {
result.inject_target = node;
break;
}
}
}
// 如果还是空,直接 fallback 到当前操作
if (result.inject_target == nullptr) result.inject_target = op;
}
// 1. 记录拷贝
VariableKeeper keeper(loop_vars_);
tvm::arith::Analyzer analyzer;
Array<PrimExpr> for_var_only_indices;
for (const auto& idx : load->indices) {
PrimExpr filtered = keeper(idx);
for_var_only_indices.push_back(analyzer.Simplify(filtered));
LOG(INFO) << "ONLY Index: " << idx;
}
CopyInfo info{dst, src, op->indices, for_var_only_indices, GetRef<Stmt>(op)};
result.copies.push_back(info);
// 2. 只有当没处理过这个 Global Buffer 时才生成 Binding
if (result.global_to_res_var.find(src->name) == result.global_to_res_var.end()) {
Var var(src->name + "_dcu_res", DataType::Int(32, 4));
VariableEliminator eliminator(loop_vars_);
tvm::arith::Analyzer analyzer;
Array<PrimExpr> base_indices;
LOG(INFO) << loop_vars_.size() << " loop vars in context.";
for (const auto* var : loop_vars_) {
LOG(INFO) << "Loop Var: " << var->name_hint;
}
for (const auto& idx : load->indices) {
// 将所有外层循环变量 (k, i 等) 全部替换为 0
PrimExpr no_loops = eliminator(idx);
// 化简出最终的基地址表达式
base_indices.push_back(analyzer.Simplify(no_loops));
}
// ✅ 关键点:填充真实的地址信息 src->data (即 A.data)
Array<PrimExpr> args;
args.push_back(src->data); // 先加 data
// 如果需要把 indices 的每个元素作为独立参数展开:
for (const auto& idx : base_indices) {
args.push_back(idx);
LOG(INFO) << "Clean Index: " << idx;
}
PrimExpr val = Call(DataType::Int(32, 4),
Op::Get("tl.make_dcu_resource"), args);
result.global_to_res_var[src->name] = var;
// 将这个绑定关系和 destination 的 shared buffer 绑死
result.shared_alloc_to_binding[src->name] = {var, val};
}
}
}
}
StmtExprVisitor::VisitStmt_(op);
}
};
Collector col;
col(body);
return col.result;
}
// ============================================================================
// Phase 2: 替换 BufferStore -> dcu_async_copy
// ============================================================================
class StoreReplacer : public StmtExprMutator {
public:
static Stmt Run(Stmt body, const std::vector<CopyInfo>& copies,
const std::unordered_map<String, Var>& global_to_var) {
StoreReplacer replacer(copies, global_to_var);
return replacer(std::move(body));
}
private:
StoreReplacer(const std::vector<CopyInfo>& copies,
const std::unordered_map<String, Var>& global_to_var)
: copies_(copies), global_to_var_(global_to_var) {}
Stmt VisitStmt_(const BufferStoreNode* op) final {
for (const auto& copy : copies_) {
if (copy.store_stmt.same_as(GetRef<Stmt>(op))) {
// Global 取 resource var (A_dcu_res)
Var src_res = global_to_var_.at(copy.src_buffer->name);
// Shared 取 data pointer (A_shared.data)
PrimExpr dst_res = copy.dst_buffer->data;
PrimExpr copy_size = IntImm(DataType::Int(32), 1);
PrimExpr predicate = Bool(true);
return Evaluate(
Call(DataType::Int(32), Op::Get("tl.dcu_async_copy"),
{dst_res, Flatten(copy.dst_indices),
src_res, Flatten(copy.src_indices),
copy_size, predicate}));
}
}
return StmtExprMutator::VisitStmt_(op);
}
PrimExpr Flatten(const Array<PrimExpr>& idx) {
if (idx.empty()) return IntImm(DataType::Int(32), 0);
if (idx.size() == 1) return idx[0];
PrimExpr r = idx[0];
for (size_t i = 1; i < idx.size(); ++i) r = r + idx[i];
return r;
}
const std::vector<CopyInfo>& copies_;
const std::unordered_map<String, Var>& global_to_var_;
};
// ============================================================================
// Phase 3: 根据 Shared Alloc 位置进行精准注入
// ============================================================================
class ResourceInjector : public tvm::tir::StmtExprMutator {
public:
static Stmt Run(Stmt body,
const std::unordered_map<String, std::pair<Var, PrimExpr>>& bindings,
const tvm::tir::StmtNode* target) {
if (!target || bindings.empty()) return body;
ResourceInjector mutator(bindings, target);
return mutator(std::move(body));
}
private:
ResourceInjector(const std::unordered_map<String, std::pair<Var, PrimExpr>>& bindings,
const tvm::tir::StmtNode* target)
: bindings_(bindings), target_(target) {}
Stmt VisitStmt(const Stmt& stmt) override {
// 当我们遍历到刚才标记的那个 AST 节点时
if (stmt.get() == target_) {
// 先向下遍历(保持 TVM Mutator 的习惯)
Stmt new_stmt = StmtExprMutator::VisitStmt(stmt);
// 在这个节点的外面套上所有的 LetStmt
for (const auto& item : bindings_) {
Var res_var = item.second.first;
PrimExpr init_expr = item.second.second;
new_stmt = tvm::tir::LetStmt(res_var, init_expr, new_stmt);
}
return new_stmt; // 返回包裹好的新节点
}
return StmtExprMutator::VisitStmt(stmt);
}
std::unordered_map<String, std::pair<Var, PrimExpr>> bindings_;
const tvm::tir::StmtNode* target_;
};
// ============================================================================
// Pass 入口
// ============================================================================
PrimFunc LowerSharedGlobalCopy(PrimFunc f) {
auto* n = f.CopyOnWrite();
// 1. 收集信息并定位目标注入点
auto res = CollectResources(n->body);
if (res.copies.empty()) return f;
// 【核心修改】:2. 先注入 LetStmt!
// 此时使用的 n->body 是原始 AST,res.inject_target 指针百分之百匹配。
Stmt injected = ResourceInjector::Run(n->body, res.shared_alloc_to_binding, res.inject_target);
// 3. 替换拷贝语句
// injected 是套了 LetStmt 的新 AST,但底层的 BufferStore 还是原来的,可以被正常替换。
Stmt replaced = StoreReplacer::Run(injected, res.copies, res.global_to_res_var);
// 4. 写回 PrimFunc
n->body = std::move(replaced);
return GetRef<PrimFunc>(n);
}
namespace transform {
using namespace tir::transform;
tvm::transform::Pass LowerSharedGlobalCopy() {
auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
return tl::LowerSharedGlobalCopy(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedGlobalCopy", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerSharedGlobalCopy", LowerSharedGlobalCopy);
}
} // namespace transform
} // namespace tl
} // namespace tvm
\ No newline at end of file
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/op.h>
#include <tvm/tir/builtin.h>
#include <tvm/arith/analyzer.h>
#include <vector>
using namespace tvm::tir;
using tvm::ffi::GetRef;
using tvm::ffi::make_object;
namespace tvm {
namespace tl {
using namespace tir;
using ffi::Array;
using ffi::String;
class AsyncCopySimplifier : public StmtExprMutator {
public:
static Stmt Run(Stmt stmt) {
AsyncCopySimplifier mutator;
return mutator(std::move(stmt));
}
private:
arith::Analyzer analyzer_;
Var k_var_;
PrimExpr k_extent_; // 新增:记录 k 循环的次数
std::pair<PrimExpr, PrimExpr> ExtractStride(PrimExpr expr, Var var) {
if (!var.defined()) return {expr, make_zero(expr.dtype())};
PrimExpr base = tvm::tir::Substitute(expr, {{var, make_zero(var.dtype())}});
PrimExpr plus_one = tvm::tir::Substitute(expr, {{var, make_const(var.dtype(), 1)}});
PrimExpr stride = analyzer_.Simplify(plus_one - base);
return {analyzer_.Simplify(base), stride};
}
Stmt VisitStmt_(const ForNode* op) final {
// 1. 记录 k 的信息
bool is_k = (op->loop_var->name_hint == "k");
if (is_k) {
k_var_ = op->loop_var;
k_extent_ = op->extent; // 获取 k 的循环次数 (如 64)
}
// 2. 递归访问子节点
Stmt body = this->VisitStmt(op->body);
// 3. 处理 Async Copy 简化
if (op->kind == ForKind::kUnrolled) {
if (const EvaluateNode* eval = body.as<EvaluateNode>()) {
if (const CallNode* call = eval->value.as<CallNode>()) {
static const Op& dcu_copy_op = Op::Get("tl.dcu_async_copy");
if (call->op.same_as(dcu_copy_op)) {
Var i_var = op->loop_var;
PrimExpr i_extent = op->extent; // 获取 i 的循环次数 (如 2)
auto get_i_info = [&](PrimExpr offset) {
if (const RampNode* ramp = offset.as<RampNode>()) {
auto [base, stride] = ExtractStride(ramp->base, i_var);
return std::make_pair(base, stride);
}
return ExtractStride(offset, i_var);
};
// 提取 i 的步长
auto [base_dst, i_stride_dst] = get_i_info(call->args[1]);
auto [base_src, i_stride_src] = get_i_info(call->args[3]);
// 提取 k 的步长 (从 base_src 继续解构)
auto [final_src_offset, k_stride_src] = ExtractStride(base_src, k_var_);
// 构造新的参数列表,包含循环次数
// 建议参数顺序:[dst, dst_off, src, src_off, size, i_extent, i_stride_dst, i_stride_src, k_stride_src]
// 这里的 size 保持原样 (如 8),i_extent 传入 2
Array<PrimExpr> new_args = {
call->args[0], // dst_buf
base_dst, // 基础 dst 偏移
call->args[2], // src_buf
final_src_offset, // 基础 src 偏移
i_extent, // i 循环次数
i_stride_dst, // i 的 dst 步长
i_stride_src, // i 的 src 步长
k_stride_src // k 的 src 步长
};
return Evaluate(Call(call->dtype, call->op, new_args));
}
}
}
}
if (is_k) {
k_var_ = Var();
k_extent_ = PrimExpr();
}
if (body.same_as(op->body)) return GetRef<Stmt>(op);
auto n = CopyOnWrite(op);
n->body = std::move(body);
return Stmt(n);
}
};
// ============================================================================
// Pass 入口
// ============================================================================
PrimFunc SimplifyDCUAsyncCopy(PrimFunc f) {
auto* n = f.CopyOnWrite();
n->body = AsyncCopySimplifier::Run(std::move(n->body));
return GetRef<PrimFunc>(n);
}
namespace transform {
using namespace tir::transform;
tvm::transform::Pass SimplifyDCUAsyncCopy() {
auto pass_func = [=](PrimFunc f, const IRModule &m, tvm::transform::PassContext ctx) {
return tl::SimplifyDCUAsyncCopy(std::move(f));
};
return tvm::tir::transform::CreatePrimFuncPass(pass_func, 0, "tl.SimplifyDCUAsyncCopy", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
tvm::ffi::reflection::GlobalDef().def("tl.transform.SimplifyDCUAsyncCopy", SimplifyDCUAsyncCopy);
}
} // namespace transform
} // namespace tl
} // namespace tvm
\ No newline at end of file
...@@ -26,6 +26,7 @@ import tvm_ffi ...@@ -26,6 +26,7 @@ import tvm_ffi
from tvm.base import py_str from tvm.base import py_str
import tvm.runtime import tvm.runtime
import tvm.target import tvm.target
from tvm.target import Target
from tvm.contrib import utils from tvm.contrib import utils
...@@ -286,3 +287,12 @@ def find_rocm_path(): ...@@ -286,3 +287,12 @@ def find_rocm_path():
if os.path.exists(os.path.join(rocm_path, "bin/hipcc")): if os.path.exists(os.path.join(rocm_path, "bin/hipcc")):
return rocm_path return rocm_path
raise RuntimeError("Cannot find ROCm path") raise RuntimeError("Cannot find ROCm path")
def is_dcu(target: Target) -> bool:
if target.kind.name != "hip" and target.kind.name != "rocm":
return False
if "mcpu" in target.attrs:
mcpu = str(target.attrs["mcpu"])
return mcpu.startswith("gfx936")
return False
\ No newline at end of file
...@@ -252,6 +252,7 @@ def lower( ...@@ -252,6 +252,7 @@ def lower(
func = func_or_mod func = func_or_mod
params = extrac_params(func) if not runtime_only else None params = extrac_params(func) if not runtime_only else None
mod = tvm.IRModule({func.attrs["global_symbol"]: func}) mod = tvm.IRModule({func.attrs["global_symbol"]: func})
print(mod)
if isinstance(target, str): if isinstance(target, str):
target = determine_target(target) target = determine_target(target)
...@@ -266,10 +267,11 @@ def lower( ...@@ -266,10 +267,11 @@ def lower(
# Before lowering, do semantic check # Before lowering, do semantic check
PreLowerSemanticCheck(mod) PreLowerSemanticCheck(mod)
print("1111111")
print(mod)
# Phase 1: Lower and legalize the IR # Phase 1: Lower and legalize the IR
mod = LowerAndLegalize(mod, target) mod = LowerAndLegalize(mod, target)
# print(mod)
# Phase 2: Optimize the IR for the target # Phase 2: Optimize the IR for the target
mod = OptimizeForTarget(mod, target) mod = OptimizeForTarget(mod, target)
......
...@@ -4,6 +4,7 @@ from tvm.target import Target ...@@ -4,6 +4,7 @@ from tvm.target import Target
import tilelang import tilelang
from tilelang.transform import PassContext from tilelang.transform import PassContext
from tilelang.contrib.nvcc import have_tma, is_hopper from tilelang.contrib.nvcc import have_tma, is_hopper
from tilelang.contrib.rocm import is_dcu
def allow_warp_specialized(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool: def allow_warp_specialized(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool:
...@@ -69,6 +70,10 @@ def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool: ...@@ -69,6 +70,10 @@ def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool:
enabled = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE, False) enabled = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE, False)
return enabled return enabled
def dcu_async_copy_supported(target: Target | None = None) -> bool:
return is_dcu(target)
def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]: def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]:
if pass_ctx is None: if pass_ctx is None:
...@@ -271,6 +276,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -271,6 +276,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.InjectPTXAsyncCopy()(mod) mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
# Inject ds_read for shared to register memory copy on DCU # Inject ds_read for shared to register memory copy on DCU
mod = tilelang.transform.InjectDSRead()(mod) mod = tilelang.transform.InjectDSRead()(mod)
print("222222222")
print(mod) print(mod)
if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target): if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target):
...@@ -281,5 +287,13 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -281,5 +287,13 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# Transform threadblock to persistent threadblock # Transform threadblock to persistent threadblock
mod = tilelang.transform.PersistThreadblock()(mod) mod = tilelang.transform.PersistThreadblock()(mod)
print("OptimizeForTarget")
print(mod)
if dcu_async_copy_supported(target):
mod = tilelang.transform.LowerSharedGlobalCopy()(mod)
print("OptimizeForTarget2")
print(mod)
mod = tilelang.transform.SimplifyDCUAsyncCopy()(mod)
print("OptimizeForTarget3")
print(mod)
return mod return mod
...@@ -1900,6 +1900,7 @@ tvm_mmac = _dtype_forward(_tir_op.tvm_mmac) ...@@ -1900,6 +1900,7 @@ tvm_mmac = _dtype_forward(_tir_op.tvm_mmac)
tvm_mfma_store = _dtype_forward(_tir_op.tvm_mfma_store) tvm_mfma_store = _dtype_forward(_tir_op.tvm_mfma_store)
tvm_rdna_wmma = _dtype_forward(_tir_op.tvm_rdna_wmma) tvm_rdna_wmma = _dtype_forward(_tir_op.tvm_rdna_wmma)
tvm_rdna_wmma_store = _dtype_forward(_tir_op.tvm_rdna_wmma_store) tvm_rdna_wmma_store = _dtype_forward(_tir_op.tvm_rdna_wmma_store)
make_dcu_resource = _dtype_forward(_tir_op.make_dcu_resource)
broadcast = Broadcast broadcast = Broadcast
ramp = Ramp ramp = Ramp
...@@ -2222,4 +2223,5 @@ __all__ = [ ...@@ -2222,4 +2223,5 @@ __all__ = [
"CommReducer", "CommReducer",
"Range", "Range",
"vscale", "vscale",
"make_dcu_resource",
] ]
from .gemm_base import GemmBase from .gemm_base import GemmBase
from tilelang.layout import make_swizzled_layout from tilelang.layout import make_swizzled_layout, make_linear_layout
from tilelang.intrinsics.mmac_macro_generator import ( from tilelang.intrinsics.mmac_macro_generator import (
MatrixCoreIntrinEmitter, MatrixCoreIntrinEmitter,
) )
......
...@@ -552,3 +552,11 @@ def LayoutReducer(): ...@@ -552,3 +552,11 @@ def LayoutReducer():
The transform pass object produced by the FFI backend. The transform pass object produced by the FFI backend.
""" """
return _ffi_api.LayoutReducer() # type: ignore return _ffi_api.LayoutReducer() # type: ignore
def LowerSharedGlobalCopy():
"""DCUResourceRewriter"""
return _ffi_api.LowerSharedGlobalCopy() # type: ignore
def SimplifyDCUAsyncCopy():
"""SimplifyDCUAsyncCopy"""
return _ffi_api.SimplifyDCUAsyncCopy() # type: ignore
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment