"tools/cfgs/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "98281d051622bb982bb97ab6df33d859539ea230"
Commit 74e57416 authored by wangziyang's avatar wangziyang
Browse files

add inject_blocal_layout

parent 15599a93
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file inject_blocal_layout_transform.cc
* \brief Transform B_local layout from shared memory thread-interleaved layout
* to local row-major layout using ds_read_vector hardware instructions.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
#include "tir/ir/buffer_common.h"
#include "tvm/tir/stmt.h"
#include "inject_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
/*!
* \brief Check if a statement contains B_local stores
*/
bool ContainsBLocalStore(const Stmt& stmt) {
bool found = false;
tir::PreOrderVisit(stmt, [&](const ObjectRef& node) -> bool {
if (found) {
return false;
}
if (const auto* store = node.as<BufferStoreNode>()) {
std::string name = store->buffer->name;
if (name.find("B_local") != std::string::npos) {
found = true;
return false;
}
}
return true;
});
return found;
}
/*!
* \brief Check if this is a B_local store pattern
*
* Pattern to match:
* B_local[index] = B_shared[index_expr]
*
* Where B_shared[index_expr] is a complex expression involving:
* - thread_binding (threadIdx.x, threadIdx.y, etc.)
* - ki (iteration variable)
* - j and local_id (loop variables)
*/
bool IsBLocalStorePattern(const BufferStoreNode* op,
Var* local_var,
Var* shared_var,
PrimExpr* shared_offset) {
// Check if store is to a local buffer named B_local
std::string buffer_name = op->buffer->name;
if (buffer_name.find("B_local") == std::string::npos) {
return false;
}
// Must have exactly one index: B_local[index]
if (op->indices.size() != 1) {
return false;
}
// Check if value is a BufferLoad from shared memory
const BufferLoadNode* load = op->value.as<BufferLoadNode>();
if (load == nullptr) {
return false;
}
// Check if load is from shared memory
std::string load_buffer_name = load->buffer->name;
std::cout<<"[DEBUG IsBLocalStorePattern] load buffer name: " << load_buffer_name << std::endl;
if (load_buffer_name.find("B_shared") == std::string::npos) {
return false;
}
// Get buffer variables
*local_var = op->buffer->data;
*shared_var = load->buffer->data;
// Extract the shared memory offset from the load indices
if (!load->indices.empty()) {
*shared_offset = load->indices[0];
} else {
*shared_offset = make_const(DataType::Int(32), 0);
}
return true;
}
class BLocalLayoutTransformer : public StmtExprMutator {
public:
BLocalLayoutTransformer(const IRModule& module) : module_(module) {}
Stmt VisitStmt_(const BufferStoreNode* op) override {
// Check if this is a B_local store pattern BEFORE visiting
// to get the original buffer->data vars (not mutated by VisitStmt_)
Var local_var;
Var shared_var;
PrimExpr shared_offset;
if (!IsBLocalStorePattern(op, &local_var, &shared_var, &shared_offset)) {
// Only visit if not our target pattern
return Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
}
std::cout<<"[DEBUG BLocalLayoutTransformer VisitStmt_] BufferStoreNode buffer name: " << op->buffer->name << std::endl;
// For ds_read_vector: ds_read_vector(dst, src, m, n, offset)
// m, n describe the 2D layout of the shared memory tile
// For B_local (16x32 tile): m=16, n=32
PrimExpr m = make_const(DataType::Int(32), 16);
PrimExpr n = make_const(DataType::Int(32), 32);
PrimExpr offset = shared_offset;
// Create the ds_read call
// ds_read_vector(local_ptr, shared_ptr, m, n, offset)
// Use the vars directly - don't call VisitExpr on them as that creates new Vars
Array<PrimExpr> ds_read_args = {
local_var, // dst: local buffer pointer
op->buffer->data, // src: shared memory pointer
m, // m: rows in shared memory tile
n, // n: columns in shared memory tile
offset // offset: starting offset in shared memory
};
Call ds_read_call = Call(DataType::Handle(), ds_read_vector(), ds_read_args);
// Replace the BufferStore with the ds_read call
return Evaluate(ds_read_call);
}
private:
const IRModule& module_;
};
/*!
* \brief Inject prefetch for B_local using ds_read_vector
*/
class BLocalPrefetchInjector : public StmtMutator {
public:
BLocalPrefetchInjector(const IRModule& module) : module_(module) {}
Stmt VisitStmt_(const ForNode* op) override {
if (op->kind == ForKind::kParallel || op->kind == ForKind::kSerial ||
op->kind == ForKind::kVectorized) {
Stmt body = VisitStmt(op->body);
// Check if body contains B_local stores
if (ContainsBLocalStore(body)) {
// Inject prefetch before the loop
Stmt prefetch = GenerateBLocalPrefetch();
return SeqStmt({prefetch, For(op->loop_var, op->min, op->extent,
op->kind, body, op->thread_binding,
op->annotations)});
}
return For(op->loop_var, op->min, op->extent, op->kind, body,
op->thread_binding, op->annotations);
}
return StmtMutator::VisitStmt_(op);
}
private:
Stmt GenerateBLocalPrefetch() {
// Placeholder: actual implementation depends on the specific
// shared memory layout and thread block configuration
return Evaluate(0);
}
const IRModule& module_;
};
using namespace tir::transform;
tvm::transform::Pass InjectBLocalLayoutTransform() {
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
// Only apply to DCU targets
if (!IsDCUTarget(m)) {
std::cout << "[DEBUG InjectBLocalLayoutTransform] Not a DCU target, skipping" << std::endl;
return f;
}
auto* n = f.CopyOnWrite();
n->body = BLocalLayoutTransformer(m)(n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.InjectBLocalLayoutTransform", {});
}
tvm::transform::Pass InjectBLocalLayoutTransformWithPrefetch() {
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
// Only apply to DCU targets
if (!IsDCUTarget(m)) {
return f;
}
auto* n = f.CopyOnWrite();
n->body = BLocalPrefetchInjector(m)(n->body);
n->body = BLocalLayoutTransformer(m)(n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.InjectBLocalLayoutTransformWithPrefetch", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectBLocalLayoutTransform",
InjectBLocalLayoutTransform);
refl::GlobalDef().def("tl.transform.InjectBLocalLayoutTransformWithPrefetch",
InjectBLocalLayoutTransformWithPrefetch);
}
} // namespace tl
} // namespace tvm
......@@ -33,31 +33,13 @@
#include "../op/builtin.h"
#include "tir/ir/buffer_common.h"
#include "tvm/tir/stmt.h"
#include "inject_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
/*!
* \brief Check if the target is AMD DCU (gfx936, gfx942, etc.)
*/
bool IsDCUTarget(const IRModule& module) {
for (auto& p : module->functions) {
if (auto* prim_func = p.second.as<PrimFuncNode>()) {
if (auto opt_target = prim_func->GetAttr<Target>("target")) {
Target target = opt_target.value();
if (target->attrs.count("mcpu")) {
std::string mcpu = Downcast<tvm::ffi::String>(target->attrs.at("mcpu"));
// if mcpu start with "gfx936", it is DCU
return mcpu.find("gfx936") == 0;
}
}
}
}
return false;
}
class DSReadInjector : public StmtExprMutator {
public:
/*!
......@@ -87,7 +69,7 @@ class DSReadInjector : public StmtExprMutator {
/*!
* \brief Visit BufferStoreNode to inject ds_read_vector call
* Pattern: local_buffer[...] = shared_buffer[...] (BufferLoad)
* Parameters m, n, offset are passed via a preceding CallNode
* Parameters m, n, offset are passed via a CallNode (tl.ds_read_config)
*/
Stmt VisitStmt_(const BufferStoreNode* op) override {
std::cout << "[DEBUG VisitStmt_] Visiting BufferStoreNode" << std::endl;
......@@ -118,24 +100,18 @@ class DSReadInjector : public StmtExprMutator {
return StmtExprMutator::VisitStmt_(op);
}
// Found pattern: local = BufferLoad(shared)
// The m, n, offset parameters should come from a CallNode in the IR
// For now, use default values that will be replaced when CallNode is processed
std::cout << "[DEBUG BufferStore] Injecting ds_read_vector call!" << std::endl;
// Get parameters from the Store's indices or use default values
// In a full implementation, these would come from a preceding CallNode
PrimExpr m = make_const(DataType::Int(32), 16);
// For A_shared, use the actual shared memory base pointer
PrimExpr m = make_const(DataType::Int(32), 32);
PrimExpr n = make_const(DataType::Int(32), 16);
PrimExpr offset = make_const(DataType::Int(32), 0);
// Visit all arguments to transform any nested expressions
// Use buffer data vars directly
Array<PrimExpr> new_args = {
VisitExpr(load->buffer.access_ptr(0, DataType::Handle(), 1, 0)), // src
VisitExpr(op->buffer->data), // dst
VisitExpr(m),
VisitExpr(n),
VisitExpr(offset)
load->buffer->data, // src
op->buffer->data, // dst
m,
n,
offset
};
// Create the ds_read call
......@@ -169,4 +145,4 @@ TVM_FFI_STATIC_INIT_BLOCK() {
}
} // namespace tl
} // namespace tvm
} // namespace tvm
\ No newline at end of file
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file inject_utils.cc
* \brief Common utilities for injection transforms.
*/
#include "inject_utils.h"
#include "../target/utils.h"
#include <iostream>
namespace tvm {
namespace tl {
using namespace tir;
bool IsDCUTarget(const IRModule& module) {
return TargetIsDCU(Target::Current(false));
}
} // namespace tl
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file inject_utils.h
* \brief Common utilities for injection transforms.
*/
#ifndef TVM_TL_INJECT_UTILS_H_
#define TVM_TL_INJECT_UTILS_H_
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
namespace tvm {
namespace tl {
/*!
* \brief Check if the target is AMD DCU (gfx936, gfx942, etc.)
*/
bool IsDCUTarget(const IRModule& module);
} // namespace tl
} // namespace tvm
#endif // TVM_TL_INJECT_UTILS_H_
......@@ -181,9 +181,7 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# print("********************")
# print(mod)
# print("********************")
pass_ctx = tilelang.transform.get_pass_context()
# Lower the barrier.arrive into specific initialization slot
mod = tilelang.transform.LowerSharedBarrier()(mod)
......@@ -229,6 +227,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.ConfigIndexBitwidth()(mod)
mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod)
# Transform B_local layout from shared memory thread-interleaved to local row-major
mod = tilelang.transform.InjectBLocalLayoutTransform()(mod)
mod = tilelang.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod)
......@@ -265,11 +265,14 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge)(mod)
mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
# Inject PTX async copy must behind the thread sync pass
# as ptx async copy won't be recognized as a valid buffer load
mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
# Inject ds_read for shared to register memory copy on DCU
mod = tilelang.transform.InjectDSRead()(mod)
print(mod)
if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target):
mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod)
mod = tilelang.transform.MakePackedAPI()(mod)
......
......@@ -115,8 +115,7 @@ class MatrixCoreIntrinEmitter:
if a_dtype.bits == 32:
self.k_dim = 4
elif a_dtype.bits in {16, 8}:
# self.k_dim = 16
self.k_dim = 256
self.k_dim = 16
else:
raise ValueError(f"Unsupported a_dtype = {a_dtype}")
......
......@@ -360,6 +360,46 @@ def InjectDSRead():
return _ffi_api.InjectDSRead() # type: ignore
def InjectBLocalLayoutTransform():
"""Transform B_local layout from shared memory thread-interleaved to local row-major.
This pass specifically handles the B_local buffer layout transformation in GEMM kernels
for AMD DCU (gfx936, gfx942, etc.). It converts complex indexed BufferStore patterns
from shared memory into vectorized ds_read_vector hardware instructions.
B Layout Transformation:
- Shared Memory Layout (per thread in warp, 16 elements):
Thread 0: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15]
Thread 1: [16,17,18,... ,31 ]
...
- Local Register Layout (16x32, row-major):
Row 0: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15, 0, 1,...]
Row 1: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15, 0, 1,...]
...
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
print("Injecting Transform B_local layout from shared memory thread-interleaved to local row-major for DCU.")
return _ffi_api.InjectBLocalLayoutTransform() # type: ignore
def InjectBLocalLayoutTransformWithPrefetch():
"""Transform B_local layout with prefetch injection.
This pass is similar to InjectBLocalLayoutTransform but also injects
prefetch operations for B_local before the main transformation.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InjectBLocalLayoutTransformWithPrefetch() # type: ignore
def LowerDeviceStorageAccessInfo():
"""Lower attached storage access information on device.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment