"...git@developer.sourcefind.cn:yangql/composable_kernel.git" did not exist on "56adf7e9cc4fcf6592151281a727e96b625bc54f"
Commit 34e0883d authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Replace thread binding detector in LayoutInference Pass (#31)

* [Refactor] Rename AllocateCollector to ThreadBindingCollector and streamline thread binding logic

* [Refactor] Adjust formatting in ThreadBindingCollector for consistency

* [Refactor] Enhance clang-tidy check to handle cases with no changed C/C++ files

* [Refactor] Remove clang-tidy checks from format script to streamline formatting process
parent fee42951
...@@ -254,62 +254,6 @@ if ! git diff --quiet &>/dev/null; then ...@@ -254,62 +254,6 @@ if ! git diff --quiet &>/dev/null; then
exit 1 exit 1
fi fi
# Check if clang-tidy is installed and get the version
if command -v clang-tidy &>/dev/null; then
CLANG_TIDY_VERSION=$(clang-tidy --version | head -n 1 | awk '{print $3}')
tool_version_check "clang-tidy" "$CLANG_TIDY_VERSION" "$(grep clang-tidy requirements-dev.txt | cut -d'=' -f3)"
else
echo "clang-tidy not found. Skipping C++ static analysis."
CLANG_TIDY_AVAILABLE=false
fi
# Function to run clang-tidy
clang_tidy() {
clang-tidy "$@" -- -std=c++17
}
# Run clang-tidy on all C/C++ files
clang_tidy_all() {
find . -type f \( -name '*.c' -o -name '*.cc' -o -name '*.cpp' -o -name '*.h' -o -name '*.hpp' \) \
-not -path "./3rdparty/*" -not -path "./build/*" \
| xargs -n 1 clang-tidy -- -std=c++17
}
# Run clang-tidy on changed C/C++ files relative to main
clang_tidy_changed() {
if git show-ref --verify --quiet refs/remotes/origin/main; then
BASE_BRANCH="origin/main"
else
BASE_BRANCH="main"
fi
MERGEBASE="$(git merge-base $BASE_BRANCH HEAD)"
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.c' '*.cc' '*.cpp' '*.h' '*.hpp' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.c' '*.cc' '*.cpp' '*.h' '*.hpp' | xargs -n 1 clang-tidy -- -std=c++17
fi
}
# Add clang-tidy support to the main script logic
echo 'tile-lang clang-tidy: Check Start'
if [[ "$CLANG_TIDY_AVAILABLE" != false ]]; then
if [[ "$1" == '--files' ]]; then
# If --files is given, analyze only the provided files
clang_tidy "${@:2}"
elif [[ "$1" == '--all' ]]; then
# If --all is given, analyze all eligible C/C++ files
clang_tidy_all
else
# Otherwise, analyze only changed C/C++ files
clang_tidy_changed
fi
else
echo "clang-tidy is not available. Skipping static analysis."
fi
echo 'tile-lang clang-tidy: Done'
if ! git diff --quiet &>/dev/null; then if ! git diff --quiet &>/dev/null; then
echo 'Reformatted files. Please review and stage the changes.' echo 'Reformatted files. Please review and stage the changes.'
echo 'Changes not staged for commit:' echo 'Changes not staged for commit:'
......
...@@ -43,76 +43,21 @@ namespace tl { ...@@ -43,76 +43,21 @@ namespace tl {
using namespace tir; using namespace tir;
using runtime::StorageRank;
using runtime::StorageScope;
static bool IsDynamicSharedMemory(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ".dyn";
}
static bool IsStaticSharedMemory(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == "";
}
static bool isLocalFragment(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kLocal &&
storage_scope.tag == ".fragment";
}
/*! /*!
* \brief collect the mapping from the buffer var to its allocate * \brief collect the mapping from the buffer var to its allocate
*/ */
class AllocateCollector : public StmtExprVisitor { class ThreadBindingCollector : public StmtExprVisitor {
public: public:
void VisitStmt_(const AllocateNode *op) final {
if (IsDynamicSharedMemory(op->buffer_var)) {
dyn_shmem_allocs_[op->buffer_var.get()] = op;
} else if (IsStaticSharedMemory(op->buffer_var)) {
static_shmem_allocs_[op->buffer_var.get()] = op;
} else if (isLocalFragment(op->buffer_var)) {
local_fragment_allocs_[op->buffer_var.get()] = op;
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const BlockNode *op) final {
for (auto buffer : op->alloc_buffers) {
if (IsDynamicSharedMemory(buffer->data)) {
dyn_shmem_allocs_[buffer->data.get()] = op;
} else if (IsStaticSharedMemory(buffer->data)) {
static_shmem_allocs_[buffer->data.get()] = op;
} else if (isLocalFragment(buffer->data)) {
local_fragment_allocs_[buffer->data.get()] = op;
}
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const AllocateConstNode *op) final {
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const SeqStmtNode *op) final {
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const AttrStmtNode *op) final { void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
thread_binding_[iv->var.get()] = iv;
}
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
} }
// The dynamic mapping from the original buffer var to its allocate // The thread binding map
std::unordered_map<const VarNode *, const Object *> dyn_shmem_allocs_; std::unordered_map<const VarNode *, IterVar> thread_binding_;
// The static mapping from the original buffer var to its allocate
std::unordered_map<const VarNode *, const Object *> static_shmem_allocs_;
// The local fragment mapping from the original buffer var to its allocate
std::unordered_map<const VarNode *, const Object *> local_fragment_allocs_;
}; };
using namespace tir; using namespace tir;
...@@ -477,15 +422,10 @@ private: ...@@ -477,15 +422,10 @@ private:
tvm::transform::Pass LayoutInference() { tvm::transform::Pass LayoutInference() {
using namespace tir::transform; using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
AllocateCollector collector; ThreadBindingCollector collector;
collector(f->body); collector(f->body);
// TODO(Lei): This is a hack to avoid the issue of thread partition bool has_thread_binding = collector.thread_binding_.size() > 0;
// for cpu backend. We should remove this after we have a better bool skip_thread_partition = !has_thread_binding;
// solution for thread partition detect.
bool need_thread_partition = (collector.dyn_shmem_allocs_.size() > 1 ||
collector.static_shmem_allocs_.size() > 1 ||
collector.local_fragment_allocs_.size() > 1);
bool skip_thread_partition = !need_thread_partition;
return LayoutInferencer::Substitute(std::move(f), skip_thread_partition); return LayoutInferencer::Substitute(std::move(f), skip_thread_partition);
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment