Commit 15599a93 authored by wangziyang's avatar wangziyang
Browse files

print MatrixCore init local size

parent 3852d58b
......@@ -21,6 +21,7 @@
* \brief Replace shared memory BufferLoad with ds_read hardware instructions
* \file inject_ds_read.cc
*/
#include <iostream>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
......@@ -57,138 +58,104 @@ bool IsDCUTarget(const IRModule& module) {
return false;
}
class DSReadInjector : public StmtMutator {
class DSReadInjector : public StmtExprMutator {
public:
Stmt VisitStmt_(const BufferStoreNode* store) final {
/*!
* \brief Visit EvaluateNode to handle explicit ds_read_vector call
* ds_read_vector Call is wrapped in Evaluate to become a statement
* Parameters m, n, offset are passed explicitly via CallNode args
*/
Stmt VisitStmt_(const EvaluateNode* op) override {
std::cout << "[DEBUG VisitStmt_] Visiting EvaluateNode" << std::endl;
const CallNode* call = op->value.as<CallNode>();
std::cout << "[DEBUG VisitStmt_] CallNode ptr: " << call << std::endl;
if (call != nullptr && call->op.same_as(ds_read_vector())) {
ICHECK(call->args.size() == 5)
<< "ds_read_vector expects 5 arguments: (dst, src, m, n, offset)";
// Print args for debugging - these are the actual CallNode args passed in
std::cout << "[DEBUG ds_read_vector] args[0] (dst): " << call->args[0] << std::endl;
std::cout << "[DEBUG ds_read_vector] args[1] (src): " << call->args[1] << std::endl;
std::cout << "[DEBUG ds_read_vector] args[2] (m): " << call->args[2] << std::endl;
std::cout << "[DEBUG ds_read_vector] args[3] (n): " << call->args[3] << std::endl;
std::cout << "[DEBUG ds_read_vector] args[4] (offset): " << call->args[4] << std::endl;
}
// Continue with default traversal (don't replace the existing call)
return StmtExprMutator::VisitStmt_(op);
}
/*!
* \brief Visit BufferStoreNode to inject ds_read_vector call
* Pattern: local_buffer[...] = shared_buffer[...] (BufferLoad)
* Parameters m, n, offset are passed via a preceding CallNode
*/
Stmt VisitStmt_(const BufferStoreNode* op) override {
std::cout << "[DEBUG VisitStmt_] Visiting BufferStoreNode" << std::endl;
// Check if the store is to a local register (not shared memory)
bool is_local = store->buffer.scope() == "local" ||
store->buffer.scope() == "local.fragment";
bool is_local = op->buffer.scope() == "local" ||
op->buffer.scope() == "local.fragment";
std::cout << "[DEBUG BufferStore] is_local: " << is_local
<< ", scope: " << op->buffer.scope() << std::endl;
if (!is_local) {
return StmtMutator::VisitStmt_(store);
return StmtExprMutator::VisitStmt_(op);
}
// Check if the value is a BufferLoad from shared memory
if (auto* load = store->value.as<BufferLoadNode>()) {
bool is_shared_load = load->buffer.scope() == "shared" ||
load->buffer.scope() == "shared.dyn";
if (!is_shared_load) {
return StmtMutator::VisitStmt_(store);
}
// Skip if indices are vectorized (contain Ramp expressions)
// ds_read is a scalar instruction, cannot handle vectorized indices
if (HasVectorizedIndices(store->indices) || HasVectorizedIndices(load->indices)) {
return StmtMutator::VisitStmt_(store);
}
// Check if the buffer is large enough for ds_read_vector
// ds_read_vector<32, 16> with half_t reads 16 bytes (8 elements)
// For small buffers (less than 16 bytes), skip this transformation
if (store->buffer.defined()) {
const auto& buffer_shape = store->buffer->shape;
if (buffer_shape.size() == 1) {
if (auto* int_shape = buffer_shape[0].as<IntImmNode>()) {
int extent = int_shape->value;
int dtype_bytes = load->dtype.bytes();
// ds_read_vector<32,16> with half_t reads 16 bytes minimum
// For buffers smaller than what ds_read_vector needs, skip
if (extent * dtype_bytes < 16) {
return StmtMutator::VisitStmt_(store);
}
}
}
}
// Analyze the load pattern to determine which ds_read to use
return InjectDSRead(store, load);
const BufferLoadNode* load = op->value.as<BufferLoadNode>();
if (load == nullptr) {
std::cout << "[DEBUG BufferStore] value is not BufferLoad" << std::endl;
return StmtExprMutator::VisitStmt_(op);
}
return StmtMutator::VisitStmt_(store);
}
private:
bool is_shared_load = load->buffer.scope() == "shared" ||
load->buffer.scope() == "shared.dyn";
std::cout << "[DEBUG BufferStore] is_shared_load: " << is_shared_load
<< ", load scope: " << load->buffer.scope() << std::endl;
// PrimExpr VisitExpr_(const CallNode *op) {
// Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
// if (call->op.same_as(builtin::tvm_access_ptr())) {
// return RewriteBufferAccess(call, {1});
// }
// return call;
// }
/*!
* \brief Check if any index expression contains a Ramp (vectorized) expression
*/
bool HasVectorizedIndices(const Array<PrimExpr>& indices) {
for (const auto& idx : indices) {
if (idx.as<RampNode>()) {
return true;
}
}
return false;
}
Stmt InjectDSRead(const BufferStoreNode* store, const BufferLoadNode* load) {
const Buffer& shared_buf = load->buffer;
const Buffer& local_buf = store->buffer;
// Analyze indices to determine the byte offset
// PrimExpr offset = load->indices.size() > 0 ? load->indices[0] : make_zero(DataType::UInt(0));
// Calculate buffer size in bytes
int buffer_bytes = 0;
if (local_buf.defined() && local_buf->shape.size() == 1) {
if (auto* int_shape = local_buf->shape[0].as<IntImmNode>()) {
int num_elements = int_shape->value;
int dtype_bytes = local_buf->dtype.bytes();
buffer_bytes = num_elements * dtype_bytes;
}
if (!is_shared_load) {
return StmtExprMutator::VisitStmt_(op);
}
// Determine which ds_read to use based on buffer size
// ds_read_b64 loads 8 bytes (64 bits) = 1 element for half_t, 2 for float32
// ds_read_m32x16_b16 loads 32 bytes (256 bits)
int dtype_bits = local_buf->dtype.bits();
int m = 16;
// For buffer < 16 bytes, use single ds_read_b64 (M=32, N=1)
// For buffer >= 16 bytes, use double ds_read_b64 (M=32, N=16)
// ds_read_b64 reads 8 bytes per call
int n = (buffer_bytes >= 32) ? 32 : 16;
int offset = 0;
return EmitDSRead(local_buf, shared_buf, m, n, offset);
}
Stmt EmitDSRead(const Buffer& local_buf,
const Buffer& shared_buf, int m, int n, int offset) {
// ds_read_vector takes: (dst, shared_ptr, m, n, offset)
Array<PrimExpr> args = {
local_buf->data, // dst: local buffer data pointer
shared_buf.access_ptr(0, DataType::Handle(), 1, 0), // src: shared buffer data pointer
make_const(DataType::Int(32), m),
make_const(DataType::Int(32), n),
make_const(DataType::Int(32), offset) // byte_offset: offset into shared memory
// Found pattern: local = BufferLoad(shared)
// The m, n, offset parameters should come from a CallNode in the IR
// For now, use default values that will be replaced when CallNode is processed
std::cout << "[DEBUG BufferStore] Injecting ds_read_vector call!" << std::endl;
// Get parameters from the Store's indices or use default values
// In a full implementation, these would come from a preceding CallNode
PrimExpr m = make_const(DataType::Int(32), 16);
PrimExpr n = make_const(DataType::Int(32), 16);
PrimExpr offset = make_const(DataType::Int(32), 0);
// Visit all arguments to transform any nested expressions
Array<PrimExpr> new_args = {
VisitExpr(load->buffer.access_ptr(0, DataType::Handle(), 1, 0)), // src
VisitExpr(op->buffer->data), // dst
VisitExpr(m),
VisitExpr(n),
VisitExpr(offset)
};
Stmt ds_read_stmt = Evaluate(
Call(DataType::Handle(), ds_read_vector(), args));
return ds_read_stmt;
// Create the ds_read call
Call ds_read_call = Call(DataType::Handle(), ds_read_vector(), new_args);
return Evaluate(ds_read_call);
}
};
using namespace tir::transform;
tvm::transform::Pass InjectDSRead() {
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
std::cout << "[DEBUG InjectDSRead] Pass is being executed" << std::endl;
// Only apply to DCU targets
if (!IsDCUTarget(m)) {
std::cout << "[DEBUG InjectDSRead] Not a DCU target, skipping" << std::endl;
return f;
}
std::cout << "[DEBUG InjectDSRead] Is DCU target, applying injector" << std::endl;
auto *n = f.CopyOnWrite();
n->body = DSReadInjector()(n->body);
return f;
......
......@@ -181,6 +181,9 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# print("********************")
# print(mod)
# print("********************")
pass_ctx = tilelang.transform.get_pass_context()
# Lower the barrier.arrive into specific initialization slot
mod = tilelang.transform.LowerSharedBarrier()(mod)
......@@ -204,6 +207,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.RewriteWgmmaSync()(mod)
mod = tilelang.transform.InjectFenceProxy()(mod)
else:
mod = tilelang.transform.IfStmtBinding()(mod)
mod = tilelang.transform.PlanAndUpdateBufferAllocationLocation()(mod)
mod = tilelang.transform.PipelinePlanning()(mod)
......@@ -214,12 +218,14 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# so we need to inject a fence proxy before it
mod = tilelang.transform.InjectFenceProxy()(mod)
mod = tilelang.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.Simplify()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
mod = tilelang.transform.FlattenBuffer()(mod)
# ConfigIndexBitwidth must be applied after FlattenBuffer
# as it will flatten index computing
mod = tilelang.transform.ConfigIndexBitwidth()(mod)
mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod)
......
......@@ -38,6 +38,7 @@ def shared_16x16_to_ldmatrix_64x4_layout(ind):
def thread_id_shared_access_64x4_to_16x16_layout_A(thread_id, local_id):
print("mfma_layout thread_id_shared_access_64x4_to_16x16_layout_A:", thread_id, local_id)
i = thread_id % 16
j = (thread_id // 16) * 4 + local_id
return i, j
......@@ -50,6 +51,7 @@ def shared_16x16_to_local_64x4_layout_A(i, j):
def thread_id_shared_access_64x4_to_16x16_layout_B(thread_id, local_id):
print("mfma_layout thread_id_shared_access_64x4_to_16x16_layout_B:", thread_id, local_id)
i = local_id + (thread_id // 16) * 4
j = thread_id % 16
return i, j
......
......@@ -115,7 +115,8 @@ class MatrixCoreIntrinEmitter:
if a_dtype.bits == 32:
self.k_dim = 4
elif a_dtype.bits in {16, 8}:
self.k_dim = 16
# self.k_dim = 16
self.k_dim = 256
else:
raise ValueError(f"Unsupported a_dtype = {a_dtype}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment