Commit 34e0883d authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Replace thread binding detector in LayoutInference Pass (#31)

* [Refactor] Rename AllocateCollector to ThreadBindingCollector and streamline thread binding logic

* [Refactor] Adjust formatting in ThreadBindingCollector for consistency

* [Refactor] Enhance clang-tidy check to handle cases with no changed C/C++ files

* [Refactor] Remove clang-tidy checks from format script to streamline formatting process
parent fee42951
......@@ -254,62 +254,6 @@ if ! git diff --quiet &>/dev/null; then
exit 1
fi
# Check if clang-tidy is installed and get the version
if command -v clang-tidy &>/dev/null; then
CLANG_TIDY_VERSION=$(clang-tidy --version | head -n 1 | awk '{print $3}')
tool_version_check "clang-tidy" "$CLANG_TIDY_VERSION" "$(grep clang-tidy requirements-dev.txt | cut -d'=' -f3)"
else
echo "clang-tidy not found. Skipping C++ static analysis."
CLANG_TIDY_AVAILABLE=false
fi
# Function to run clang-tidy
clang_tidy() {
clang-tidy "$@" -- -std=c++17
}
# Run clang-tidy on all C/C++ files
clang_tidy_all() {
find . -type f \( -name '*.c' -o -name '*.cc' -o -name '*.cpp' -o -name '*.h' -o -name '*.hpp' \) \
-not -path "./3rdparty/*" -not -path "./build/*" \
| xargs -n 1 clang-tidy -- -std=c++17
}
# Run clang-tidy on changed C/C++ files relative to main
clang_tidy_changed() {
if git show-ref --verify --quiet refs/remotes/origin/main; then
BASE_BRANCH="origin/main"
else
BASE_BRANCH="main"
fi
MERGEBASE="$(git merge-base $BASE_BRANCH HEAD)"
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.c' '*.cc' '*.cpp' '*.h' '*.hpp' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.c' '*.cc' '*.cpp' '*.h' '*.hpp' | xargs -n 1 clang-tidy -- -std=c++17
fi
}
# Add clang-tidy support to the main script logic
echo 'tile-lang clang-tidy: Check Start'
if [[ "$CLANG_TIDY_AVAILABLE" != false ]]; then
if [[ "$1" == '--files' ]]; then
# If --files is given, analyze only the provided files
clang_tidy "${@:2}"
elif [[ "$1" == '--all' ]]; then
# If --all is given, analyze all eligible C/C++ files
clang_tidy_all
else
# Otherwise, analyze only changed C/C++ files
clang_tidy_changed
fi
else
echo "clang-tidy is not available. Skipping static analysis."
fi
echo 'tile-lang clang-tidy: Done'
if ! git diff --quiet &>/dev/null; then
echo 'Reformatted files. Please review and stage the changes.'
echo 'Changes not staged for commit:'
......
......@@ -43,76 +43,21 @@ namespace tl {
using namespace tir;
using runtime::StorageRank;
using runtime::StorageScope;
static bool IsDynamicSharedMemory(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ".dyn";
}
static bool IsStaticSharedMemory(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == "";
}
static bool isLocalFragment(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kLocal &&
storage_scope.tag == ".fragment";
}
/*!
* \brief collect the mapping from the buffer var to its allocate
*/
class AllocateCollector : public StmtExprVisitor {
class ThreadBindingCollector : public StmtExprVisitor {
public:
void VisitStmt_(const AllocateNode *op) final {
if (IsDynamicSharedMemory(op->buffer_var)) {
dyn_shmem_allocs_[op->buffer_var.get()] = op;
} else if (IsStaticSharedMemory(op->buffer_var)) {
static_shmem_allocs_[op->buffer_var.get()] = op;
} else if (isLocalFragment(op->buffer_var)) {
local_fragment_allocs_[op->buffer_var.get()] = op;
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const BlockNode *op) final {
for (auto buffer : op->alloc_buffers) {
if (IsDynamicSharedMemory(buffer->data)) {
dyn_shmem_allocs_[buffer->data.get()] = op;
} else if (IsStaticSharedMemory(buffer->data)) {
static_shmem_allocs_[buffer->data.get()] = op;
} else if (isLocalFragment(buffer->data)) {
local_fragment_allocs_[buffer->data.get()] = op;
}
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const AllocateConstNode *op) final {
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const SeqStmtNode *op) final {
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
thread_binding_[iv->var.get()] = iv;
}
StmtExprVisitor::VisitStmt_(op);
}
// The dynamic mapping from the original buffer var to its allocate
std::unordered_map<const VarNode *, const Object *> dyn_shmem_allocs_;
// The static mapping from the original buffer var to its allocate
std::unordered_map<const VarNode *, const Object *> static_shmem_allocs_;
// The local fragment mapping from the original buffer var to its allocate
std::unordered_map<const VarNode *, const Object *> local_fragment_allocs_;
// The thread binding map
std::unordered_map<const VarNode *, IterVar> thread_binding_;
};
using namespace tir;
......@@ -477,15 +422,10 @@ private:
tvm::transform::Pass LayoutInference() {
using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
AllocateCollector collector;
ThreadBindingCollector collector;
collector(f->body);
// TODO(Lei): This is a hack to avoid the issue of thread partition
// for cpu backend. We should remove this after we have a better
// solution for thread partition detect.
bool need_thread_partition = (collector.dyn_shmem_allocs_.size() > 1 ||
collector.static_shmem_allocs_.size() > 1 ||
collector.local_fragment_allocs_.size() > 1);
bool skip_thread_partition = !need_thread_partition;
bool has_thread_binding = collector.thread_binding_.size() > 0;
bool skip_thread_partition = !has_thread_binding;
return LayoutInferencer::Substitute(std::move(f), skip_thread_partition);
};
return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment