Commit 15599a93 authored by wangziyang's avatar wangziyang
Browse files

print MatrixCore init local size

parent 3852d58b
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* \brief Replace shared memory BufferLoad with ds_read hardware instructions * \brief Replace shared memory BufferLoad with ds_read hardware instructions
* \file inject_ds_read.cc * \file inject_ds_read.cc
*/ */
#include <iostream>
#include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h> #include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
...@@ -57,138 +58,104 @@ bool IsDCUTarget(const IRModule& module) { ...@@ -57,138 +58,104 @@ bool IsDCUTarget(const IRModule& module) {
return false; return false;
} }
class DSReadInjector : public StmtMutator { class DSReadInjector : public StmtExprMutator {
public: public:
Stmt VisitStmt_(const BufferStoreNode* store) final { /*!
// Check if the store is to a local register (not shared memory) * \brief Visit EvaluateNode to handle explicit ds_read_vector call
bool is_local = store->buffer.scope() == "local" || * ds_read_vector Call is wrapped in Evaluate to become a statement
store->buffer.scope() == "local.fragment"; * Parameters m, n, offset are passed explicitly via CallNode args
*/
if (!is_local) { Stmt VisitStmt_(const EvaluateNode* op) override {
return StmtMutator::VisitStmt_(store); std::cout << "[DEBUG VisitStmt_] Visiting EvaluateNode" << std::endl;
} const CallNode* call = op->value.as<CallNode>();
std::cout << "[DEBUG VisitStmt_] CallNode ptr: " << call << std::endl;
// Check if the value is a BufferLoad from shared memory if (call != nullptr && call->op.same_as(ds_read_vector())) {
if (auto* load = store->value.as<BufferLoadNode>()) { ICHECK(call->args.size() == 5)
bool is_shared_load = load->buffer.scope() == "shared" || << "ds_read_vector expects 5 arguments: (dst, src, m, n, offset)";
load->buffer.scope() == "shared.dyn";
if (!is_shared_load) {
return StmtMutator::VisitStmt_(store);
}
// Skip if indices are vectorized (contain Ramp expressions)
// ds_read is a scalar instruction, cannot handle vectorized indices
if (HasVectorizedIndices(store->indices) || HasVectorizedIndices(load->indices)) {
return StmtMutator::VisitStmt_(store);
}
// Check if the buffer is large enough for ds_read_vector
// ds_read_vector<32, 16> with half_t reads 16 bytes (8 elements)
// For small buffers (less than 16 bytes), skip this transformation
if (store->buffer.defined()) {
const auto& buffer_shape = store->buffer->shape;
if (buffer_shape.size() == 1) {
if (auto* int_shape = buffer_shape[0].as<IntImmNode>()) {
int extent = int_shape->value;
int dtype_bytes = load->dtype.bytes();
// ds_read_vector<32,16> with half_t reads 16 bytes minimum
// For buffers smaller than what ds_read_vector needs, skip
if (extent * dtype_bytes < 16) {
return StmtMutator::VisitStmt_(store);
}
}
}
}
// Analyze the load pattern to determine which ds_read to use // Print args for debugging - these are the actual CallNode args passed in
return InjectDSRead(store, load); std::cout << "[DEBUG ds_read_vector] args[0] (dst): " << call->args[0] << std::endl;
std::cout << "[DEBUG ds_read_vector] args[1] (src): " << call->args[1] << std::endl;
std::cout << "[DEBUG ds_read_vector] args[2] (m): " << call->args[2] << std::endl;
std::cout << "[DEBUG ds_read_vector] args[3] (n): " << call->args[3] << std::endl;
std::cout << "[DEBUG ds_read_vector] args[4] (offset): " << call->args[4] << std::endl;
} }
// Continue with default traversal (don't replace the existing call)
return StmtMutator::VisitStmt_(store); return StmtExprMutator::VisitStmt_(op);
} }
private:
// PrimExpr VisitExpr_(const CallNode *op) {
// Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
// if (call->op.same_as(builtin::tvm_access_ptr())) {
// return RewriteBufferAccess(call, {1});
// }
// return call;
// }
/*! /*!
* \brief Check if any index expression contains a Ramp (vectorized) expression * \brief Visit BufferStoreNode to inject ds_read_vector call
* Pattern: local_buffer[...] = shared_buffer[...] (BufferLoad)
* Parameters m, n, offset are passed via a preceding CallNode
*/ */
bool HasVectorizedIndices(const Array<PrimExpr>& indices) { Stmt VisitStmt_(const BufferStoreNode* op) override {
for (const auto& idx : indices) { std::cout << "[DEBUG VisitStmt_] Visiting BufferStoreNode" << std::endl;
if (idx.as<RampNode>()) {
return true;
}
}
return false;
}
Stmt InjectDSRead(const BufferStoreNode* store, const BufferLoadNode* load) {
const Buffer& shared_buf = load->buffer;
const Buffer& local_buf = store->buffer;
// Analyze indices to determine the byte offset
// PrimExpr offset = load->indices.size() > 0 ? load->indices[0] : make_zero(DataType::UInt(0));
// Calculate buffer size in bytes
int buffer_bytes = 0;
if (local_buf.defined() && local_buf->shape.size() == 1) {
if (auto* int_shape = local_buf->shape[0].as<IntImmNode>()) {
int num_elements = int_shape->value;
int dtype_bytes = local_buf->dtype.bytes();
buffer_bytes = num_elements * dtype_bytes;
}
}
// Determine which ds_read to use based on buffer size // Check if the store is to a local register (not shared memory)
// ds_read_b64 loads 8 bytes (64 bits) = 1 element for half_t, 2 for float32 bool is_local = op->buffer.scope() == "local" ||
// ds_read_m32x16_b16 loads 32 bytes (256 bits) op->buffer.scope() == "local.fragment";
int dtype_bits = local_buf->dtype.bits(); std::cout << "[DEBUG BufferStore] is_local: " << is_local
int m = 16; << ", scope: " << op->buffer.scope() << std::endl;
// For buffer < 16 bytes, use single ds_read_b64 (M=32, N=1) if (!is_local) {
// For buffer >= 16 bytes, use double ds_read_b64 (M=32, N=16) return StmtExprMutator::VisitStmt_(op);
// ds_read_b64 reads 8 bytes per call }
int n = (buffer_bytes >= 32) ? 32 : 16;
int offset = 0;
return EmitDSRead(local_buf, shared_buf, m, n, offset); // Check if the value is a BufferLoad from shared memory
const BufferLoadNode* load = op->value.as<BufferLoadNode>();
if (load == nullptr) {
std::cout << "[DEBUG BufferStore] value is not BufferLoad" << std::endl;
return StmtExprMutator::VisitStmt_(op);
} }
Stmt EmitDSRead(const Buffer& local_buf, bool is_shared_load = load->buffer.scope() == "shared" ||
const Buffer& shared_buf, int m, int n, int offset) { load->buffer.scope() == "shared.dyn";
std::cout << "[DEBUG BufferStore] is_shared_load: " << is_shared_load
<< ", load scope: " << load->buffer.scope() << std::endl;
// ds_read_vector takes: (dst, shared_ptr, m, n, offset) if (!is_shared_load) {
Array<PrimExpr> args = { return StmtExprMutator::VisitStmt_(op);
local_buf->data, // dst: local buffer data pointer }
shared_buf.access_ptr(0, DataType::Handle(), 1, 0), // src: shared buffer data pointer
make_const(DataType::Int(32), m), // Found pattern: local = BufferLoad(shared)
make_const(DataType::Int(32), n), // The m, n, offset parameters should come from a CallNode in the IR
make_const(DataType::Int(32), offset) // byte_offset: offset into shared memory // For now, use default values that will be replaced when CallNode is processed
std::cout << "[DEBUG BufferStore] Injecting ds_read_vector call!" << std::endl;
// Get parameters from the Store's indices or use default values
// In a full implementation, these would come from a preceding CallNode
PrimExpr m = make_const(DataType::Int(32), 16);
PrimExpr n = make_const(DataType::Int(32), 16);
PrimExpr offset = make_const(DataType::Int(32), 0);
// Visit all arguments to transform any nested expressions
Array<PrimExpr> new_args = {
VisitExpr(load->buffer.access_ptr(0, DataType::Handle(), 1, 0)), // src
VisitExpr(op->buffer->data), // dst
VisitExpr(m),
VisitExpr(n),
VisitExpr(offset)
}; };
Stmt ds_read_stmt = Evaluate( // Create the ds_read call
Call(DataType::Handle(), ds_read_vector(), args)); Call ds_read_call = Call(DataType::Handle(), ds_read_vector(), new_args);
return Evaluate(ds_read_call);
return ds_read_stmt;
} }
}; };
using namespace tir::transform; using namespace tir::transform;
tvm::transform::Pass InjectDSRead() { tvm::transform::Pass InjectDSRead() {
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
std::cout << "[DEBUG InjectDSRead] Pass is being executed" << std::endl;
// Only apply to DCU targets // Only apply to DCU targets
if (!IsDCUTarget(m)) { if (!IsDCUTarget(m)) {
std::cout << "[DEBUG InjectDSRead] Not a DCU target, skipping" << std::endl;
return f; return f;
} }
std::cout << "[DEBUG InjectDSRead] Is DCU target, applying injector" << std::endl;
auto *n = f.CopyOnWrite(); auto *n = f.CopyOnWrite();
n->body = DSReadInjector()(n->body); n->body = DSReadInjector()(n->body);
return f; return f;
......
...@@ -181,6 +181,9 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: ...@@ -181,6 +181,9 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# print("********************")
# print(mod)
# print("********************")
pass_ctx = tilelang.transform.get_pass_context() pass_ctx = tilelang.transform.get_pass_context()
# Lower the barrier.arrive into specific initialization slot # Lower the barrier.arrive into specific initialization slot
mod = tilelang.transform.LowerSharedBarrier()(mod) mod = tilelang.transform.LowerSharedBarrier()(mod)
...@@ -204,6 +207,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -204,6 +207,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.RewriteWgmmaSync()(mod) mod = tilelang.transform.RewriteWgmmaSync()(mod)
mod = tilelang.transform.InjectFenceProxy()(mod) mod = tilelang.transform.InjectFenceProxy()(mod)
else: else:
mod = tilelang.transform.IfStmtBinding()(mod) mod = tilelang.transform.IfStmtBinding()(mod)
mod = tilelang.transform.PlanAndUpdateBufferAllocationLocation()(mod) mod = tilelang.transform.PlanAndUpdateBufferAllocationLocation()(mod)
mod = tilelang.transform.PipelinePlanning()(mod) mod = tilelang.transform.PipelinePlanning()(mod)
...@@ -214,12 +218,14 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -214,12 +218,14 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# so we need to inject a fence proxy before it # so we need to inject a fence proxy before it
mod = tilelang.transform.InjectFenceProxy()(mod) mod = tilelang.transform.InjectFenceProxy()(mod)
mod = tilelang.transform.LowerOpaqueBlock()(mod) mod = tilelang.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.Simplify()(mod) mod = tilelang.transform.Simplify()(mod)
mod = tir.transform.NarrowDataType(32)(mod) mod = tir.transform.NarrowDataType(32)(mod)
mod = tilelang.transform.FlattenBuffer()(mod) mod = tilelang.transform.FlattenBuffer()(mod)
# ConfigIndexBitwidth must be applied after FlattenBuffer # ConfigIndexBitwidth must be applied after FlattenBuffer
# as it will flatten index computing # as it will flatten index computing
mod = tilelang.transform.ConfigIndexBitwidth()(mod) mod = tilelang.transform.ConfigIndexBitwidth()(mod)
mod = tir.transform.Simplify()(mod) mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod) mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod)
......
...@@ -38,6 +38,7 @@ def shared_16x16_to_ldmatrix_64x4_layout(ind): ...@@ -38,6 +38,7 @@ def shared_16x16_to_ldmatrix_64x4_layout(ind):
def thread_id_shared_access_64x4_to_16x16_layout_A(thread_id, local_id): def thread_id_shared_access_64x4_to_16x16_layout_A(thread_id, local_id):
print("mfma_layout thread_id_shared_access_64x4_to_16x16_layout_A:", thread_id, local_id)
i = thread_id % 16 i = thread_id % 16
j = (thread_id // 16) * 4 + local_id j = (thread_id // 16) * 4 + local_id
return i, j return i, j
...@@ -50,6 +51,7 @@ def shared_16x16_to_local_64x4_layout_A(i, j): ...@@ -50,6 +51,7 @@ def shared_16x16_to_local_64x4_layout_A(i, j):
def thread_id_shared_access_64x4_to_16x16_layout_B(thread_id, local_id): def thread_id_shared_access_64x4_to_16x16_layout_B(thread_id, local_id):
print("mfma_layout thread_id_shared_access_64x4_to_16x16_layout_B:", thread_id, local_id)
i = local_id + (thread_id // 16) * 4 i = local_id + (thread_id // 16) * 4
j = thread_id % 16 j = thread_id % 16
return i, j return i, j
......
...@@ -115,7 +115,8 @@ class MatrixCoreIntrinEmitter: ...@@ -115,7 +115,8 @@ class MatrixCoreIntrinEmitter:
if a_dtype.bits == 32: if a_dtype.bits == 32:
self.k_dim = 4 self.k_dim = 4
elif a_dtype.bits in {16, 8}: elif a_dtype.bits in {16, 8}:
self.k_dim = 16 # self.k_dim = 16
self.k_dim = 256
else: else:
raise ValueError(f"Unsupported a_dtype = {a_dtype}") raise ValueError(f"Unsupported a_dtype = {a_dtype}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment