Commit dd95e41b authored by wangziyang's avatar wangziyang
Browse files

add B_local layout transformation with loop optimization

parent bba13746
...@@ -41,163 +41,162 @@ namespace tl { ...@@ -41,163 +41,162 @@ namespace tl {
using namespace tir; using namespace tir;
/*! /*!
* \brief Check if a statement contains B_local stores * \brief Transformer that handles B_local layout transformation with loop optimization
*
* This transformer handles two cases:
* 1. B_local store with outer loop: halve the loop extent and double the offset
* 2. B_local store without outer loop: just double the offset
*/ */
bool ContainsBLocalStore(const Stmt& stmt) { class BLocalLayoutTransformer : public StmtExprMutator {
bool found = false; public:
tir::PreOrderVisit(stmt, [&](const ObjectRef& node) -> bool { explicit BLocalLayoutTransformer(int expand)
if (found) { : expand_(expand) {}
return false;
} private:
if (const auto* store = node.as<BufferStoreNode>()) { int expand_;
std::string name = store->buffer->name;
if (name.find("B_local") != std::string::npos) { Stmt VisitStmt_(const ForNode* op) final {
found = true; // 只处理 serial 外层循环
return false; if (op->kind != ForKind::kSerial) {
} return StmtExprMutator::VisitStmt_(op);
} }
return true;
});
return found;
}
/*! // 判断是否是 B_local 写循环
* \brief Check if this is a B_local store pattern auto store = op->body.as<BufferStoreNode>();
* if (!store) {
* Pattern to match: return StmtExprMutator::VisitStmt_(op);
* B_local[index] = B_shared[index_expr]
*
* Where B_shared[index_expr] is a complex expression involving:
* - thread_binding (threadIdx.x, threadIdx.y, etc.)
* - ki (iteration variable)
* - j and local_id (loop variables)
*/
bool IsBLocalStorePattern(const BufferStoreNode* op,
Var* local_var,
Var* shared_var,
PrimExpr* shared_offset) {
// Check if store is to a local buffer named B_local
std::string buffer_name = op->buffer->name;
if (buffer_name.find("B_local") == std::string::npos) {
return false;
} }
// Must have exactly one index: B_local[index] if (!IsBLocal(store->buffer)) {
if (op->indices.size() != 1) { return StmtExprMutator::VisitStmt_(op);
return false;
} }
// Check if value is a BufferLoad from shared memory int64_t old_extent = op->extent.as<IntImmNode>()->value;
const BufferLoadNode* load = op->value.as<BufferLoadNode>();
if (load == nullptr) { ICHECK(old_extent % expand_ == 0)
return false; << "Loop extent must be divisible by expand factor.";
int64_t new_extent = old_extent / expand_;
// 修改循环范围
For new_for =
For(op->loop_var,
op->min,
Integer(new_extent),
op->kind,
MutateStore(store, op->loop_var));
return new_for;
} }
// Check if load is from shared memory bool IsBLocal(const Buffer& buffer) {
std::string load_buffer_name = load->buffer->name; std::string name = buffer->name;
std::cout<<"[DEBUG IsBLocalStorePattern] load buffer name: " << load_buffer_name << std::endl; return name.find("B_local") != std::string::npos;
if (load_buffer_name.find("B_shared") == std::string::npos) {
return false;
} }
// Get buffer variables Stmt MutateStore(const BufferStoreNode* store,
*local_var = op->buffer->data; const Var& loop_var) {
*shared_var = load->buffer->data;
// Extract the shared memory offset from the load indices Array<PrimExpr> new_indices = store->indices;
if (!load->indices.empty()) {
*shared_offset = load->indices[0];
} else {
*shared_offset = make_const(DataType::Int(32), 0);
}
return true; PrimExpr new_value = store->value;
}
class BLocalLayoutTransformer : public StmtExprMutator { // 修改切片跨度:
public: // 原来 j*vec : j*vec+vec
BLocalLayoutTransformer(const IRModule& module) : module_(module) {} // 改为 j*vec : j*vec*expand + vec
Stmt VisitStmt_(const BufferStoreNode* op) override {
// Check if this is a B_local store pattern BEFORE visiting
// to get the original buffer->data vars (not mutated by VisitStmt_)
Var local_var;
Var shared_var;
PrimExpr shared_offset;
if (!IsBLocalStorePattern(op, &local_var, &shared_var, &shared_offset)) {
// Only visit if not our target pattern
return Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
}
std::cout<<"[DEBUG BLocalLayoutTransformer VisitStmt_] BufferStoreNode buffer name: " << op->buffer->name << std::endl;
// For ds_read_vector: ds_read_vector(dst, src, m, n, offset)
// m, n describe the 2D layout of the shared memory tile
// For B_local (16x32 tile): m=16, n=32
PrimExpr m = make_const(DataType::Int(32), 16);
PrimExpr n = make_const(DataType::Int(32), 32);
PrimExpr offset = shared_offset;
// Create the ds_read call
// ds_read_vector(local_ptr, shared_ptr, m, n, offset)
// Use the vars directly - don't call VisitExpr on them as that creates new Vars
Array<PrimExpr> ds_read_args = {
local_var, // dst: local buffer pointer
op->buffer->data, // src: shared memory pointer
m, // m: rows in shared memory tile
n, // n: columns in shared memory tile
offset // offset: starting offset in shared memory
};
Call ds_read_call = Call(DataType::Handle(), ds_read_vector(), ds_read_args); PrimExpr idx = store->indices[0];
//T.Ramp(j * 4, 1, 4) -> Ramp(j*8, 1, 4)
std::cout << idx << std::endl;
// Replace the BufferStore with the ds_read call // 解析 j*vec 结构
return Evaluate(ds_read_call); // 假设结构为 j * vec + const
}
private: // 不改 RHS
const IRModule& module_; // PrimExpr value = store->value;
};
/*! // 修改写入向量宽度
* \brief Inject prefetch for B_local using ds_read_vector // 原 value 是 Ramp(base=j*4, stride=1, lanes=4)
*/ // 匹配 j * stride
class BLocalPrefetchInjector : public StmtMutator { // Ramp(base=j*8, stride=1, lanes=8)
public: if (const auto* ramp = idx.as<RampNode>()) {
BLocalPrefetchInjector(const IRModule& module) : module_(module) {}
Stmt VisitStmt_(const ForNode* op) override {
if (op->kind == ForKind::kParallel || op->kind == ForKind::kSerial ||
op->kind == ForKind::kVectorized) {
Stmt body = VisitStmt(op->body);
// Check if body contains B_local stores
if (ContainsBLocalStore(body)) {
// Inject prefetch before the loop
Stmt prefetch = GenerateBLocalPrefetch();
return SeqStmt({prefetch, For(op->loop_var, op->min, op->extent,
op->kind, body, op->thread_binding,
op->annotations)});
}
return For(op->loop_var, op->min, op->extent, op->kind, body, PrimExpr base = ramp->base;
op->thread_binding, op->annotations); PrimExpr stride = ramp->stride;
int old_lanes = ramp->lanes.as<IntImmNode>()->value;
int new_lanes = old_lanes * expand_;
// 匹配 base = j * stride_val
if (const auto* mul = base.as<MulNode>()) {
if (mul->a.same_as(loop_var)) {
int64_t old_stride =
mul->b.as<IntImmNode>()->value;
int64_t new_stride =
old_stride * expand_;
PrimExpr new_base =
loop_var *
make_const(DataType::Int(32), new_stride);
new_indices.Set(
0,
Ramp(new_base, stride, new_lanes));
} }
else if (mul->b.same_as(loop_var)) {
int64_t old_stride =
mul->a.as<IntImmNode>()->value;
int64_t new_stride =
old_stride * expand_;
return StmtMutator::VisitStmt_(op); PrimExpr new_base =
make_const(DataType::Int(32), new_stride) *
loop_var;
new_indices.Set(
0,
Ramp(new_base, stride, new_lanes));
}
}
} }
private: if (auto* load = new_value.as<BufferLoadNode>()) {
Stmt GenerateBLocalPrefetch() { // BufferLoad with region access: B_shared[start : end]
// Placeholder: actual implementation depends on the specific // end - start = lanes,需要同步扩展
// shared memory layout and thread block configuration Array<PrimExpr> value_indices = load->indices;
return Evaluate(0); if (auto* old_ramp = load->indices[0].as<RampNode>()) {
PrimExpr scalar_base = old_ramp->base; // 必须是 scalar
PrimExpr stride = old_ramp->stride;
//RHS 4 lane
int old_lanes = old_ramp->lanes.as<IntImmNode>()->value;
//RHS 8 lane
int new_lanes = old_lanes * expand_;
value_indices.Set(
0,
Ramp(scalar_base, stride, new_lanes)
);
new_value = BufferLoad(load->buffer, value_indices);
}
} }
const IRModule& module_; return BufferStore(store->buffer,
new_value,
new_indices);
}
}; };
Stmt InjectBLocalLayoutTransformPass(Stmt stmt, int expand) {
return BLocalLayoutTransformer(expand)(std::move(stmt));
}
using namespace tir::transform; using namespace tir::transform;
tvm::transform::Pass InjectBLocalLayoutTransform() { tvm::transform::Pass InjectBLocalLayoutTransform() {
...@@ -209,33 +208,16 @@ tvm::transform::Pass InjectBLocalLayoutTransform() { ...@@ -209,33 +208,16 @@ tvm::transform::Pass InjectBLocalLayoutTransform() {
} }
auto* n = f.CopyOnWrite(); auto* n = f.CopyOnWrite();
n->body = BLocalLayoutTransformer(m)(n->body); n->body = InjectBLocalLayoutTransformPass(n->body, 2);
return f; return f;
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.InjectBLocalLayoutTransform", {}); return CreatePrimFuncPass(pass_func, 0, "tl.InjectBLocalLayoutTransform", {});
} }
tvm::transform::Pass InjectBLocalLayoutTransformWithPrefetch() {
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
// Only apply to DCU targets
if (!IsDCUTarget(m)) {
return f;
}
auto* n = f.CopyOnWrite();
n->body = BLocalPrefetchInjector(m)(n->body);
n->body = BLocalLayoutTransformer(m)(n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.InjectBLocalLayoutTransformWithPrefetch", {});
}
TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectBLocalLayoutTransform", refl::GlobalDef().def("tl.transform.InjectBLocalLayoutTransform",
InjectBLocalLayoutTransform); InjectBLocalLayoutTransform);
refl::GlobalDef().def("tl.transform.InjectBLocalLayoutTransformWithPrefetch",
InjectBLocalLayoutTransformWithPrefetch);
} }
} // namespace tl } // namespace tl
......
...@@ -235,8 +235,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -235,8 +235,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.ConfigIndexBitwidth()(mod) mod = tilelang.transform.ConfigIndexBitwidth()(mod)
mod = tir.transform.Simplify()(mod) mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod) mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod)
# Transform B_local layout from shared memory thread-interleaved to local row-major
# mod = tilelang.transform.InjectBLocalLayoutTransform()(mod)
mod = tilelang.transform.StorageRewrite()(mod) mod = tilelang.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod) mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod) mod = tir.transform.RenormalizeSplitPattern()(mod)
...@@ -295,9 +293,10 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -295,9 +293,10 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# Transform threadblock to persistent threadblock # Transform threadblock to persistent threadblock
mod = tilelang.transform.PersistThreadblock()(mod) mod = tilelang.transform.PersistThreadblock()(mod)
if dcu_async_copy_supported(target): if dcu_async_copy_supported(target):
print("--------------support dcu async copy------------------")
mod = tilelang.transform.LowerSharedGlobalCopy()(mod) mod = tilelang.transform.LowerSharedGlobalCopy()(mod)
mod = tilelang.transform.FixDCUWaitCount()(mod) mod = tilelang.transform.FixDCUWaitCount()(mod)
# mod = tilelang.transform.InjectBLocalLayoutTransform()(mod)
# mod = tilelang.transform.SimplifyDCUAsyncCopy()(mod) # mod = tilelang.transform.SimplifyDCUAsyncCopy()(mod)
print("OptimizeForTarget3") print("OptimizeForTarget3")
print(mod) print(mod)
......
...@@ -386,20 +386,6 @@ def InjectBLocalLayoutTransform(): ...@@ -386,20 +386,6 @@ def InjectBLocalLayoutTransform():
return _ffi_api.InjectBLocalLayoutTransform() # type: ignore return _ffi_api.InjectBLocalLayoutTransform() # type: ignore
def InjectBLocalLayoutTransformWithPrefetch():
"""Transform B_local layout with prefetch injection.
This pass is similar to InjectBLocalLayoutTransform but also injects
prefetch operations for B_local before the main transformation.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InjectBLocalLayoutTransformWithPrefetch() # type: ignore
def LowerDeviceStorageAccessInfo(): def LowerDeviceStorageAccessInfo():
"""Lower attached storage access information on device. """Lower attached storage access information on device.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment