Commit eff4082d authored by wangziyang's avatar wangziyang
Browse files

fix ds_read pass

parent dd95e41b
...@@ -218,7 +218,7 @@ TIR_DEFINE_TL_BUILTIN(tma_store_wait) ...@@ -218,7 +218,7 @@ TIR_DEFINE_TL_BUILTIN(tma_store_wait)
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ds_read_vector) TIR_DEFINE_TL_BUILTIN(ds_read_vector)
.set_num_inputs(5) .set_num_inputs(3)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
......
...@@ -839,18 +839,16 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -839,18 +839,16 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
func_name += "_trans"; func_name += "_trans";
print_extern_call_stmt(func_name, 2); print_extern_call_stmt(func_name, 2);
}else if(op->op.same_as(tl::ds_read_vector())){ }else if(op->op.same_as(tl::ds_read_vector())){
//ds_read_b64 %1, %2 offset:%3 // ds_read_m32x16_b16 %0, %1 offset:0
// ds_read_m32x16_b16 %0, %1 offset:%2
printf("[DEBUG VisitExpr_] Branch: ds_read_vector\n"); printf("[DEBUG VisitExpr_] Branch: ds_read_vector\n");
std::string dst = this->PrintExpr(op->args[0]); std::string dst = this->PrintExpr(op->args[0]);
std::string lds_base_ptr = this->PrintExpr(op->args[1]); std::string local_offset = this->PrintExpr(op->args[1]);
std::string m = this->PrintExpr(op->args[2]); std::string lds_offset = this->PrintExpr(op->args[2]);
std::string n = this->PrintExpr(op->args[3]); os << "tl::ds_read_vector("
std::string offset = this->PrintExpr(op->args[4]); << dst << " + " << local_offset
this->PrintIndent(); << ", "
this->stream << "tl::ds_read_vector<" << m << ", " << n <<", " << offset << ">" << lds_offset
<< "(*reinterpret_cast<float4_*>(" << dst << "), " << ")";
<< "reinterpret_cast<uintptr_t>(" << lds_base_ptr << "));\n";
}else if (op->op.same_as(tl::wait_wgmma())) { }else if (op->op.same_as(tl::wait_wgmma())) {
printf("[DEBUG VisitExpr_] Branch: wait_wgmma\n"); printf("[DEBUG VisitExpr_] Branch: wait_wgmma\n");
this->PrintIndent(); this->PrintIndent();
......
...@@ -176,36 +176,44 @@ template <int N> ...@@ -176,36 +176,44 @@ template <int N>
// } // }
// } // }
template <int M, int N, int offset> TL_DEVICE void ds_read_vector(void* dst, uint32_t lds_base_ptr)
TL_DEVICE void ds_read_vector(float4_& dst, uint32_t lds_base_ptr)
{ {
if constexpr (M == 16 && N == 32) asm volatile("ds_read_m32x16_b16 %0, %1 offset:0\n\t"
{
const int offset_in_bytes = offset * sizeof(half_t);
asm volatile("ds_read_m32x16_b16 %0, %1 offset:%2\n\t"
: "+v"(dst) : "+v"(dst)
: "v"(lds_base_ptr), : "v"(lds_base_ptr),
"n"(offset_in_bytes)
: "memory"); : "memory");
}
else if constexpr (M == 32 && N == 16)
{
const int offset_in_bytes0 = offset * sizeof(half_t);
const int offset_in_bytes1 = offset_in_bytes0 + 4096;
float2_& front = *reinterpret_cast<float2_*>(&dst);
float2_& rear = *(reinterpret_cast<float2_*>(&dst) + 1);
asm volatile(
"ds_read_b64 %1, %2 offset:%3\n\t"
"ds_read_b64 %0, %2 offset:%4\n\t"
: "+v"(rear), "+v"(front)
: "v"(lds_base_ptr), "n"(offset_in_bytes0), "n"(offset_in_bytes1)
: "memory"
);
}
} }
// template <int M, int N, int offset>
// TL_DEVICE void ds_read_vector(float4_& dst, uint32_t lds_base_ptr)
// {
// if constexpr (M == 16 && N == 32)
// {
// const int offset_in_bytes = offset * sizeof(half_t);
// asm volatile("ds_read_m32x16_b16 %0, %1 offset:%2\n\t"
// : "+v"(dst)
// : "v"(lds_base_ptr),
// "n"(offset_in_bytes)
// : "memory");
// }
// else if constexpr (M == 32 && N == 16)
// {
// const int offset_in_bytes0 = offset * sizeof(half_t);
// const int offset_in_bytes1 = offset_in_bytes0 + 4096;
// float2_& front = *reinterpret_cast<float2_*>(&dst);
// float2_& rear = *(reinterpret_cast<float2_*>(&dst) + 1);
// asm volatile(
// "ds_read_b64 %1, %2 offset:%3\n\t"
// "ds_read_b64 %0, %2 offset:%4\n\t"
// : "+v"(rear), "+v"(front)
// : "v"(lds_base_ptr), "n"(offset_in_bytes0), "n"(offset_in_bytes1)
// : "memory"
// );
// }
// }
template <int N> template <int N>
TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr, TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr,
void *global_base_ptr, bool cond) { void *global_base_ptr, bool cond) {
......
...@@ -42,88 +42,63 @@ using namespace tir; ...@@ -42,88 +42,63 @@ using namespace tir;
class DSReadInjector : public StmtExprMutator { class DSReadInjector : public StmtExprMutator {
public: public:
/*! bool IsBLocalBuffer(const Buffer& buffer) {
* \brief Visit EvaluateNode to handle explicit ds_read_vector call std::string name = buffer->name;
* ds_read_vector Call is wrapped in Evaluate to become a statement return name.find("B_local") != std::string::npos;
* Parameters m, n, offset are passed explicitly via CallNode args }
*/
Stmt VisitStmt_(const EvaluateNode* op) override { private:
std::cout << "[DEBUG VisitStmt_] Visiting EvaluateNode" << std::endl;
const CallNode* call = op->value.as<CallNode>(); Stmt VisitStmt_(const BufferStoreNode* op) final {
std::cout << "[DEBUG VisitStmt_] CallNode ptr: " << call << std::endl; Buffer buffer = op->buffer;
if (call != nullptr && call->op.same_as(ds_read_vector())) {
ICHECK(call->args.size() == 5) if (!IsBLocalBuffer(buffer)) {
<< "ds_read_vector expects 5 arguments: (dst, src, m, n, offset)";
// Print args for debugging - these are the actual CallNode args passed in
std::cout << "[DEBUG ds_read_vector] args[0] (dst): " << call->args[0] << std::endl;
std::cout << "[DEBUG ds_read_vector] args[1] (src): " << call->args[1] << std::endl;
std::cout << "[DEBUG ds_read_vector] args[2] (m): " << call->args[2] << std::endl;
std::cout << "[DEBUG ds_read_vector] args[3] (n): " << call->args[3] << std::endl;
std::cout << "[DEBUG ds_read_vector] args[4] (offset): " << call->args[4] << std::endl;
}
// Continue with default traversal (don't replace the existing call)
return StmtExprMutator::VisitStmt_(op); return StmtExprMutator::VisitStmt_(op);
} }
/*! const BufferLoadNode* load = op->value.as<BufferLoadNode>();
* \brief Visit BufferStoreNode to inject ds_read_vector call if (!load) {
* Pattern: local_buffer[...] = shared_buffer[...] (BufferLoad) return StmtExprMutator::VisitStmt_(op);
* Parameters m, n, offset are passed via a CallNode (tl.ds_read_config) }
*/
Stmt VisitStmt_(const BufferStoreNode* op) override {
std::cout << "[DEBUG VisitStmt_] Visiting BufferStoreNode" << std::endl;
// Check if the store is to a local register (not shared memory)
bool is_local = op->buffer.scope() == "local" ||
op->buffer.scope() == "local.fragment";
std::cout << "[DEBUG BufferStore] is_local: " << is_local
<< ", scope: " << op->buffer.scope() << std::endl;
if (!is_local) {
return StmtExprMutator::VisitStmt_(op);
}
// Check if the value is a BufferLoad from shared memory // local offset
const BufferLoadNode* load = op->value.as<BufferLoadNode>(); ICHECK(op->indices.size() == 1);
if (load == nullptr) { PrimExpr local_index = op->indices[0];
std::cout << "[DEBUG BufferStore] value is not BufferLoad" << std::endl; PrimExpr local_offset;
return StmtExprMutator::VisitStmt_(op);
} if (const RampNode* ramp = local_index.as<RampNode>()) {
local_offset = ramp->base;
} else {
local_offset = local_index;
}
bool is_shared_load = load->buffer.scope() == "shared" || // lds offset
load->buffer.scope() == "shared.dyn"; ICHECK(load->indices.size() == 1);
std::cout << "[DEBUG BufferStore] is_shared_load: " << is_shared_load PrimExpr lds_index = load->indices[0];
<< ", load scope: " << load->buffer.scope() << std::endl; PrimExpr lds_offset ;
if (!is_shared_load) { if (const RampNode* ramp = lds_index.as<RampNode>()) {
return StmtExprMutator::VisitStmt_(op); lds_offset = ramp->base;
} } else {
lds_offset = lds_index;
}
// For A_shared, use the actual shared memory base pointer // buffer pointer
PrimExpr m = make_const(DataType::Int(32), 32); PrimExpr buffer_ptr = buffer->data;
PrimExpr n = make_const(DataType::Int(32), 16);
PrimExpr offset; Array<PrimExpr> args = {
// Extract the shared memory offset from the load indices buffer_ptr,
if (!load->indices.empty()) { local_offset,
offset = load->indices[0]; lds_offset
} else { };
offset = make_const(DataType::Int(32), 0);
} Call call = Call(
DataType::Handle(),
ds_read_vector(),
args
);
// Use buffer data vars directly return Evaluate(call);
Array<PrimExpr> new_args = {
load->buffer->data, // src
op->buffer->data, // dst
m,
n,
offset
};
// Create the ds_read call
Call ds_read_call = Call(DataType::Handle(), ds_read_vector(), new_args);
return Evaluate(ds_read_call);
} }
}; };
......
...@@ -297,6 +297,11 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -297,6 +297,11 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.LowerSharedGlobalCopy()(mod) mod = tilelang.transform.LowerSharedGlobalCopy()(mod)
mod = tilelang.transform.FixDCUWaitCount()(mod) mod = tilelang.transform.FixDCUWaitCount()(mod)
mod = tilelang.transform.InjectBLocalLayoutTransform()(mod) mod = tilelang.transform.InjectBLocalLayoutTransform()(mod)
print("InjectBLocalLayoutTransform ............")
print(mod)
mod = tilelang.transform.InjectDSRead()(mod)
print("InjectDSRead ............")
print(mod)
# mod = tilelang.transform.SimplifyDCUAsyncCopy()(mod) # mod = tilelang.transform.SimplifyDCUAsyncCopy()(mod)
print("OptimizeForTarget3") print("OptimizeForTarget3")
print(mod) print(mod)
......
...@@ -90,7 +90,7 @@ def __ldg(load_or_buf: BufferLoad | tir.Buffer, index: PrimExpr | int | None = N ...@@ -90,7 +90,7 @@ def __ldg(load_or_buf: BufferLoad | tir.Buffer, index: PrimExpr | int | None = N
def ds_read_vector(dst: tir.Var, shared_ptr: tir.Var, m: int, n: int, offset: int) -> Call: def ds_read_vector(dst: tir.Var, local_offset: tir.Var, shared_ptr: tir.Var) -> Call:
""" """
Load from shared memory using ds_read_b64 instruction. Load from shared memory using ds_read_b64 instruction.
...@@ -104,14 +104,12 @@ def ds_read_vector(dst: tir.Var, shared_ptr: tir.Var, m: int, n: int, offset: in ...@@ -104,14 +104,12 @@ def ds_read_vector(dst: tir.Var, shared_ptr: tir.Var, m: int, n: int, offset: in
Load from shared memory using ds_read_m32x16_b16 instruction. Load from shared memory using ds_read_m32x16_b16 instruction.
The ds_read_vector intrinsic has signature: The ds_read_vector intrinsic has signature:
ds_read_vector<M,N,offset>(float4 & dst, int lds_base_ptr) ds_read_vector(dst + local_offset, shared_ptr)
Args: Args:
dst: Destination pointer (register/local buffer). dst: Destination pointer (register/local buffer).
local_offset: Local offset in bytes for the destination register.
lds_base_ptr: Source pointer (shared memory buffer data). lds_base_ptr: Source pointer (shared memory buffer data).
M: Number of columns in the matrix to load (for ds_read_m32x16_b16 / ds_read_b64).
N: Number of rows in the matrix to load (for ds_read_m32x16_b16 / ds_read_b64).
offset: address offset into shared memory.
Returns: Returns:
Call: A TIR call intrinsic for the ds_read_b64 instruction. Call: A TIR call intrinsic for the ds_read_b64 instruction.
...@@ -120,10 +118,8 @@ def ds_read_vector(dst: tir.Var, shared_ptr: tir.Var, m: int, n: int, offset: in ...@@ -120,10 +118,8 @@ def ds_read_vector(dst: tir.Var, shared_ptr: tir.Var, m: int, n: int, offset: in
"handle", "handle",
tir.op.Op.get("tl.ds_read_vector"), tir.op.Op.get("tl.ds_read_vector"),
dst, dst,
shared_ptr, local_offset,
m, shared_ptr
n,
offset
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment