Unverified Commit cdc5d8d3 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Introduce clang-tidy into format.sh (#777)

* [Refactor] Update Clang-Tidy Checks and Improve Code Consistency

- Enhanced .clang-tidy configuration by adding specific checks for better bug detection and performance optimization.
- Refactored function signatures across multiple files to use `const` references for parameters, improving performance and code clarity.
- Updated various methods to ensure consistent handling of parameters, particularly in `AddPredicate`, `Substitute`, and `PlanLoopPartition` functions.
- Improved readability by replacing size checks with `empty()` method calls in several locations, ensuring clearer intent in the code.
- General code cleanup and adherence to best practices for better maintainability.

* [Refactor] Enhance Code Consistency and Clang-Tidy Configuration

- Updated .clang-tidy configuration to include additional checks for improved code quality and performance.
- Refactored function signatures across multiple files to use `const` references, enhancing performance and clarity.
- Replaced size checks with `empty()` method calls in various locations for clearer intent.
- Improved handling of parameters in several functions, ensuring consistent usage of `std::move` where applicable.
- General code cleanup to adhere to best practices and improve maintainability.

* [Refactor] Integrate Clang-Tidy Checks and Enhance Code Consistency

- Added clang-tidy checks to the format script for improved code quality assurance.
- Refactored function signatures across multiple files to consistently use `const` references, enhancing performance and clarity.
- Updated the requirements-lint.txt file to include clang-tidy as a dependency.
- General code cleanup to adhere to best practices and improve maintainability.

* [CI] Update AMD CI Workflow to Include Build Directory Creation

- Added steps to create a build directory and configure CMake with ROCm support during the format check process.
- Ensured cleanup of the build directory after the format check to maintain a clean workspace.

* [Refactor] Remove Unused Member Variables in AtomicAddNode and CopyNode

- Removed the `args_` member variable from both `AtomicAddNode` and `CopyNode` classes to streamline the code and eliminate unnecessary data members.
- This change enhances code clarity and maintainability by focusing on relevant attributes for each class.

* [Refactor] Update Clang-Tidy Integration and Code Improvements

- Modified the format script to include the `-fix` option in the clang-tidy command for automatic code fixes.
- Refactored the `AtomicAddVectorizePlanner` class to improve variable handling and consistency, including changes to member variable types and function signatures.
- Enhanced code clarity by removing unnecessary `std::move` calls and ensuring consistent usage of types across the class.
- General code cleanup to adhere to best practices and improve maintainability.

* [Refactor] Improve Parameter Handling and Consistency in AtomicAddVectorize

- Updated function signatures in `AtomicAddVectorizePlanResult` and `AtomicAddVectorizeRewriter` to use `const` references and `std::move` for better performance and clarity.
- Enhanced the `UpdateVectorSize` method to accept `const Array<PrimExpr>&` for improved efficiency.
- General code cleanup to maintain consistency and adhere to best practices.

* [CI] Add Git Submodule Initialization to CI Workflow

- Included a step to initialize and update git submodules recursively in the CI workflow.
- This change ensures that all necessary submodules are available during the format check process, improving build reliability.

* [CI] Add Git Submodule Update Step to Format Check

- Included a command to initialize and update git submodules recursively in the CI workflow during the format check process.
- This enhancement ensures that all required submodules are available, contributing to improved build reliability.

* [Refactor] Update Function Signatures in AtomicAddVectorize

- Modified the `VectorizeAtomicAdd` function signature to use `const` references for `thread_var` and `thread_bounds`, enhancing performance and code clarity.
- This change aligns with previous refactoring efforts to improve parameter handling and consistency across the codebase.
parent 471cc7f8
Checks: > Checks: >
# 1. Retained categories: easier to find bugs/performance issues
clang-analyzer-*, clang-analyzer-*,
cppcoreguidelines-*, cppcoreguidelines-pro-type-static-cast-downcast,
modernize-*, cppcoreguidelines-pro-type-member-init,
cppcoreguidelines-pro-bounds-array-to-pointer-decay,
cppcoreguidelines-pro-bounds-pointer-arithmetic,
cppcoreguidelines-slicing,
cppcoreguidelines-narrowing-conversions,
performance-*, performance-*,
readability-*,
-readability-identifier-length # 2. Readability: only keep useful rules
readability-braces-around-statements,
readability-container-size-empty,
readability-delete-null-pointer,
readability-redundant-member-init,
readability-redundant-smartptr-get,
readability-redundant-string-cstr,
# 3. Disable all intrusive/style-breaking rules
-readability-identifier-length,
-readability-avoid-const-params-in-decls,
-readability-else-after-return,
-cppcoreguidelines-avoid-magic-numbers,
-modernize-use-trailing-return-type,
-modernize-use-nodiscard,
-modernize-use-auto,
-modernize-pass-by-value,
-modernize-return-braced-init-list,
-modernize-use-default-member-init,
-modernize-loop-convert,
-modernize-concat-nested-namespaces,
-llvm-include-order,
-bugprone-unused-return-value,
-clang-diagnostic-unused-result,
-cppcoreguidelines-special-member-functions,
-performance-noexcept-move-constructor,
-cppcoreguidelines-narrowing-conversions,
-clang-diagnostic-error,
-cppcoreguidelines-pro-type-member-init,
-clang-analyzer-optin.cplusplus.UninitializedObject,
-cppcoreguidelines-pro-type-static-cast-downcast,
-performance-unnecessary-value-param,
WarningsAsErrors: '*' WarningsAsErrors: '*'
HeaderFilterRegex: '^(?!.*(3rdparty|build)).*$' HeaderFilterRegex: '^(?!.*(3rdparty|build)).*$'
\ No newline at end of file
...@@ -48,6 +48,9 @@ jobs: ...@@ -48,6 +48,9 @@ jobs:
- name: Run format check - name: Run format check
run: | run: |
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
git submodule update --init --recursive
mkdir -p build
cd build; cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DUSE_ROCM=ON; cd ..
if ! output=$(./format.sh 2>&1); then if ! output=$(./format.sh 2>&1); then
echo "------------------------------------" echo "------------------------------------"
echo "message:" echo "message:"
...@@ -56,6 +59,7 @@ jobs: ...@@ -56,6 +59,7 @@ jobs:
echo "------------------------------------" echo "------------------------------------"
exit 1 exit 1
fi fi
rm -rf build
- name: Commit and Push Changes - name: Commit and Push Changes
uses: stefanzweifel/git-auto-commit-action@v5 uses: stefanzweifel/git-auto-commit-action@v5
......
...@@ -47,6 +47,10 @@ jobs: ...@@ -47,6 +47,10 @@ jobs:
- name: Run format check - name: Run format check
run: | run: |
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
git submodule update --init --recursive
mkdir -p build
# run cmake to create the build directory with compile_commands.json
cd build; cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DUSE_CUDA=ON; cd ..
if ! output=$(./format.sh 2>&1); then if ! output=$(./format.sh 2>&1); then
echo "------------------------------------" echo "------------------------------------"
echo "message:" echo "message:"
...@@ -55,6 +59,7 @@ jobs: ...@@ -55,6 +59,7 @@ jobs:
echo "------------------------------------" echo "------------------------------------"
exit 1 exit 1
fi fi
rm -rf build
- name: Commit and Push Changes - name: Commit and Push Changes
uses: stefanzweifel/git-auto-commit-action@v5 uses: stefanzweifel/git-auto-commit-action@v5
......
...@@ -249,6 +249,73 @@ else ...@@ -249,6 +249,73 @@ else
fi fi
echo 'tile-lang clang-format: Done' echo 'tile-lang clang-format: Done'
echo 'tile-lang clang-tidy: Check Start'
# If clang-tidy is available, run it; otherwise, skip
if command -v run-clang-tidy &>/dev/null; then
# Check if clang-tidy is available
if ! command -v clang-tidy &>/dev/null; then
echo "clang-tidy not found. Skipping clang-tidy checks."
else
# Get clang-tidy version
CLANG_TIDY_VERSION=$(clang-tidy --version | head -n1 | awk '{print $4}')
echo "Using clang-tidy version: $CLANG_TIDY_VERSION"
# Check if build directory exists
if [ ! -d "build" ]; then
echo "Build directory not found. Skipping clang-tidy checks."
else
# Run clang-tidy on specified files
clang_tidy_files() {
run-clang-tidy -j 64 "$@" -p build
}
# Run clang-tidy on all C/C++ source files
clang_tidy_all() {
run-clang-tidy -j 64 src/*.cc -p build
}
# Run clang-tidy on changed C/C++ files relative to main
clang_tidy_changed() {
if git show-ref --verify --quiet refs/remotes/origin/main; then
BASE_BRANCH="origin/main"
else
BASE_BRANCH="main"
fi
MERGEBASE="$(git merge-base $BASE_BRANCH HEAD)"
# Get changed C/C++ files
CHANGED_FILES=$(git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.c' '*.cc' '*.cpp' '*.h' '*.hpp' 2>/dev/null || true)
if [ -n "$CHANGED_FILES" ]; then
echo "Running clang-tidy on changed files:"
echo "$CHANGED_FILES"
# Convert newline-separated files to space-separated and run clang-tidy once
CHANGED_FILES_SPACE=$(echo "$CHANGED_FILES" | tr '\n' ' ')
run-clang-tidy -j 64 $CHANGED_FILES_SPACE -p build -fix
else
echo "No C/C++ files changed. Skipping clang-tidy."
fi
}
if [[ "$1" == '--files' ]]; then
# If --files is given, run clang-tidy only on the provided files
clang_tidy_files "${@:2}"
elif [[ "$1" == '--all' ]]; then
# If --all is given, run clang-tidy on all source files
clang_tidy_all
else
# Otherwise, run clang-tidy only on changed C/C++ files
clang_tidy_changed
fi
fi
fi
else
echo "run-clang-tidy not found. Skipping clang-tidy checks."
echo "To install clang-tidy tools, you may need to install clang-tidy and run-clang-tidy."
fi
echo 'tile-lang clang-tidy: Done'
# Check if there are any uncommitted changes after all formatting steps. # Check if there are any uncommitted changes after all formatting steps.
# If there are, ask the user to review and stage them. # If there are, ask the user to review and stage them.
if ! git diff --quiet &>/dev/null; then if ! git diff --quiet &>/dev/null; then
......
...@@ -5,3 +5,4 @@ tomli==2.0.1 ...@@ -5,3 +5,4 @@ tomli==2.0.1
ruff==0.6.5 ruff==0.6.5
codespell==2.3.0 codespell==2.3.0
clang-format==15.0.7 clang-format==15.0.7
clang-tidy==18.1.8
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/reflection/registry.h>
#include <tvm/script/ir_builder/tir/ir.h> #include <tvm/script/ir_builder/tir/ir.h>
#include <utility>
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -19,8 +21,8 @@ using namespace script::ir_builder::tir; ...@@ -19,8 +21,8 @@ using namespace script::ir_builder::tir;
static Var CreateEnvThread(String name, String thread_tag, DataType dtype) { static Var CreateEnvThread(String name, String thread_tag, DataType dtype) {
using namespace tvm::tir; using namespace tvm::tir;
using namespace tvm::script::ir_builder; using namespace tvm::script::ir_builder;
IterVar iter_var(Range{nullptr}, Var(name, dtype), IterVar iter_var(Range{nullptr}, Var(std::move(name), dtype),
tvm::tir::IterVarType::kThreadIndex, thread_tag); tvm::tir::IterVarType::kThreadIndex, std::move(thread_tag));
Var var = iter_var->var; Var var = iter_var->var;
if (Optional<PrimFuncFrame> opt_frame = if (Optional<PrimFuncFrame> opt_frame =
IRBuilder::Current()->FindFrame<PrimFuncFrame>()) { IRBuilder::Current()->FindFrame<PrimFuncFrame>()) {
...@@ -31,15 +33,15 @@ static Var CreateEnvThread(String name, String thread_tag, DataType dtype) { ...@@ -31,15 +33,15 @@ static Var CreateEnvThread(String name, String thread_tag, DataType dtype) {
return var; return var;
} }
static ForFrame MakeIterVarFrame(std::string name, PrimExpr dom) { static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) {
using namespace tvm::tir; using namespace tvm::tir;
Var var = Var(name, dom->dtype); Var var = Var(name, dom->dtype);
// Create a frame that represents a loop over the given domain. // Create a frame that represents a loop over the given domain.
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
n->vars.push_back(var); n->vars.push_back(var);
n->doms.push_back(Range(0, dom)); n->doms.push_back(Range(0, dom));
n->f_make_for_loop = [](Array<Var> vars, Array<Range> doms, n->f_make_for_loop = [](const Array<Var> &vars, const Array<Range> &doms,
Stmt body) -> Stmt { const Stmt &body) -> Stmt {
ICHECK_EQ(vars.size(), 1); ICHECK_EQ(vars.size(), 1);
ICHECK_EQ(doms.size(), 1); ICHECK_EQ(doms.size(), 1);
return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body); return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body);
...@@ -47,8 +49,8 @@ static ForFrame MakeIterVarFrame(std::string name, PrimExpr dom) { ...@@ -47,8 +49,8 @@ static ForFrame MakeIterVarFrame(std::string name, PrimExpr dom) {
return ForFrame(n); return ForFrame(n);
} }
ForFrame ParallelFor(Array<PrimExpr> extents, ForFrame ParallelFor(const Array<PrimExpr> &extents,
Map<String, ObjectRef> annotations) { const Map<String, ObjectRef> &annotations) {
using namespace tvm::tir; using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
n->vars.reserve(extents.size()); n->vars.reserve(extents.size());
...@@ -58,32 +60,33 @@ ForFrame ParallelFor(Array<PrimExpr> extents, ...@@ -58,32 +60,33 @@ ForFrame ParallelFor(Array<PrimExpr> extents,
n->vars.push_back(Var("v", extent.dtype())); n->vars.push_back(Var("v", extent.dtype()));
n->doms.push_back(Range(make_const(dtype, 0), extent)); n->doms.push_back(Range(make_const(dtype, 0), extent));
} }
n->f_make_for_loop = [annotations](Array<Var> vars, Array<Range> doms, n->f_make_for_loop = [annotations](const Array<Var> &vars,
const Array<Range> &doms,
Stmt body) -> Stmt { Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), doms.size()); ICHECK_EQ(vars.size(), doms.size());
int n = vars.size(); int n = vars.size();
for (int i = n - 1; i >= 0; --i) { for (int i = n - 1; i >= 0; --i) {
Range dom = doms[i]; Range dom = doms[i];
Var var = vars[i]; Var var = vars[i];
body = body = For(var, dom->min, dom->extent, ForKind::kParallel, body,
For(var, dom->min, dom->extent, ForKind::kParallel, std::move(body), /*thread_binding=*/std::nullopt, /*annotations=*/annotations);
/*thread_binding=*/std::nullopt, /*annotations=*/annotations);
} }
return body; return body;
}; };
return ForFrame(n); return ForFrame(n);
} }
ForFrame PipelinedFor(PrimExpr start, PrimExpr stop, int num_stages, ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages,
Array<PrimExpr> order, Array<PrimExpr> stages, const Array<PrimExpr> &order,
Array<Array<PrimExpr>> sync, const Array<PrimExpr> &stages,
Array<Array<PrimExpr>> groups) { const Array<Array<PrimExpr>> &sync,
const Array<Array<PrimExpr>> &groups) {
using namespace tvm::tir; using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
DataType dtype = stop.dtype(); DataType dtype = stop.dtype();
n->vars.push_back(Var("v", dtype)); n->vars.push_back(Var("v", dtype));
n->doms.push_back(Range(start, stop)); n->doms.push_back(Range(std::move(start), stop));
n->f_make_for_loop = [=](Array<Var> vars, Array<Range> doms, n->f_make_for_loop = [=](const Array<Var> &vars, const Array<Range> &doms,
Stmt body) -> Stmt { Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), doms.size()); ICHECK_EQ(vars.size(), doms.size());
int n = vars.size(); int n = vars.size();
...@@ -91,26 +94,25 @@ ForFrame PipelinedFor(PrimExpr start, PrimExpr stop, int num_stages, ...@@ -91,26 +94,25 @@ ForFrame PipelinedFor(PrimExpr start, PrimExpr stop, int num_stages,
Map<String, ObjectRef> anno; Map<String, ObjectRef> anno;
if (num_stages > 0) if (num_stages > 0)
anno.Set("num_stages", PrimExpr(num_stages)); anno.Set("num_stages", PrimExpr(num_stages));
if (order.size() > 0) if (!order.empty())
anno.Set("tl_pipeline_order", order); anno.Set("tl_pipeline_order", order);
if (stages.size() > 0) if (!stages.empty())
anno.Set("tl_pipeline_stage", stages); anno.Set("tl_pipeline_stage", stages);
if (sync.size() > 0) if (!sync.empty())
anno.Set("tl_pipeline_sync", sync); anno.Set("tl_pipeline_sync", sync);
if (groups.size() > 0) if (!groups.empty())
anno.Set("tl_pipeline_group", groups); anno.Set("tl_pipeline_group", groups);
body = For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body = For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body,
std::move(body),
/*thread_binding=*/std::nullopt, /*annotations=*/anno); /*thread_binding=*/std::nullopt, /*annotations=*/anno);
return body; return body;
}; };
return ForFrame(n); return ForFrame(n);
} }
ForFrame PersistentFor(Array<PrimExpr> domain, PrimExpr wave_size, ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size,
PrimExpr index, PrimExpr group_size) { const PrimExpr &index, PrimExpr group_size) {
using namespace tvm::tir; using namespace tvm::tir;
ICHECK(domain.size() > 0); ICHECK(!domain.empty());
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
n->vars.reserve(domain.size()); n->vars.reserve(domain.size());
n->doms.reserve(domain.size()); n->doms.reserve(domain.size());
...@@ -139,8 +141,8 @@ ForFrame PersistentFor(Array<PrimExpr> domain, PrimExpr wave_size, ...@@ -139,8 +141,8 @@ ForFrame PersistentFor(Array<PrimExpr> domain, PrimExpr wave_size,
} }
grouped_domain.push_back(group_size); grouped_domain.push_back(group_size);
n->f_make_for_loop = [=](Array<Var> vars, Array<Range> doms, n->f_make_for_loop = [=](const Array<Var> &vars, const Array<Range> &doms,
Stmt body) -> Stmt { const Stmt &body) -> Stmt {
ICHECK_EQ(vars.size(), doms.size()); ICHECK_EQ(vars.size(), doms.size());
Map<String, ObjectRef> anno; Map<String, ObjectRef> anno;
Array<PrimExpr> idxs(grouped_domain.size(), PrimExpr()); Array<PrimExpr> idxs(grouped_domain.size(), PrimExpr());
...@@ -220,9 +222,9 @@ public: ...@@ -220,9 +222,9 @@ public:
KernelLaunchFrameNode); KernelLaunchFrameNode);
}; };
KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size, KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size,
Optional<Array<PrimExpr>> block_size_opt, const Optional<Array<PrimExpr>> &block_size_opt,
Map<String, ffi::Any> attrs) { const Map<String, ffi::Any> &attrs) {
ObjectPtr<KernelLaunchFrameNode> n = make_object<KernelLaunchFrameNode>(); ObjectPtr<KernelLaunchFrameNode> n = make_object<KernelLaunchFrameNode>();
// If the kernel is a CPU kernel, we don't need to launch any threads. // If the kernel is a CPU kernel, we don't need to launch any threads.
...@@ -234,7 +236,7 @@ KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size, ...@@ -234,7 +236,7 @@ KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size,
if (is_cpu_kernel_frame) { if (is_cpu_kernel_frame) {
// Launch CPU Kernel // Launch CPU Kernel
ICHECK(grid_size.size() >= 0); ICHECK(grid_size.size() >= 0);
ICHECK(block_size.size() == 0) << "CPU kernel cannot have block size"; ICHECK(block_size.empty()) << "CPU kernel cannot have block size";
ICHECK(attrs.defined()); ICHECK(attrs.defined());
// create grid loop var // create grid loop var
for (int i = 0; i < grid_size.size(); i++) { for (int i = 0; i < grid_size.size(); i++) {
...@@ -244,7 +246,7 @@ KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size, ...@@ -244,7 +246,7 @@ KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size,
} else { } else {
// Launch GPU Kernel // Launch GPU Kernel
ICHECK(grid_size.size() <= 3); ICHECK(grid_size.size() <= 3);
if (grid_size.size() > 0) if (!grid_size.empty())
n->frames.push_back(LaunchThread( n->frames.push_back(LaunchThread(
CreateEnvThread("bx", "blockIdx.x", grid_size[0].dtype()), CreateEnvThread("bx", "blockIdx.x", grid_size[0].dtype()),
grid_size[0])); grid_size[0]));
...@@ -258,7 +260,7 @@ KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size, ...@@ -258,7 +260,7 @@ KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size,
grid_size[2])); grid_size[2]));
if (block_size.defined()) { if (block_size.defined()) {
ICHECK(block_size.size() <= 3); ICHECK(block_size.size() <= 3);
if (block_size.size() > 0) { if (!block_size.empty()) {
n->frames.push_back(LaunchThread( n->frames.push_back(LaunchThread(
CreateEnvThread("tx", "threadIdx.x", block_size[0].dtype()), CreateEnvThread("tx", "threadIdx.x", block_size[0].dtype()),
block_size[0])); block_size[0]));
...@@ -333,12 +335,13 @@ public: ...@@ -333,12 +335,13 @@ public:
WarpSpecializeFrameNode); WarpSpecializeFrameNode);
}; };
WarpSpecializeFrame WarpSpecialize(Array<IntImm> warp_group_ids, WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids,
PrimExpr thread_idx, const PrimExpr &thread_idx,
int warp_group_size = 128) { int warp_group_size = 128) {
ObjectPtr<WarpSpecializeFrameNode> n = make_object<WarpSpecializeFrameNode>(); ObjectPtr<WarpSpecializeFrameNode> n = make_object<WarpSpecializeFrameNode>();
PrimExpr condition; PrimExpr condition;
std::vector<int> warp_groups; std::vector<int> warp_groups;
warp_groups.reserve(warp_group_ids.size());
for (int i = 0; i < warp_group_ids.size(); i++) { for (int i = 0; i < warp_group_ids.size(); i++) {
warp_groups.push_back(Downcast<IntImm>(warp_group_ids[i])->value); warp_groups.push_back(Downcast<IntImm>(warp_group_ids[i])->value);
} }
......
...@@ -90,8 +90,6 @@ using namespace tir; ...@@ -90,8 +90,6 @@ using namespace tir;
class AtomicAddNode : public TileOperatorNode { class AtomicAddNode : public TileOperatorNode {
public: public:
Array<PrimExpr> args_;
Buffer src, dst; Buffer src, dst;
Array<Range> src_range, dst_range; Array<Range> src_range, dst_range;
IntImm coalesced_width; IntImm coalesced_width;
......
...@@ -21,7 +21,7 @@ using namespace tir; ...@@ -21,7 +21,7 @@ using namespace tir;
/*! /*!
* \brief Copy instruction type. * \brief Copy instruction type.
*/ */
enum class CopyInst { enum class CopyInst : uint8_t {
kNormal = 0, // utilize ldg/stg or cpasync or any buffer copy kNormal = 0, // utilize ldg/stg or cpasync or any buffer copy
kLDSM = 1, // ldmatrix memory copy kLDSM = 1, // ldmatrix memory copy
kSTSM = 2, // stmatrix memory copy kSTSM = 2, // stmatrix memory copy
...@@ -307,8 +307,6 @@ struct TMAIm2ColDesc { ...@@ -307,8 +307,6 @@ struct TMAIm2ColDesc {
*/ */
class CopyNode : public TileOperatorNode { class CopyNode : public TileOperatorNode {
public: public:
Array<PrimExpr> args_; // Copy parameters (indices, sizes, etc.)
Buffer src, dst; // Source and destination buffers Buffer src, dst; // Source and destination buffers
Array<Range> src_range, dst_range; // Ranges for each dimension in src and dst Array<Range> src_range, dst_range; // Ranges for each dimension in src and dst
IntImm coalesced_width; // Width (in elements) for coalesced memory access IntImm coalesced_width; // Width (in elements) for coalesced memory access
...@@ -316,13 +314,13 @@ public: ...@@ -316,13 +314,13 @@ public:
mutable ParallelOp par_op_; // Optional associated parallelization operator mutable ParallelOp par_op_; // Optional associated parallelization operator
enum class EvictionPolicy { enum class EvictionPolicy : uint8_t {
kEvictNormal = 0, kEvictNormal = 0,
kEvictFirst = 1, kEvictFirst = 1,
kEvictLast = 2, kEvictLast = 2,
}; };
int eviction_policy; // Policy for cache eviction uint8_t eviction_policy; // Policy for cache eviction
static constexpr const char *_type_key = "tl.Copy"; static constexpr const char *_type_key = "tl.Copy";
TVM_DECLARE_FINAL_OBJECT_INFO(CopyNode, TileOperatorNode); TVM_DECLARE_FINAL_OBJECT_INFO(CopyNode, TileOperatorNode);
......
...@@ -82,7 +82,7 @@ namespace tl { ...@@ -82,7 +82,7 @@ namespace tl {
using namespace tir; using namespace tir;
enum class GemmWarpPolicy { enum class GemmWarpPolicy : uint8_t {
kSquare = 0, kSquare = 0,
kFullRow = 1, kFullRow = 1,
kFullCol = 2, kFullCol = 2,
...@@ -117,7 +117,7 @@ public: ...@@ -117,7 +117,7 @@ public:
private: private:
// Target GEMM instruction // Target GEMM instruction
enum class GemmInst { kMMA, kWGMMA, kUTCMMA, kMFMA }; enum class GemmInst : uint8_t { kMMA, kWGMMA, kUTCMMA, kMFMA };
GemmInst GetGemmInst(int block_size, Target target) const; GemmInst GetGemmInst(int block_size, Target target) const;
std::pair<int, int> ComputeWarpPartition(int num_warps, GemmInst gemm_inst, std::pair<int, int> ComputeWarpPartition(int num_warps, GemmInst gemm_inst,
......
...@@ -72,7 +72,7 @@ class GemmSPNode : public TileOperatorNode { ...@@ -72,7 +72,7 @@ class GemmSPNode : public TileOperatorNode {
public: public:
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
enum class GemmWarpPolicy { enum class GemmWarpPolicy : uint8_t {
kSquare = 0, kSquare = 0,
kFullRow = 1, kFullRow = 1,
kFullCol = 2, kFullCol = 2,
......
...@@ -25,7 +25,7 @@ using AddWorkspaceCallback = std::function<PrimExpr(int, DataType)>; ...@@ -25,7 +25,7 @@ using AddWorkspaceCallback = std::function<PrimExpr(int, DataType)>;
using LayoutMap = Map<Buffer, Layout>; using LayoutMap = Map<Buffer, Layout>;
using BufferMap = Map<Var, Buffer>; using BufferMap = Map<Var, Buffer>;
enum class InferLevel { enum class InferLevel : uint8_t {
kFree = 0, kFree = 0,
kCommon = 1, kCommon = 1,
kStrict = 2, kStrict = 2,
...@@ -51,91 +51,6 @@ struct LayoutInferArgs { ...@@ -51,91 +51,6 @@ struct LayoutInferArgs {
class TileOperatorNode; class TileOperatorNode;
class TileOperator; class TileOperator;
/**
* Abstract base class for tile-level operators.
*
* Implementations must provide lowering to TIR, layout inference, and cloning.
*/
/**
* Lower this tile operator to a TIR statement.
*
* @param T Lowering context and utilities (target, thread bounds, layout
* mappings, buffer remapping, and AddWorkspace callback for requesting
* temporary buffers).
* @param analyzer Arithmetic analyzer used during lowering.
* @return A TIR Stmt representing the lowered operator.
*/
/**
* Infer buffer layouts for this operator.
*
* The returned LayoutMap associates input/output Buffers with inferred Layouts.
* The `level` controls how strictly layouts are determined (kFree, kCommon,
* kStrict).
*
* @param T Layout inference context (target, thread bounds, existing
* layout_map, buffer_remap).
* @param level Inference strictness level.
* @return A LayoutMap mapping Buffers to their inferred Layouts.
*/
/**
* Create a deep copy of this TileOperator.
*
* @return A TileOperator referencing a cloned operator instance.
*/
/**
* Reference wrapper for TileOperatorNode.
*
* Use this ObjectRef to hold and pass tile operator instances within the
* runtime.
*/
/**
* Extract the underlying Var from an access pointer expression.
*
* If `expr` represents an access pointer that directly refers to a variable,
* returns that Var; otherwise returns a null/default Var.
*
* @param expr The pointer/access expression to inspect.
* @return The extracted Var, or a null Var if none can be found.
*/
/**
* Parse a Call into a TileOperator using the provided buffer mapping.
*
* @param call The Call node representing a tile operator invocation.
* @param vmap Mapping from TIR Vars to Buffers for resolving buffer arguments.
* @return A TileOperator constructed from the call and buffer map.
*/
/**
* Parse a Stmt into a TileOperator using the provided buffer mapping.
*
* @param stmt The Stmt representing a tile operator region or call.
* @param vmap Mapping from TIR Vars to Buffers for resolving buffer references.
* @return A TileOperator constructed from the statement and buffer map.
*/
/**
* Function type for TL operator builders exposed to the FFI.
*
* Builder functions take an array of PrimExpr arguments and a BufferMap, and
* return a constructed TileOperator.
*/
/**
* Register a TL operator and its builder with TVM's op registry.
*
* Entry should be a type providing a static `Get()` and a constructor taking
* `(Array<PrimExpr>, BufferMap)`. This macro registers the operator under the
* name "tl.OpName" and sets an FFI builder attribute that constructs
* Entry(args, vmap).
*
* Usage: TIR_REGISTER_TL_OP(MyOpEntry, MyOp)
*/
class TileOperatorNode : public Object { class TileOperatorNode : public Object {
public: public:
virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const = 0; virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const = 0;
......
...@@ -184,7 +184,7 @@ public: ...@@ -184,7 +184,7 @@ public:
Optional<PrimExpr> GetPredicate(Var thread_var) const; Optional<PrimExpr> GetPredicate(Var thread_var) const;
// Clone this operator. // Clone this operator.
TileOperator Clone() const; TileOperator Clone() const override;
private: private:
// Complete the fragment layout for a given buffer. // Complete the fragment layout for a given buffer.
...@@ -192,7 +192,7 @@ private: ...@@ -192,7 +192,7 @@ private:
// Check if the buffer is accessed with common indices (i.e., loop variables). // Check if the buffer is accessed with common indices (i.e., loop variables).
bool IsCommonAccessIndice(const Buffer &buffer) const; bool IsCommonAccessIndice(const Buffer &buffer) const;
// Add a predicate to the current predicate expression. // Add a predicate to the current predicate expression.
void AddPredicate(PrimExpr expr) const { void AddPredicate(const PrimExpr &expr) const {
predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr; predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
} }
// Allow ParallelLoopNestVisitor to access private members. // Allow ParallelLoopNestVisitor to access private members.
...@@ -218,7 +218,7 @@ class ParallelOp : public TileOperator { ...@@ -218,7 +218,7 @@ class ParallelOp : public TileOperator {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(ParallelOp, TileOperator, ParallelOpNode); TVM_DEFINE_OBJECT_REF_METHODS(ParallelOp, TileOperator, ParallelOpNode);
ParallelOp(For root) { ParallelOp(const For &root) {
auto op = make_object<ParallelOpNode>(root); auto op = make_object<ParallelOpNode>(root);
data_ = std::move(op); data_ = std::move(op);
} }
......
...@@ -154,7 +154,7 @@ namespace tl { ...@@ -154,7 +154,7 @@ namespace tl {
using namespace tir; using namespace tir;
enum class ReduceType { enum class ReduceType : uint8_t {
kSum, kSum,
kAbsSum, kAbsSum,
kMax, kMax,
......
...@@ -92,7 +92,7 @@ public: ...@@ -92,7 +92,7 @@ public:
int GetAccessMask() const { return access_mask_; } int GetAccessMask() const { return access_mask_; }
bool IsFullRegion() const; bool IsFullRegion() const;
TileOperator Clone() const; TileOperator Clone() const override;
}; };
class RegionOp : public TileOperator { class RegionOp : public TileOperator {
......
...@@ -25,7 +25,7 @@ public: ...@@ -25,7 +25,7 @@ public:
explicit TileLangAlignDynamicSharedMemoryAllocations(int align_bytes) explicit TileLangAlignDynamicSharedMemoryAllocations(int align_bytes)
: align_bytes_(align_bytes) {} : align_bytes_(align_bytes) {}
static Stmt Substitute(int align_bytes, Stmt stmt) { static Stmt Substitute(int align_bytes, const Stmt &stmt) {
TileLangAlignDynamicSharedMemoryAllocations smem_rewriter(align_bytes); TileLangAlignDynamicSharedMemoryAllocations smem_rewriter(align_bytes);
return smem_rewriter.VisitStmt(stmt); return smem_rewriter.VisitStmt(stmt);
} }
...@@ -138,7 +138,8 @@ private: ...@@ -138,7 +138,8 @@ private:
tvm::transform::Pass AlignDynamicSharedMemoryAllocations(int align_bytes) { tvm::transform::Pass AlignDynamicSharedMemoryAllocations(int align_bytes) {
using namespace tir::transform; using namespace tir::transform;
auto pass_func = [align_bytes](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [align_bytes](PrimFunc f, const IRModule &m,
const PassContext &ctx) {
auto *n = f.CopyOnWrite(); auto *n = f.CopyOnWrite();
n->body = TileLangAlignDynamicSharedMemoryAllocations::Substitute( n->body = TileLangAlignDynamicSharedMemoryAllocations::Substitute(
align_bytes, n->body); align_bytes, n->body);
......
...@@ -31,6 +31,8 @@ ...@@ -31,6 +31,8 @@
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <utility>
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -39,7 +41,7 @@ using namespace tir; ...@@ -39,7 +41,7 @@ using namespace tir;
class DeviceRegionAnnotater : public StmtMutator { class DeviceRegionAnnotater : public StmtMutator {
public: public:
explicit DeviceRegionAnnotater(Target device_target) explicit DeviceRegionAnnotater(Target device_target)
: device_target_(device_target) {} : device_target_(std::move(device_target)) {}
Stmt VisitStmt_(const AttrStmtNode *op) final { Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tvm::attr::kTarget) { if (op->attr_key == tvm::attr::kTarget) {
...@@ -64,8 +66,8 @@ private: ...@@ -64,8 +66,8 @@ private:
tvm::transform::Pass AnnotateDeviceRegions() { tvm::transform::Pass AnnotateDeviceRegions() {
using namespace tir::transform; using namespace tir::transform;
auto pass_func = [](PrimFunc func, IRModule mod, auto pass_func = [](PrimFunc func, const IRModule &mod,
tvm::transform::PassContext ctx) -> PrimFunc { const tvm::transform::PassContext &ctx) -> PrimFunc {
auto opt_target = func->GetAttr<Target>(tvm::attr::kTarget); auto opt_target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(opt_target) << "AnnotateDeviceRegions: Require the target attribute"; ICHECK(opt_target) << "AnnotateDeviceRegions: Require the target attribute";
Target target = opt_target.value(); Target target = opt_target.value();
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "../op/builtin.h" #include "../op/builtin.h"
...@@ -32,8 +33,8 @@ private: ...@@ -32,8 +33,8 @@ private:
void VisitStmt_(const EvaluateNode *op) final { void VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) { if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(set_max_nreg())) { if (call->op.same_as(set_max_nreg())) {
int reg_hint = call->args[0].as<IntImmNode>()->value; auto reg_hint = call->args[0].as<IntImmNode>()->value;
int is_inc = call->args[1].as<IntImmNode>()->value; auto is_inc = call->args[1].as<IntImmNode>()->value;
ICHECK(reg_hint <= 240 && reg_hint >= 24) ICHECK(reg_hint <= 240 && reg_hint >= 24)
<< "Invalid reg hint: " << reg_hint; << "Invalid reg hint: " << reg_hint;
ICHECK(is_inc == 0 || is_inc == 1) << "Invalid is_inc: " << is_inc; ICHECK(is_inc == 0 || is_inc == 1) << "Invalid is_inc: " << is_inc;
...@@ -97,8 +98,8 @@ private: ...@@ -97,8 +98,8 @@ private:
Optional<Stmt> consumer_body = if_then_else->else_case; Optional<Stmt> consumer_body = if_then_else->else_case;
ICHECK(consumer_body.defined()) << "Consumer body is undefined"; ICHECK(consumer_body.defined()) << "Consumer body is undefined";
int dec_reg = nreg_[0].as<IntImmNode>()->value; auto dec_reg = nreg_[0].as<IntImmNode>()->value;
int inc_reg = nreg_[1].as<IntImmNode>()->value; auto inc_reg = nreg_[1].as<IntImmNode>()->value;
auto inc_reg_stmt = Evaluate(0); auto inc_reg_stmt = Evaluate(0);
auto dec_reg_stmt = Evaluate(0); auto dec_reg_stmt = Evaluate(0);
...@@ -109,10 +110,14 @@ private: ...@@ -109,10 +110,14 @@ private:
bool has_simt_copy = false; // Placeholder bool has_simt_copy = false; // Placeholder
if (dec_reg >= 0 && inc_reg >= 0 && !has_simt_copy) { if (dec_reg >= 0 && inc_reg >= 0 && !has_simt_copy) {
inc_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(), auto inc_reg_num =
{inc_reg == 0 ? 240 : inc_reg, 1})); IntImm(DataType::Int(32), inc_reg == 0 ? 240 : inc_reg);
dec_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(), auto dec_reg_num =
{dec_reg == 0 ? 24 : dec_reg, 0})); IntImm(DataType::Int(32), dec_reg == 0 ? 24 : dec_reg);
inc_reg_stmt = Evaluate(
Call(DataType::Handle(), set_max_nreg(), {inc_reg_num, 1}));
dec_reg_stmt = Evaluate(
Call(DataType::Handle(), set_max_nreg(), {dec_reg_num, 0}));
} }
// Inject register setting statements // Inject register setting statements
...@@ -145,8 +150,9 @@ private: ...@@ -145,8 +150,9 @@ private:
using namespace tir::transform; using namespace tir::transform;
tvm::transform::Pass AnnotateWarpGroupRegAlloc() { tvm::transform::Pass AnnotateWarpGroupRegAlloc() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) -> PrimFunc { auto pass_func = [](PrimFunc f, const IRModule &m,
return SetMaxNRegInjector::Inject(f); const PassContext &ctx) -> PrimFunc {
return SetMaxNRegInjector::Inject(std::move(f));
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateWarpGroupRegAlloc", {}); return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateWarpGroupRegAlloc", {});
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <utility>
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -35,8 +36,8 @@ public: ...@@ -35,8 +36,8 @@ public:
AtomicAddVectorizePlanResult Plan(const For &node, Var thread_var, AtomicAddVectorizePlanResult Plan(const For &node, Var thread_var,
Range thread_bounds, int vectorize_hint) { Range thread_bounds, int vectorize_hint) {
this->max_vector_size = vectorize_hint; this->max_vector_size = vectorize_hint;
this->thread_var = thread_var; this->thread_var = std::move(thread_var);
this->thread_bounds = thread_bounds; this->thread_bounds = std::move(thread_bounds);
this->operator()(node); this->operator()(node);
return {vector_size_, dynamic_, condition_}; return {vector_size_, dynamic_, condition_};
} }
...@@ -79,7 +80,7 @@ private: ...@@ -79,7 +80,7 @@ private:
return arith::IRVisitorWithAnalyzer::VisitExpr_(node); return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
} }
void UpdateVectorSize(const Array<PrimExpr> indices, const Buffer &buffer) { void UpdateVectorSize(const Array<PrimExpr> &indices, const Buffer &buffer) {
if (!inner_for_) if (!inner_for_)
return; return;
auto extent_ptr = inner_for_->extent.as<IntImmNode>(); auto extent_ptr = inner_for_->extent.as<IntImmNode>();
...@@ -141,12 +142,14 @@ private: ...@@ -141,12 +142,14 @@ private:
class AtomicAddVectorizeRewriter : public StmtExprMutator { class AtomicAddVectorizeRewriter : public StmtExprMutator {
public: public:
AtomicAddVectorizeRewriter(AtomicAddVectorizePlanResult plan, Var thread_var, AtomicAddVectorizeRewriter(const AtomicAddVectorizePlanResult &plan,
PrimExpr by_var, PrimExpr bx_var, Var thread_var, PrimExpr by_var, PrimExpr bx_var,
Range thread_bounds, int stride_y, int stride_x) const Range &thread_bounds, int stride_y,
int stride_x)
: vector_size_(plan.vector_size), condition_(plan.condition), : vector_size_(plan.vector_size), condition_(plan.condition),
dynamic_(plan.dynamic), tx_var_(thread_var), by_var_(by_var), dynamic_(plan.dynamic), tx_var_(std::move(thread_var)),
bx_var_(bx_var), stride_y_(stride_y), stride_x_(stride_x) { by_var_(std::move(by_var)), bx_var_(std::move(bx_var)),
stride_y_(stride_y), stride_x_(stride_x) {
const int64_t *tx_ext = as_const_int(thread_bounds->extent); const int64_t *tx_ext = as_const_int(thread_bounds->extent);
ICHECK(tx_ext) ICHECK(tx_ext)
<< "thread_bounds->extent must be a constant for vectorization."; << "thread_bounds->extent must be a constant for vectorization.";
...@@ -324,8 +327,8 @@ static int GetVectorizeSizeMax(int compute_capability, DataType dtype) { ...@@ -324,8 +327,8 @@ static int GetVectorizeSizeMax(int compute_capability, DataType dtype) {
return 1; return 1;
} }
For VectorizeAtomicAdd(const For &for_node, Var thread_var, Range thread_bounds, For VectorizeAtomicAdd(const For &for_node, const Var &thread_var,
int compute_capability) { const Range &thread_bounds, int compute_capability) {
int vectorize_size_max = 1; int vectorize_size_max = 1;
int stride_x = -1, stride_y = -1; int stride_x = -1, stride_y = -1;
...@@ -382,4 +385,4 @@ For VectorizeAtomicAdd(const For &for_node, Var thread_var, Range thread_bounds, ...@@ -382,4 +385,4 @@ For VectorizeAtomicAdd(const For &for_node, Var thread_var, Range thread_bounds,
} }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file
...@@ -14,8 +14,8 @@ namespace tl { ...@@ -14,8 +14,8 @@ namespace tl {
using namespace tir; using namespace tir;
For VectorizeAtomicAdd(const For &for_node, Var thread_var, Range thread_bounds, For VectorizeAtomicAdd(const For &for_node, const Var &thread_var,
int compute_capability); const Range &thread_bounds, int compute_capability);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
......
...@@ -66,8 +66,15 @@ public: ...@@ -66,8 +66,15 @@ public:
} }
if (mem_reuse_max > 0) { if (mem_reuse_max > 0) {
cluster_tag = std::string tag_str = cluster_tag; // Convert to std::string
"clusterIdx" + String(cluster_tag.c_str() + strlen("blockIdx")); if (tag_str.rfind("blockIdx", 0) == 0) {
// starts with "blockIdx"
tag_str = "clusterIdx" + tag_str.substr(strlen("blockIdx"));
} else {
// Unexpected format — maybe just prefix
tag_str = "clusterIdx" + tag_str;
}
cluster_tag = tvm::ffi::String(tag_str); // Convert back
return WithAttr(f, cluster_tag, Integer(cluster_size_)); return WithAttr(f, cluster_tag, Integer(cluster_size_));
} else { } else {
return f; return f;
...@@ -109,7 +116,7 @@ PrimFunc ClusterPlanning(PrimFunc f) { return ClusterPlanner::Substitute(f); } ...@@ -109,7 +116,7 @@ PrimFunc ClusterPlanning(PrimFunc f) { return ClusterPlanner::Substitute(f); }
namespace transform { namespace transform {
tvm::transform::Pass ClusterPlanning() { tvm::transform::Pass ClusterPlanning() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return ClusterPlanning(std::move(f)); return ClusterPlanning(std::move(f));
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.ClusterPlanning", {}); return CreatePrimFuncPass(pass_func, 0, "tl.ClusterPlanning", {});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment