Commit dd95e41b authored by wangziyang's avatar wangziyang
Browse files

add B_local layout transformation with loop optimization

parent bba13746
......@@ -41,162 +41,161 @@ namespace tl {
using namespace tir;
/*!
* \brief Check if a statement contains B_local stores
* \brief Transformer that handles B_local layout transformation with loop optimization
*
* This transformer handles two cases:
* 1. B_local store with outer loop: halve the loop extent and double the offset
* 2. B_local store without outer loop: just double the offset
*/
bool ContainsBLocalStore(const Stmt& stmt) {
bool found = false;
tir::PreOrderVisit(stmt, [&](const ObjectRef& node) -> bool {
if (found) {
return false;
class BLocalLayoutTransformer : public StmtExprMutator {
public:
explicit BLocalLayoutTransformer(int expand)
: expand_(expand) {}
private:
int expand_;
Stmt VisitStmt_(const ForNode* op) final {
// 只处理 serial 外层循环
if (op->kind != ForKind::kSerial) {
return StmtExprMutator::VisitStmt_(op);
}
if (const auto* store = node.as<BufferStoreNode>()) {
std::string name = store->buffer->name;
if (name.find("B_local") != std::string::npos) {
found = true;
return false;
}
// 判断是否是 B_local 写循环
auto store = op->body.as<BufferStoreNode>();
if (!store) {
return StmtExprMutator::VisitStmt_(op);
}
return true;
});
return found;
}
/*!
* \brief Check if this is a B_local store pattern
*
* Pattern to match:
* B_local[index] = B_shared[index_expr]
*
* Where B_shared[index_expr] is a complex expression involving:
* - thread_binding (threadIdx.x, threadIdx.y, etc.)
* - ki (iteration variable)
* - j and local_id (loop variables)
*/
bool IsBLocalStorePattern(const BufferStoreNode* op,
Var* local_var,
Var* shared_var,
PrimExpr* shared_offset) {
// Check if store is to a local buffer named B_local
std::string buffer_name = op->buffer->name;
if (buffer_name.find("B_local") == std::string::npos) {
return false;
}
if (!IsBLocal(store->buffer)) {
return StmtExprMutator::VisitStmt_(op);
}
// Must have exactly one index: B_local[index]
if (op->indices.size() != 1) {
return false;
}
int64_t old_extent = op->extent.as<IntImmNode>()->value;
ICHECK(old_extent % expand_ == 0)
<< "Loop extent must be divisible by expand factor.";
int64_t new_extent = old_extent / expand_;
// 修改循环范围
For new_for =
For(op->loop_var,
op->min,
Integer(new_extent),
op->kind,
MutateStore(store, op->loop_var));
// Check if value is a BufferLoad from shared memory
const BufferLoadNode* load = op->value.as<BufferLoadNode>();
if (load == nullptr) {
return false;
return new_for;
}
// Check if load is from shared memory
std::string load_buffer_name = load->buffer->name;
std::cout<<"[DEBUG IsBLocalStorePattern] load buffer name: " << load_buffer_name << std::endl;
if (load_buffer_name.find("B_shared") == std::string::npos) {
return false;
bool IsBLocal(const Buffer& buffer) {
std::string name = buffer->name;
return name.find("B_local") != std::string::npos;
}
// Get buffer variables
*local_var = op->buffer->data;
*shared_var = load->buffer->data;
Stmt MutateStore(const BufferStoreNode* store,
const Var& loop_var) {
// Extract the shared memory offset from the load indices
if (!load->indices.empty()) {
*shared_offset = load->indices[0];
} else {
*shared_offset = make_const(DataType::Int(32), 0);
}
Array<PrimExpr> new_indices = store->indices;
return true;
}
PrimExpr new_value = store->value;
class BLocalLayoutTransformer : public StmtExprMutator {
public:
BLocalLayoutTransformer(const IRModule& module) : module_(module) {}
Stmt VisitStmt_(const BufferStoreNode* op) override {
// Check if this is a B_local store pattern BEFORE visiting
// to get the original buffer->data vars (not mutated by VisitStmt_)
Var local_var;
Var shared_var;
PrimExpr shared_offset;
if (!IsBLocalStorePattern(op, &local_var, &shared_var, &shared_offset)) {
// Only visit if not our target pattern
return Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
}
std::cout<<"[DEBUG BLocalLayoutTransformer VisitStmt_] BufferStoreNode buffer name: " << op->buffer->name << std::endl;
// For ds_read_vector: ds_read_vector(dst, src, m, n, offset)
// m, n describe the 2D layout of the shared memory tile
// For B_local (16x32 tile): m=16, n=32
PrimExpr m = make_const(DataType::Int(32), 16);
PrimExpr n = make_const(DataType::Int(32), 32);
PrimExpr offset = shared_offset;
// Create the ds_read call
// ds_read_vector(local_ptr, shared_ptr, m, n, offset)
// Use the vars directly - don't call VisitExpr on them as that creates new Vars
Array<PrimExpr> ds_read_args = {
local_var, // dst: local buffer pointer
op->buffer->data, // src: shared memory pointer
m, // m: rows in shared memory tile
n, // n: columns in shared memory tile
offset // offset: starting offset in shared memory
};
Call ds_read_call = Call(DataType::Handle(), ds_read_vector(), ds_read_args);
// Replace the BufferStore with the ds_read call
return Evaluate(ds_read_call);
}
// 修改切片跨度:
// 原来 j*vec : j*vec+vec
// 改为 j*vec : j*vec*expand + vec
private:
const IRModule& module_;
};
PrimExpr idx = store->indices[0];
//T.Ramp(j * 4, 1, 4) -> Ramp(j*8, 1, 4)
std::cout << idx << std::endl;
/*!
* \brief Inject prefetch for B_local using ds_read_vector
*/
class BLocalPrefetchInjector : public StmtMutator {
public:
BLocalPrefetchInjector(const IRModule& module) : module_(module) {}
Stmt VisitStmt_(const ForNode* op) override {
if (op->kind == ForKind::kParallel || op->kind == ForKind::kSerial ||
op->kind == ForKind::kVectorized) {
Stmt body = VisitStmt(op->body);
// Check if body contains B_local stores
if (ContainsBLocalStore(body)) {
// Inject prefetch before the loop
Stmt prefetch = GenerateBLocalPrefetch();
return SeqStmt({prefetch, For(op->loop_var, op->min, op->extent,
op->kind, body, op->thread_binding,
op->annotations)});
// 解析 j*vec 结构
// 假设结构为 j * vec + const
// 不改 RHS
// PrimExpr value = store->value;
// 修改写入向量宽度
// 原 value 是 Ramp(base=j*4, stride=1, lanes=4)
// 匹配 j * stride
// Ramp(base=j*8, stride=1, lanes=8)
if (const auto* ramp = idx.as<RampNode>()) {
PrimExpr base = ramp->base;
PrimExpr stride = ramp->stride;
int old_lanes = ramp->lanes.as<IntImmNode>()->value;
int new_lanes = old_lanes * expand_;
// 匹配 base = j * stride_val
if (const auto* mul = base.as<MulNode>()) {
if (mul->a.same_as(loop_var)) {
int64_t old_stride =
mul->b.as<IntImmNode>()->value;
int64_t new_stride =
old_stride * expand_;
PrimExpr new_base =
loop_var *
make_const(DataType::Int(32), new_stride);
new_indices.Set(
0,
Ramp(new_base, stride, new_lanes));
}
else if (mul->b.same_as(loop_var)) {
int64_t old_stride =
mul->a.as<IntImmNode>()->value;
int64_t new_stride =
old_stride * expand_;
PrimExpr new_base =
make_const(DataType::Int(32), new_stride) *
loop_var;
new_indices.Set(
0,
Ramp(new_base, stride, new_lanes));
}
}
}
return For(op->loop_var, op->min, op->extent, op->kind, body,
op->thread_binding, op->annotations);
if (auto* load = new_value.as<BufferLoadNode>()) {
// BufferLoad with region access: B_shared[start : end]
// end - start = lanes,需要同步扩展
Array<PrimExpr> value_indices = load->indices;
if (auto* old_ramp = load->indices[0].as<RampNode>()) {
PrimExpr scalar_base = old_ramp->base; // 必须是 scalar
PrimExpr stride = old_ramp->stride;
//RHS 4 lane
int old_lanes = old_ramp->lanes.as<IntImmNode>()->value;
//RHS 8 lane
int new_lanes = old_lanes * expand_;
value_indices.Set(
0,
Ramp(scalar_base, stride, new_lanes)
);
new_value = BufferLoad(load->buffer, value_indices);
}
}
return StmtMutator::VisitStmt_(op);
return BufferStore(store->buffer,
new_value,
new_indices);
}
};
private:
Stmt GenerateBLocalPrefetch() {
// Placeholder: actual implementation depends on the specific
// shared memory layout and thread block configuration
return Evaluate(0);
}
const IRModule& module_;
};
Stmt InjectBLocalLayoutTransformPass(Stmt stmt, int expand) {
return BLocalLayoutTransformer(expand)(std::move(stmt));
}
using namespace tir::transform;
......@@ -209,33 +208,16 @@ tvm::transform::Pass InjectBLocalLayoutTransform() {
}
auto* n = f.CopyOnWrite();
n->body = BLocalLayoutTransformer(m)(n->body);
n->body = InjectBLocalLayoutTransformPass(n->body, 2);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.InjectBLocalLayoutTransform", {});
}
tvm::transform::Pass InjectBLocalLayoutTransformWithPrefetch() {
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
// Only apply to DCU targets
if (!IsDCUTarget(m)) {
return f;
}
auto* n = f.CopyOnWrite();
n->body = BLocalPrefetchInjector(m)(n->body);
n->body = BLocalLayoutTransformer(m)(n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.InjectBLocalLayoutTransformWithPrefetch", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectBLocalLayoutTransform",
InjectBLocalLayoutTransform);
refl::GlobalDef().def("tl.transform.InjectBLocalLayoutTransformWithPrefetch",
InjectBLocalLayoutTransformWithPrefetch);
}
} // namespace tl
......
......@@ -235,8 +235,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.ConfigIndexBitwidth()(mod)
mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod)
# Transform B_local layout from shared memory thread-interleaved to local row-major
# mod = tilelang.transform.InjectBLocalLayoutTransform()(mod)
mod = tilelang.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod)
......@@ -295,9 +293,10 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# Transform threadblock to persistent threadblock
mod = tilelang.transform.PersistThreadblock()(mod)
if dcu_async_copy_supported(target):
print("--------------support dcu async copy------------------")
mod = tilelang.transform.LowerSharedGlobalCopy()(mod)
mod = tilelang.transform.FixDCUWaitCount()(mod)
#
mod = tilelang.transform.InjectBLocalLayoutTransform()(mod)
# mod = tilelang.transform.SimplifyDCUAsyncCopy()(mod)
print("OptimizeForTarget3")
print(mod)
......
......@@ -386,20 +386,6 @@ def InjectBLocalLayoutTransform():
return _ffi_api.InjectBLocalLayoutTransform() # type: ignore
def InjectBLocalLayoutTransformWithPrefetch():
"""Transform B_local layout with prefetch injection.
This pass is similar to InjectBLocalLayoutTransform but also injects
prefetch operations for B_local before the main transformation.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InjectBLocalLayoutTransformWithPrefetch() # type: ignore
def LowerDeviceStorageAccessInfo():
"""Lower attached storage access information on device.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment