Unverified Commit cdc5d8d3 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Introduce clang-tidy into format.sh (#777)

* [Refactor] Update Clang-Tidy Checks and Improve Code Consistency

- Enhanced .clang-tidy configuration by adding specific checks for better bug detection and performance optimization.
- Refactored function signatures across multiple files to use `const` references for parameters, improving performance and code clarity.
- Updated various methods to ensure consistent handling of parameters, particularly in `AddPredicate`, `Substitute`, and `PlanLoopPartition` functions.
- Improved readability by replacing size checks with `empty()` method calls in several locations, ensuring clearer intent in the code.
- General code cleanup and adherence to best practices for better maintainability.

* [Refactor] Enhance Code Consistency and Clang-Tidy Configuration

- Updated .clang-tidy configuration to include additional checks for improved code quality and performance.
- Refactored function signatures across multiple files to use `const` references, enhancing performance and clarity.
- Replaced size checks with `empty()` method calls in various locations for clearer intent.
- Improved handling of parameters in several functions, ensuring consistent usage of `std::move` where applicable.
- General code cleanup to adhere to best practices and improve maintainability.

* [Refactor] Integrate Clang-Tidy Checks and Enhance Code Consistency

- Added clang-tidy checks to the format script for improved code quality assurance.
- Refactored function signatures across multiple files to consistently use `const` references, enhancing performance and clarity.
- Updated the requirements-lint.txt file to include clang-tidy as a dependency.
- General code cleanup to adhere to best practices and improve maintainability.

* [CI] Update AMD CI Workflow to Include Build Directory Creation

- Added steps to create a build directory and configure CMake with ROCm support during the format check process.
- Ensured cleanup of the build directory after the format check to maintain a clean workspace.

* [Refactor] Remove Unused Member Variables in AtomicAddNode and CopyNode

- Removed the `args_` member variable from both `AtomicAddNode` and `CopyNode` classes to streamline the code and eliminate unnecessary data members.
- This change enhances code clarity and maintainability by focusing on relevant attributes for each class.

* [Refactor] Update Clang-Tidy Integration and Code Improvements

- Modified the format script to include the `-fix` option in the clang-tidy command for automatic code fixes.
- Refactored the `AtomicAddVectorizePlanner` class to improve variable handling and consistency, including changes to member variable types and function signatures.
- Enhanced code clarity by removing unnecessary `std::move` calls and ensuring consistent usage of types across the class.
- General code cleanup to adhere to best practices and improve maintainability.

* [Refactor] Improve Parameter Handling and Consistency in AtomicAddVectorize

- Updated function signatures in `AtomicAddVectorizePlanResult` and `AtomicAddVectorizeRewriter` to use `const` references and `std::move` for better performance and clarity.
- Enhanced the `UpdateVectorSize` method to accept `const Array<PrimExpr>&` for improved efficiency.
- General code cleanup to maintain consistency and adhere to best practices.

* [CI] Add Git Submodule Initialization to CI Workflow

- Included a step to initialize and update git submodules recursively in the CI workflow.
- This change ensures that all necessary submodules are available during the format check process, improving build reliability.

* [CI] Add Git Submodule Update Step to Format Check

- Included a command to initialize and update git submodules recursively in the CI workflow during the format check process.
- This enhancement ensures that all required submodules are available, contributing to improved build reliability.

* [Refactor] Update Function Signatures in AtomicAddVectorize

- Modified the `VectorizeAtomicAdd` function signature to use `const` references for `thread_var` and `thread_bounds`, enhancing performance and code clarity.
- This change aligns with previous refactoring efforts to improve parameter handling and consistency across the codebase.
parent 471cc7f8
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <numeric> #include <numeric>
#include <utility>
#include "../layout/layout.h" #include "../layout/layout.h"
#include "../layout/utils.h" #include "../layout/utils.h"
...@@ -32,7 +33,8 @@ struct VectorizePlanResult { ...@@ -32,7 +33,8 @@ struct VectorizePlanResult {
PrimExpr condition; PrimExpr condition;
}; };
bool IndiceCanVectorizeDynamic(PrimExpr expr, Var var, PrimExpr iter_var_size, bool IndiceCanVectorizeDynamic(const PrimExpr &expr, Var var,
const PrimExpr &iter_var_size,
int target_vectorized_size, int target_vectorized_size,
arith::Analyzer *analyzer) { arith::Analyzer *analyzer) {
ICHECK(target_vectorized_size >= 1); ICHECK(target_vectorized_size >= 1);
...@@ -136,7 +138,7 @@ private: ...@@ -136,7 +138,7 @@ private:
// TODO: may perform some checks here // TODO: may perform some checks here
} }
void UpdateVectorSize(const Array<PrimExpr> indices, const Buffer &buffer) { void UpdateVectorSize(const Array<PrimExpr> &indices, const Buffer &buffer) {
if (!inner_for_) if (!inner_for_)
return; return;
auto extent_ptr = inner_for_->extent.as<IntImmNode>(); auto extent_ptr = inner_for_->extent.as<IntImmNode>();
...@@ -198,7 +200,7 @@ private: ...@@ -198,7 +200,7 @@ private:
int vector_size_; int vector_size_;
const ForNode *inner_for_; const ForNode *inner_for_{};
Map<Var, Range> iter_map_; Map<Var, Range> iter_map_;
bool has_nonlocal_memory_access_ = false; bool has_nonlocal_memory_access_ = false;
// conditionally vectorize // conditionally vectorize
...@@ -210,8 +212,8 @@ class VectorizedBodyMutator : public StmtExprMutator { ...@@ -210,8 +212,8 @@ class VectorizedBodyMutator : public StmtExprMutator {
public: public:
VectorizedBodyMutator(Var inner_var, int vector_size, VectorizedBodyMutator(Var inner_var, int vector_size,
std::vector<PrimExpr> conditions) std::vector<PrimExpr> conditions)
: inner_var_(inner_var), vector_size_(vector_size), : inner_var_(std::move(inner_var)), vector_size_(vector_size),
conditions_(conditions) {} conditions_(std::move(conditions)) {}
private: private:
PrimExpr VisitExpr_(const CallNode *op) final { PrimExpr VisitExpr_(const CallNode *op) final {
...@@ -244,7 +246,7 @@ private: ...@@ -244,7 +246,7 @@ private:
class VectorizedConditionExtracter : public StmtExprVisitor { class VectorizedConditionExtracter : public StmtExprVisitor {
public: public:
VectorizedConditionExtracter() = default; VectorizedConditionExtracter() = default;
std::vector<PrimExpr> GetConditions(Stmt body) { std::vector<PrimExpr> GetConditions(const Stmt &body) {
this->VisitStmt(body); this->VisitStmt(body);
return conditions_; return conditions_;
} }
...@@ -269,7 +271,7 @@ private: ...@@ -269,7 +271,7 @@ private:
class NestedLoopChecker : public StmtExprVisitor { class NestedLoopChecker : public StmtExprVisitor {
public: public:
NestedLoopChecker() : loop_num_(0) {} NestedLoopChecker() : loop_num_(0) {}
int GetNestLoopNum(Stmt body) { int GetNestLoopNum(const Stmt &body) {
this->VisitStmt(body); this->VisitStmt(body);
return loop_num_; return loop_num_;
} }
...@@ -286,7 +288,7 @@ private: ...@@ -286,7 +288,7 @@ private:
class VectorizedConditionMutator : public StmtExprMutator { class VectorizedConditionMutator : public StmtExprMutator {
public: public:
VectorizedConditionMutator(Var inner_var, int extent) VectorizedConditionMutator(Var inner_var, int extent)
: inner_var_(inner_var), vector_size_(extent) {} : inner_var_(std::move(inner_var)), vector_size_(extent) {}
private: private:
PrimExpr VisitExpr_(const GENode *node) final { PrimExpr VisitExpr_(const GENode *node) final {
...@@ -343,7 +345,7 @@ private: ...@@ -343,7 +345,7 @@ private:
class VectorizeRewriterDynamic : public StmtExprMutator { class VectorizeRewriterDynamic : public StmtExprMutator {
public: public:
VectorizeRewriterDynamic(VectorizePlanResult plan, VectorizeRewriterDynamic(const VectorizePlanResult &plan,
bool disable_dynamic_tail_split) bool disable_dynamic_tail_split)
: vector_size_(plan.vector_size), condition_(plan.condition), : vector_size_(plan.vector_size), condition_(plan.condition),
dynamic_(plan.dynamic), dynamic_(plan.dynamic),
...@@ -396,7 +398,7 @@ private: ...@@ -396,7 +398,7 @@ private:
// Adaptively set vectorized variable to the min/max value of the extent // Adaptively set vectorized variable to the min/max value of the extent
PrimExpr condition_bound; PrimExpr condition_bound;
if (conditions.size() > 0) { if (!conditions.empty()) {
condition_bound = condition_mutator(conditions[0]); condition_bound = condition_mutator(conditions[0]);
for (int i = 1; i < conditions.size(); ++i) { for (int i = 1; i < conditions.size(); ++i) {
condition_bound = condition_bound && condition_mutator(conditions[i]); condition_bound = condition_bound && condition_mutator(conditions[i]);
...@@ -413,7 +415,7 @@ private: ...@@ -413,7 +415,7 @@ private:
For vectorize_for = For vectorize_for =
For(inner_var, 0, vector_size_, ForKind::kVectorized, vectorize_body); For(inner_var, 0, vector_size_, ForKind::kVectorized, vectorize_body);
For serial_for = For(inner_var, 0, vector_size_, ForKind::kSerial, body); For serial_for = For(inner_var, 0, vector_size_, ForKind::kSerial, body);
if (conditions.size() > 0) { if (!conditions.empty()) {
body = IfThenElse(condition_bound, vectorize_for, serial_for); body = IfThenElse(condition_bound, vectorize_for, serial_for);
} else { } else {
body = vectorize_for; body = vectorize_for;
...@@ -436,7 +438,7 @@ private: ...@@ -436,7 +438,7 @@ private:
} }
} }
const ForNode *inner_for_; const ForNode *inner_for_{};
int vector_size_; int vector_size_;
const PrimExpr condition_; const PrimExpr condition_;
const bool dynamic_; const bool dynamic_;
...@@ -484,7 +486,6 @@ private: ...@@ -484,7 +486,6 @@ private:
// non-vectorized loop // non-vectorized loop
return for_node; return for_node;
} }
int vectorize_hint = res.vector_size;
auto rewriter = VectorizeRewriterDynamic(res, disable_dynamic_tail_split_); auto rewriter = VectorizeRewriterDynamic(res, disable_dynamic_tail_split_);
return Downcast<For>(rewriter(for_node)); return Downcast<For>(rewriter(for_node));
} }
...@@ -509,7 +510,7 @@ public: ...@@ -509,7 +510,7 @@ public:
tvm::transform::Pass LoopVectorizeDynamic() { tvm::transform::Pass LoopVectorizeDynamic() {
using namespace tir::transform; using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
bool disable_dynamic_tail_split = bool disable_dynamic_tail_split =
ctx->GetConfig<Bool>(kDisableDynamicTailSplit, Bool(true)).value(); ctx->GetConfig<Bool>(kDisableDynamicTailSplit, Bool(true)).value();
int dynamic_alignment = int dynamic_alignment =
......
...@@ -356,7 +356,7 @@ namespace transform { ...@@ -356,7 +356,7 @@ namespace transform {
tvm::transform::Pass LowerDeviceKernelLaunch() { tvm::transform::Pass LowerDeviceKernelLaunch() {
auto pass_func = [](IRModule mod, auto pass_func = [](IRModule mod,
tir::transform::PassContext ctx) -> IRModule { const tir::transform::PassContext &ctx) -> IRModule {
auto mutator = [&mod]() { auto mutator = [&mod]() {
std::unordered_map<const GlobalVarNode *, KernelInfo> device_info_map; std::unordered_map<const GlobalVarNode *, KernelInfo> device_info_map;
for (const auto &[gvar, base_func] : mod->functions) { for (const auto &[gvar, base_func] : mod->functions) {
...@@ -380,7 +380,7 @@ tvm::transform::Pass LowerDeviceKernelLaunch() { ...@@ -380,7 +380,7 @@ tvm::transform::Pass LowerDeviceKernelLaunch() {
} }
} }
if (updates->functions.size()) { if (!updates->functions.empty()) {
mod.CopyOnWrite()->Update(updates); mod.CopyOnWrite()->Update(updates);
} }
} }
...@@ -396,7 +396,7 @@ tvm::transform::Pass LowerDeviceKernelLaunch() { ...@@ -396,7 +396,7 @@ tvm::transform::Pass LowerDeviceKernelLaunch() {
} }
} }
if (updates->functions.size()) { if (!updates->functions.empty()) {
mod.CopyOnWrite()->Update(updates); mod.CopyOnWrite()->Update(updates);
} }
} }
......
...@@ -44,7 +44,7 @@ class StorageAccessInfoLower : public StmtExprMutator { ...@@ -44,7 +44,7 @@ class StorageAccessInfoLower : public StmtExprMutator {
public: public:
Stmt VisitStmt_(const AllocateNode *op) final { Stmt VisitStmt_(const AllocateNode *op) final {
auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var)); auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var));
if (scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".var" && if (!scope.tag.empty() && scope.tag != ".dyn" && scope.tag != ".var" &&
scope.tag != ".barrier") { scope.tag != ".barrier") {
auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var)); auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var));
ICHECK(info.defined()) ICHECK(info.defined())
...@@ -105,8 +105,8 @@ private: ...@@ -105,8 +105,8 @@ private:
return AddressOffset(buffer_var, dtype, offset); return AddressOffset(buffer_var, dtype, offset);
} }
PrimExpr MakeTaggedAccessPtr(DataType ptr_type, Var buffer_var, PrimExpr MakeTaggedAccessPtr(DataType ptr_type, const Var &buffer_var,
DataType dtype, PrimExpr offset, DataType dtype, const PrimExpr &offset,
const MemoryInfo &info) { const MemoryInfo &info) {
if (ptr_type.is_handle()) { if (ptr_type.is_handle()) {
ICHECK(info->head_address.defined()) ICHECK(info->head_address.defined())
...@@ -134,7 +134,7 @@ namespace transform { ...@@ -134,7 +134,7 @@ namespace transform {
using namespace tir::transform; using namespace tir::transform;
Pass LowerDeviceStorageAccessInfo() { Pass LowerDeviceStorageAccessInfo() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) {
auto *n = f.CopyOnWrite(); auto *n = f.CopyOnWrite();
n->body = StorageAccessInfoLower()(std::move(n->body)); n->body = StorageAccessInfoLower()(std::move(n->body));
return f; return f;
......
...@@ -26,7 +26,7 @@ public: ...@@ -26,7 +26,7 @@ public:
LowerHopperIntrin substituter(disable_shuffle_elect); LowerHopperIntrin substituter(disable_shuffle_elect);
fptr->body = substituter.VisitStmt(f->body); fptr->body = substituter.VisitStmt(f->body);
Map<String, Array<PrimExpr>> init_desc_arg_map; Map<String, Array<PrimExpr>> init_desc_arg_map;
for (auto [call, var] : substituter.desc_map_) { for (const auto &[call, var] : substituter.desc_map_) {
// Should allocate 128 bytes for TensorMap on stack // Should allocate 128 bytes for TensorMap on stack
Call alloc_desc = Call(DataType::Handle(), builtin::tvm_stack_alloca(), Call alloc_desc = Call(DataType::Handle(), builtin::tvm_stack_alloca(),
{StringImm("arg_value"), 16}); {StringImm("arg_value"), 16});
...@@ -117,7 +117,7 @@ public: ...@@ -117,7 +117,7 @@ public:
} }
return var; return var;
} else if (call->op.same_as(create_list_of_mbarrier())) { } else if (call->op.same_as(create_list_of_mbarrier())) {
ICHECK(init_mbarrier_calls_.size() == 0); ICHECK(init_mbarrier_calls_.empty());
int num_barriers = static_cast<int>(call->args.size()); int num_barriers = static_cast<int>(call->args.size());
for (int i = 0; i < num_barriers; i++) { for (int i = 0; i < num_barriers; i++) {
PrimExpr mbarrier = Call(DataType::Handle(), get_mbarrier(), {i}); PrimExpr mbarrier = Call(DataType::Handle(), get_mbarrier(), {i});
...@@ -143,7 +143,7 @@ private: ...@@ -143,7 +143,7 @@ private:
using namespace tir::transform; using namespace tir::transform;
tvm::transform::Pass LowerHopperIntrin() { tvm::transform::Pass LowerHopperIntrin() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
bool disable_shuffle_elect = bool disable_shuffle_elect =
ctx->GetConfig<Bool>(kDisableShuffleElect, Bool(false)).value(); ctx->GetConfig<Bool>(kDisableShuffleElect, Bool(false)).value();
return LowerHopperIntrin::Substitute(f, disable_shuffle_elect); return LowerHopperIntrin::Substitute(f, disable_shuffle_elect);
......
...@@ -47,7 +47,7 @@ public: ...@@ -47,7 +47,7 @@ public:
l2_persistent_arguments.push_back(size_in_bytes); l2_persistent_arguments.push_back(size_in_bytes);
init_l2_persistent_map.Set(buffer->name, l2_persistent_arguments); init_l2_persistent_map.Set(buffer->name, l2_persistent_arguments);
} }
if (init_l2_persistent_map.size() > 0) { if (!init_l2_persistent_map.empty()) {
f = WithAttr(std::move(f), attr::kL2PersistentMap, f = WithAttr(std::move(f), attr::kL2PersistentMap,
init_l2_persistent_map); init_l2_persistent_map);
} }
...@@ -92,7 +92,7 @@ private: ...@@ -92,7 +92,7 @@ private:
using namespace tir::transform; using namespace tir::transform;
tvm::transform::Pass LowerL2Persistent() { tvm::transform::Pass LowerL2Persistent() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return LowerL2Persistent::Substitute(f); return LowerL2Persistent::Substitute(f);
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.LowerL2Persistent", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LowerL2Persistent", {});
......
...@@ -25,6 +25,8 @@ ...@@ -25,6 +25,8 @@
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <utility>
#include "tir/transforms/ir_utils.h" #include "tir/transforms/ir_utils.h"
namespace tvm { namespace tvm {
...@@ -144,8 +146,8 @@ private: ...@@ -144,8 +146,8 @@ private:
} }
static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var, static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var,
String thread_tag, Stmt body) { const String &thread_tag, Stmt body) {
IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent), IterVar iter_var(/*dom=*/Range::FromMinExtent(std::move(min), extent),
/*var=*/std::move(var), /*var=*/std::move(var),
/*iter_type=*/IterVarType::kThreadIndex, /*iter_type=*/IterVarType::kThreadIndex,
/*thread_tag=*/thread_tag); /*thread_tag=*/thread_tag);
...@@ -223,7 +225,7 @@ PrimFunc TLLowerOpaqueBlock(PrimFunc f) { ...@@ -223,7 +225,7 @@ PrimFunc TLLowerOpaqueBlock(PrimFunc f) {
tir::transform::Pass LowerOpaqueBlock() { tir::transform::Pass LowerOpaqueBlock() {
using namespace tir::transform; using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return TLLowerOpaqueBlock(std::move(f)); return TLLowerOpaqueBlock(std::move(f));
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.LowerOpaqueBlock", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LowerOpaqueBlock", {});
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <utility>
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -22,7 +24,7 @@ class SharedBarrierRewriter : public StmtExprMutator { ...@@ -22,7 +24,7 @@ class SharedBarrierRewriter : public StmtExprMutator {
public: public:
static Stmt Rewrite(Stmt body, bool disable_shuffle_elect = false) { static Stmt Rewrite(Stmt body, bool disable_shuffle_elect = false) {
SharedBarrierRewriter rewriter(disable_shuffle_elect); SharedBarrierRewriter rewriter(disable_shuffle_elect);
return rewriter(body); return rewriter(std::move(body));
} }
private: private:
...@@ -43,7 +45,7 @@ private: ...@@ -43,7 +45,7 @@ private:
Array<Buffer> barrier_buffers; Array<Buffer> barrier_buffers;
for (auto [data, buffer] : buffer_map_) { for (const auto &[data, buffer] : buffer_map_) {
const auto *ptr_type = const auto *ptr_type =
buffer->data->type_annotation.as<PointerTypeNode>(); buffer->data->type_annotation.as<PointerTypeNode>();
auto storage_scope = ptr_type->storage_scope; auto storage_scope = ptr_type->storage_scope;
...@@ -53,7 +55,7 @@ private: ...@@ -53,7 +55,7 @@ private:
} }
} }
if (barrier_buffers.size() == 0) { if (barrier_buffers.empty()) {
return StmtExprMutator::VisitStmt_(op); return StmtExprMutator::VisitStmt_(op);
} }
...@@ -189,7 +191,7 @@ namespace transform { ...@@ -189,7 +191,7 @@ namespace transform {
using namespace tir::transform; using namespace tir::transform;
tvm::transform::Pass LowerSharedBarrier() { tvm::transform::Pass LowerSharedBarrier() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
bool disable_shuffle_elect = bool disable_shuffle_elect =
ctx->GetConfig<Bool>(kDisableShuffleElect, Bool(false)).value(); ctx->GetConfig<Bool>(kDisableShuffleElect, Bool(false)).value();
return tl::LowerSharedBarrier(std::move(f), disable_shuffle_elect); return tl::LowerSharedBarrier(std::move(f), disable_shuffle_elect);
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include "runtime/thread_storage_scope.h" #include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h" #include "tir/transforms/ir_utils.h"
...@@ -49,17 +50,17 @@ class AllocateCollector : public StmtExprVisitor { ...@@ -49,17 +50,17 @@ class AllocateCollector : public StmtExprVisitor {
private: private:
bool IsDynamicSharedMemory(Var buffer_var) { bool IsDynamicSharedMemory(Var buffer_var) {
StorageScope storage_scope = StorageScope storage_scope = runtime::StorageScope::Create(
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var)); GetPtrStorageScope(std::move(buffer_var)));
return storage_scope.rank == runtime::StorageRank::kShared && return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ".dyn"; storage_scope.tag == ".dyn";
} }
bool IsStaticSharedMemory(Var buffer_var) { bool IsStaticSharedMemory(Var buffer_var) {
StorageScope storage_scope = StorageScope storage_scope = runtime::StorageScope::Create(
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var)); GetPtrStorageScope(std::move(buffer_var)));
return storage_scope.rank == runtime::StorageRank::kShared && return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ""; storage_scope.tag.empty();
} }
public: public:
...@@ -175,7 +176,6 @@ public: ...@@ -175,7 +176,6 @@ public:
} }
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
op = load.get();
if (auto opt = GetRemappedBuffer(load->buffer)) { if (auto opt = GetRemappedBuffer(load->buffer)) {
load.CopyOnWrite()->buffer = opt.value(); load.CopyOnWrite()->buffer = opt.value();
...@@ -197,7 +197,7 @@ private: ...@@ -197,7 +197,7 @@ private:
struct ThreadEntry { struct ThreadEntry {
runtime::ThreadScope scope; runtime::ThreadScope scope;
IterVar iv; IterVar iv;
int extent; int extent{};
// comparator // comparator
bool operator<(const ThreadEntry &other) const { bool operator<(const ThreadEntry &other) const {
return scope.dim_index < other.scope.dim_index; return scope.dim_index < other.scope.dim_index;
...@@ -532,7 +532,7 @@ private: ...@@ -532,7 +532,7 @@ private:
// Fix all local allocations as all statements are built. // Fix all local allocations as all statements are built.
Stmt body = SeqStmt::Flatten(seq); Stmt body = SeqStmt::Flatten(seq);
for (Buffer buf : new_alloc_bufs) { for (const Buffer &buf : new_alloc_bufs) {
body = DeclBuffer(buf, body); body = DeclBuffer(buf, body);
body = Allocate(buf->data, buf->dtype, buf->shape, body = Allocate(buf->data, buf->dtype, buf->shape,
const_true(buf->dtype.lanes()), body); const_true(buf->dtype.lanes()), body);
...@@ -542,12 +542,13 @@ private: ...@@ -542,12 +542,13 @@ private:
} }
std::pair<std::vector<PrimExpr>, std::vector<Buffer>> std::pair<std::vector<PrimExpr>, std::vector<Buffer>>
MakeWarpAllreduce(std::vector<PrimExpr> src_values, // MakeWarpAllreduce(std::vector<PrimExpr> src_values, //
std::vector<DataType> dtypes, // std::vector<DataType> dtypes, //
const CommReducerNode *combiner, // const CommReducerNode *combiner, //
PrimExpr reduce_index, int reduce_extent, // const PrimExpr &reduce_index, int reduce_extent, //
PrimExpr group_index, // const PrimExpr &group_index, //
PrimExpr mask, Optional<PrimExpr> predicate, // const PrimExpr &mask,
const Optional<PrimExpr> &predicate, //
std::vector<Stmt> *seq) { std::vector<Stmt> *seq) {
int n_buffers = src_values.size(); int n_buffers = src_values.size();
...@@ -785,7 +786,7 @@ private: ...@@ -785,7 +786,7 @@ private:
int *out_total_extent) { int *out_total_extent) {
int &total_extent = *out_total_extent; int &total_extent = *out_total_extent;
total_extent = 1; total_extent = 1;
if (tvec.size() == 0) { if (tvec.empty()) {
return make_zero(DataType::Int(32)); return make_zero(DataType::Int(32));
} }
...@@ -802,7 +803,7 @@ private: ...@@ -802,7 +803,7 @@ private:
return ret; return ret;
} }
// The local buffer index. // The local buffer index.
PrimExpr BufIndex(PrimExpr reduce_index, PrimExpr group_index, PrimExpr BufIndex(PrimExpr reduce_index, const PrimExpr &group_index,
int reduce_extent) { int reduce_extent) {
if (!is_zero(group_index)) { if (!is_zero(group_index)) {
return analyzer_.Simplify(group_index * reduce_extent + reduce_index); return analyzer_.Simplify(group_index * reduce_extent + reduce_index);
...@@ -817,8 +818,8 @@ private: ...@@ -817,8 +818,8 @@ private:
} }
// Emit warp shuffle calls. // Emit warp shuffle calls.
PrimExpr WarpShuffle(const Op &op, Optional<Buffer> mask_buffer, PrimExpr val, PrimExpr WarpShuffle(const Op &op, const Optional<Buffer> &mask_buffer,
PrimExpr delta_or_lane) { const PrimExpr &val, PrimExpr delta_or_lane) {
Array<PrimExpr> indices = {0}; Array<PrimExpr> indices = {0};
PrimExpr mask; PrimExpr mask;
if (mask_buffer.defined()) { if (mask_buffer.defined()) {
...@@ -827,7 +828,7 @@ private: ...@@ -827,7 +828,7 @@ private:
mask = IntImm(DataType::Int(32), 0); mask = IntImm(DataType::Int(32), 0);
} }
PrimExpr width = IntImm(DataType::Int(32), warp_size_); PrimExpr width = IntImm(DataType::Int(32), warp_size_);
Array<PrimExpr> args{mask, val, delta_or_lane, width, width}; Array<PrimExpr> args{mask, val, std::move(delta_or_lane), width, width};
return Call(val.dtype(), op, args); return Call(val.dtype(), op, args);
} }
...@@ -904,7 +905,7 @@ private: ...@@ -904,7 +905,7 @@ private:
// The maximum number of threads of the device. "-1" denotes unknown. // The maximum number of threads of the device. "-1" denotes unknown.
int max_num_threads_{-1}; int max_num_threads_{-1};
// A boolean indicating if the target supports warp-level masking. // A boolean indicating if the target supports warp-level masking.
bool need_warp_shuffle_mask_; bool need_warp_shuffle_mask_{};
// surrounding scope of thread extent. // surrounding scope of thread extent.
std::vector<const AttrStmtNode *> thread_extents_; std::vector<const AttrStmtNode *> thread_extents_;
...@@ -925,7 +926,7 @@ namespace transform { ...@@ -925,7 +926,7 @@ namespace transform {
using namespace tir::transform; using namespace tir::transform;
tvm::transform::Pass LowerThreadAllreduce() { tvm::transform::Pass LowerThreadAllreduce() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) {
AllocateCollector collector; AllocateCollector collector;
collector(f->body); collector(f->body);
bool is_dynamic = collector.dyn_shmem_allocs_.size() > 1; bool is_dynamic = collector.dyn_shmem_allocs_.size() > 1;
......
...@@ -79,7 +79,7 @@ public: ...@@ -79,7 +79,7 @@ public:
void Clear() { buffer_var_gemm_.clear(); } void Clear() { buffer_var_gemm_.clear(); }
void Collect(Stmt stmt) { VisitStmt(stmt); } void Collect(const Stmt &stmt) { VisitStmt(stmt); }
Array<Var> GetBufferVarGemm() { return buffer_var_gemm_; } Array<Var> GetBufferVarGemm() { return buffer_var_gemm_; }
...@@ -133,7 +133,7 @@ public: ...@@ -133,7 +133,7 @@ public:
* remapping. \param stmt The statement to rewrite. \param buffer_remap A map * remapping. \param stmt The statement to rewrite. \param buffer_remap A map
* from old buffers to new buffers. \return The rewritten statement. * from old buffers to new buffers. \return The rewritten statement.
*/ */
static Stmt Substitute(Stmt stmt, Map<Buffer, Buffer> buffer_remap) { static Stmt Substitute(const Stmt &stmt, Map<Buffer, Buffer> buffer_remap) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
RemapBufferRewriter substituter(&analyzer); RemapBufferRewriter substituter(&analyzer);
substituter.buffer_remap_ = std::move(buffer_remap); substituter.buffer_remap_ = std::move(buffer_remap);
...@@ -279,7 +279,7 @@ private: ...@@ -279,7 +279,7 @@ private:
return block; return block;
} }
int CheckAndGetBufferRowSize(Buffer buffer) { int CheckAndGetBufferRowSize(const Buffer &buffer) {
CHECK(buffer->shape.size() >= 2) CHECK(buffer->shape.size() >= 2)
<< "The dimension of Buffer \"" << buffer->name << "\" with shape " << "The dimension of Buffer \"" << buffer->name << "\" with shape "
<< buffer->shape << " should be at least 2"; << buffer->shape << " should be at least 2";
...@@ -289,9 +289,10 @@ private: ...@@ -289,9 +289,10 @@ private:
return buffer_row_size; return buffer_row_size;
} }
PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr, PrimExpr
Optional<PrimExpr> offset = std::nullopt, HandleAccessPtrAndOffset(const PrimExpr &access_ptr,
DataType dtype = DataType::Int(32)) { const Optional<PrimExpr> &offset = std::nullopt,
DataType dtype = DataType::Int(32)) {
// The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and // The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and
// accumulate it to smem_offset // accumulate it to smem_offset
CHECK(access_ptr->IsInstance<CallNode>()) CHECK(access_ptr->IsInstance<CallNode>())
...@@ -569,7 +570,7 @@ namespace transform { ...@@ -569,7 +570,7 @@ namespace transform {
using namespace tir::transform; using namespace tir::transform;
tvm::transform::Pass LowerTileOp() { tvm::transform::Pass LowerTileOp() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return LowerTileOpPass::Substitute(std::move(f)); return LowerTileOpPass::Substitute(std::move(f));
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.LowerTileOp", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LowerTileOp", {});
......
...@@ -48,7 +48,7 @@ namespace { ...@@ -48,7 +48,7 @@ namespace {
class ReturnRewriter : public StmtMutator { class ReturnRewriter : public StmtMutator {
public: public:
explicit ReturnRewriter(Var ret_var, Var ret_tcode) explicit ReturnRewriter(Var ret_var, Var ret_tcode)
: ret_var_(ret_var), ret_tcode_(ret_tcode) {} : ret_var_(std::move(ret_var)), ret_tcode_(std::move(ret_tcode)) {}
Stmt VisitStmt_(const ForNode *node) override { Stmt VisitStmt_(const ForNode *node) override {
if (node->kind == ForKind::kParallel) if (node->kind == ForKind::kParallel)
...@@ -82,7 +82,7 @@ private: ...@@ -82,7 +82,7 @@ private:
Buffer dummy_tcode_buffer; Buffer dummy_tcode_buffer;
}; };
ConvertedInfo ConvertForFFI(PrimExpr val) { ConvertedInfo ConvertForFFI(const PrimExpr &val) {
ConvertedInfo info; ConvertedInfo info;
// convert val's data type to FFI data type, return type code // convert val's data type to FFI data type, return type code
...@@ -124,7 +124,7 @@ private: ...@@ -124,7 +124,7 @@ private:
return info; return info;
} }
Stmt WriteToOut(PrimExpr val) { Stmt WriteToOut(const PrimExpr &val) {
auto info = ConvertForFFI(val); auto info = ConvertForFFI(val);
Stmt store_val = BufferStore(info.dummy_val_buffer, info.expr, {0}); Stmt store_val = BufferStore(info.dummy_val_buffer, info.expr, {0});
Stmt store_tcode = Stmt store_tcode =
...@@ -142,8 +142,8 @@ private: ...@@ -142,8 +142,8 @@ private:
}; };
Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) { Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) {
ReturnRewriter rewriter(ret_var, ret_tcode); ReturnRewriter rewriter(std::move(ret_var), std::move(ret_tcode));
return rewriter(body); return rewriter(std::move(body));
} }
class SubroutineCallRewriter : public StmtExprMutator { class SubroutineCallRewriter : public StmtExprMutator {
...@@ -151,7 +151,7 @@ public: ...@@ -151,7 +151,7 @@ public:
static Optional<Stmt> Apply(const Map<GlobalVar, String> &packed_func_methods, static Optional<Stmt> Apply(const Map<GlobalVar, String> &packed_func_methods,
Stmt stmt) { Stmt stmt) {
SubroutineCallRewriter rewriter(packed_func_methods); SubroutineCallRewriter rewriter(packed_func_methods);
stmt = rewriter.VisitStmt(std::move(stmt)); stmt = rewriter.VisitStmt(stmt);
if (rewriter.made_change_) { if (rewriter.made_change_) {
return stmt; return stmt;
} else { } else {
...@@ -192,12 +192,13 @@ private: ...@@ -192,12 +192,13 @@ private:
} // namespace } // namespace
inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, const std::string &msg) {
return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); return AssertStmt(std::move(lhs) == std::move(rhs), tvm::tir::StringImm(msg),
Evaluate(0));
} }
inline Stmt MakeAssertNotNull(PrimExpr ptr, std::string msg) { inline Stmt MakeAssertNotNull(PrimExpr ptr, const std::string &msg) {
Call isnull(DataType::Bool(), builtin::isnullptr(), {ptr}); Call isnull(DataType::Bool(), builtin::isnullptr(), {std::move(ptr)});
return AssertStmt(!isnull, tvm::tir::StringImm(msg), Evaluate(0)); return AssertStmt(!isnull, tvm::tir::StringImm(msg), Evaluate(0));
} }
...@@ -472,7 +473,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -472,7 +473,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
tvm::transform::Pass MakePackedAPI() { tvm::transform::Pass MakePackedAPI() {
using tvm::transform::Pass; using tvm::transform::Pass;
auto pass_func = [](IRModule mod, tvm::transform::PassContext ctx) { auto pass_func = [](IRModule mod, const tvm::transform::PassContext &ctx) {
Map<GlobalVar, String> packed_func_methods; Map<GlobalVar, String> packed_func_methods;
for (const auto &[gvar, base_func] : mod->functions) { for (const auto &[gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<PrimFunc>()) { if (auto opt = base_func.as<PrimFunc>()) {
...@@ -504,7 +505,7 @@ tvm::transform::Pass MakePackedAPI() { ...@@ -504,7 +505,7 @@ tvm::transform::Pass MakePackedAPI() {
} }
} }
if (updates->functions.size()) { if (!updates->functions.empty()) {
mod.CopyOnWrite()->Update(updates); mod.CopyOnWrite()->Update(updates);
} }
return mod; return mod;
......
...@@ -92,7 +92,7 @@ private: ...@@ -92,7 +92,7 @@ private:
using namespace tir::transform; using namespace tir::transform;
tvm::transform::Pass MergeIfStmt() { tvm::transform::Pass MergeIfStmt() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return MergeIfStmtRewriter::Substitute(f); return MergeIfStmtRewriter::Substitute(f);
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.MergeIfStmt", {}); return CreatePrimFuncPass(pass_func, 0, "tl.MergeIfStmt", {});
......
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include "../op/builtin.h" #include "../op/builtin.h"
#include "../target/utils.h" #include "../target/utils.h"
...@@ -51,16 +52,16 @@ using runtime::StorageScope; ...@@ -51,16 +52,16 @@ using runtime::StorageScope;
static bool IsDynamicSharedMemory(Var buffer_var) { static bool IsDynamicSharedMemory(Var buffer_var) {
StorageScope storage_scope = StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var)); runtime::StorageScope::Create(GetPtrStorageScope(std::move(buffer_var)));
return storage_scope.rank == runtime::StorageRank::kShared && return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ".dyn"; storage_scope.tag == ".dyn";
} }
static bool IsStaticSharedMemory(Var buffer_var) { static bool IsStaticSharedMemory(Var buffer_var) {
StorageScope storage_scope = StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var)); runtime::StorageScope::Create(GetPtrStorageScope(std::move(buffer_var)));
return storage_scope.rank == runtime::StorageRank::kShared && return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ""; storage_scope.tag.empty();
} }
/*! /*!
...@@ -106,7 +107,7 @@ public: ...@@ -106,7 +107,7 @@ public:
/*! \brief record the touch list of statement. */ /*! \brief record the touch list of statement. */
struct StmtEntry { struct StmtEntry {
// The statement // The statement
const Object *stmt; const Object *stmt{};
// The index in the linear_seq_ to point to end of the nested scope. // The index in the linear_seq_ to point to end of the nested scope.
// This is only set to non-zero if stmt is a nested scope. // This is only set to non-zero if stmt is a nested scope.
// if offset > 0, means this is the begin, the end entry is current_index + // if offset > 0, means this is the begin, the end entry is current_index +
...@@ -167,7 +168,7 @@ public: ...@@ -167,7 +168,7 @@ public:
StmtEntry e = scope_.back(); StmtEntry e = scope_.back();
scope_.pop_back(); scope_.pop_back();
if (e.touched.size() != 0) { if (!e.touched.empty()) {
e.stmt = op; e.stmt = op;
UpdateStmtAttr(op, scope_level_); UpdateStmtAttr(op, scope_level_);
linear_seq_.push_back(e); linear_seq_.push_back(e);
...@@ -180,7 +181,7 @@ public: ...@@ -180,7 +181,7 @@ public:
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
StmtEntry e = scope_.back(); StmtEntry e = scope_.back();
scope_.pop_back(); scope_.pop_back();
if (e.touched.size() != 0) { if (!e.touched.empty()) {
e.stmt = op; e.stmt = op;
UpdateStmtAttr(op, scope_level_); UpdateStmtAttr(op, scope_level_);
linear_seq_.push_back(e); linear_seq_.push_back(e);
...@@ -602,7 +603,7 @@ private: ...@@ -602,7 +603,7 @@ private:
} }
} }
PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) { PrimExpr GetBufferOffset(const Var &buffer_var, DataType dtype) {
auto it = buffer_byte_offsets_.find(buffer_var.get()); auto it = buffer_byte_offsets_.find(buffer_var.get());
ICHECK(it != buffer_byte_offsets_.end()) ICHECK(it != buffer_byte_offsets_.end())
<< "buffer_var = " << buffer_var->name_hint << ", dtype = " << dtype; << "buffer_var = " << buffer_var->name_hint << ", dtype = " << dtype;
...@@ -750,8 +751,8 @@ private: ...@@ -750,8 +751,8 @@ private:
std::vector<StmtEntry> gen_kill_seq; std::vector<StmtEntry> gen_kill_seq;
for (const auto &stmt_entry : seq) { for (const auto &stmt_entry : seq) {
// if has gen and kill, add to gen_kill_seq // if has gen and kill, add to gen_kill_seq
if (event_map_[stmt_entry.stmt].gen.size() > 0 || if (!event_map_[stmt_entry.stmt].gen.empty() ||
event_map_[stmt_entry.stmt].kill.size() > 0) { !event_map_[stmt_entry.stmt].kill.empty()) {
gen_kill_seq.push_back(stmt_entry); gen_kill_seq.push_back(stmt_entry);
} }
} }
...@@ -1124,8 +1125,8 @@ namespace transform { ...@@ -1124,8 +1125,8 @@ namespace transform {
Pass MergeSharedMemoryAllocations(bool enable_aggressive_merge = false, Pass MergeSharedMemoryAllocations(bool enable_aggressive_merge = false,
int align_bytes = 16) { int align_bytes = 16) {
auto pass_func = [enable_aggressive_merge, auto pass_func = [enable_aggressive_merge, align_bytes](
align_bytes](PrimFunc f, IRModule m, PassContext ctx) { PrimFunc f, const IRModule &m, PassContext ctx) {
bool merge_static_smem = bool merge_static_smem =
ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value(); ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
bool debug_merge_shared_memory_allocations = bool debug_merge_shared_memory_allocations =
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <utility>
#include "../op/builtin.h" #include "../op/builtin.h"
namespace tvm { namespace tvm {
...@@ -17,12 +19,12 @@ namespace tl { ...@@ -17,12 +19,12 @@ namespace tl {
using namespace tir; using namespace tir;
enum class Role { kConsumer, kProducer, kBoth }; enum class Role : uint8_t { kConsumer, kProducer, kBoth };
class WarpSpecializedRoleMarker_ : public StmtVisitor { class WarpSpecializedRoleMarker_ : public StmtVisitor {
public: public:
WarpSpecializedRoleMarker_(Map<Var, Buffer> buffer_data_to_buffer) WarpSpecializedRoleMarker_(Map<Var, Buffer> buffer_data_to_buffer)
: buffer_data_to_buffer_(buffer_data_to_buffer) {} : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)) {}
Role GetRole(const StmtNode *stmt) const { Role GetRole(const StmtNode *stmt) const {
auto it = map_.find(stmt); auto it = map_.find(stmt);
...@@ -135,8 +137,8 @@ public: ...@@ -135,8 +137,8 @@ public:
private: private:
MultiVersionBufferRewriter() = default; MultiVersionBufferRewriter() = default;
Array<Buffer> GetVersionedBuffers(Array<Stmt> seq_stmt, Array<Buffer> GetVersionedBuffers(const Array<Stmt> &seq_stmt,
Array<Buffer> scoped_buffers) { const Array<Buffer> &scoped_buffers) {
std::vector<Role> roles; std::vector<Role> roles;
Array<Array<BufferRegion>> reads, writes; Array<Array<BufferRegion>> reads, writes;
auto marker = WarpSpecializedRoleMarker_(buffer_data_to_buffer_); auto marker = WarpSpecializedRoleMarker_(buffer_data_to_buffer_);
...@@ -145,8 +147,8 @@ private: ...@@ -145,8 +147,8 @@ private:
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
/*name_hint=*/"", /*body*/ stmt); /*name_hint=*/"", /*body*/ stmt);
auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_); auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
reads.push_back(std::move(access[0])); reads.push_back(access[0]);
writes.push_back(std::move(access[1])); writes.push_back(access[1]);
roles.push_back(marker.GetRole(stmt)); roles.push_back(marker.GetRole(stmt));
} }
...@@ -173,7 +175,7 @@ private: ...@@ -173,7 +175,7 @@ private:
static Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) { static Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get())); ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get()));
new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions)); new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
if (new_buffer->strides.size()) { if (!new_buffer->strides.empty()) {
ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size()); ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1]; PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1];
new_buffer->strides.insert(new_buffer->strides.begin(), stride_0); new_buffer->strides.insert(new_buffer->strides.begin(), stride_0);
...@@ -277,10 +279,12 @@ private: ...@@ -277,10 +279,12 @@ private:
} }
PrimExpr RewriteBufferAccess(const Call &call, PrimExpr RewriteBufferAccess(const Call &call,
const std::vector<int> arg_indices) { const std::vector<int> &arg_indices) {
auto product = [](const Array<PrimExpr> &input) { auto product = [](const Array<PrimExpr> &input) {
return foldl( return foldl(
[](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, [](PrimExpr a, PrimExpr b, Span span) {
return mul(std::move(a), std::move(b), std::move(span));
},
make_const(DataType::Int(32), 1), input); make_const(DataType::Int(32), 1), input);
}; };
Array<PrimExpr> new_args = call->args; Array<PrimExpr> new_args = call->args;
...@@ -316,7 +320,7 @@ private: ...@@ -316,7 +320,7 @@ private:
using namespace tir::transform; using namespace tir::transform;
tvm::transform::Pass MultiVersionBuffer() { tvm::transform::Pass MultiVersionBuffer() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return MultiVersionBufferRewriter::Substitute(f); return MultiVersionBufferRewriter::Substitute(f);
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.MultiVersionBuffer", {}); return CreatePrimFuncPass(pass_func, 0, "tl.MultiVersionBuffer", {});
......
...@@ -53,7 +53,7 @@ private: ...@@ -53,7 +53,7 @@ private:
using namespace tir::transform; using namespace tir::transform;
tvm::transform::Pass PersistThreadblock() { tvm::transform::Pass PersistThreadblock() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return PersistThreadblock::Substitute(f); return PersistThreadblock::Substitute(f);
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.PersistThreadblock", {}); return CreatePrimFuncPass(pass_func, 0, "tl.PersistThreadblock", {});
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <utility>
#include "../target/utils.h" #include "../target/utils.h"
#include "tvm/ir/expr.h" #include "tvm/ir/expr.h"
...@@ -19,7 +21,7 @@ using namespace tir; ...@@ -19,7 +21,7 @@ using namespace tir;
* \param region2 The second region. * \param region2 The second region.
* \return Whether region1 and region2 have intersections. * \return Whether region1 and region2 have intersections.
*/ */
bool MayConflict(Region region1, Region region2) { bool MayConflict(const Region &region1, const Region &region2) {
ICHECK(region1.size() == region2.size()); ICHECK(region1.size() == region2.size());
for (size_t i = 0; i < region1.size(); i++) { for (size_t i = 0; i < region1.size(); i++) {
Range dim1 = region1[i]; Range dim1 = region1[i];
...@@ -42,7 +44,7 @@ bool MayConflict(Region region1, Region region2) { ...@@ -42,7 +44,7 @@ bool MayConflict(Region region1, Region region2) {
class BufferRegionCollector : public StmtExprVisitor { class BufferRegionCollector : public StmtExprVisitor {
public: public:
BufferRegionCollector(Map<Var, Buffer> buffer_data_to_buffer) BufferRegionCollector(Map<Var, Buffer> buffer_data_to_buffer)
: buffer_data_to_buffer_(buffer_data_to_buffer) {} : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)) {}
Array<BufferRegion> GetReads() const { return reads_; } Array<BufferRegion> GetReads() const { return reads_; }
...@@ -182,7 +184,7 @@ private: ...@@ -182,7 +184,7 @@ private:
*/ */
struct PipelineStageInfo { struct PipelineStageInfo {
Array<BufferRegion> reads, writes; Array<BufferRegion> reads, writes;
int original_stmt_index; int original_stmt_index{};
int order = -1, stage = -1; int order = -1, stage = -1;
bool copy_stage = false; bool copy_stage = false;
bool producer_for_copy = false; bool producer_for_copy = false;
...@@ -200,7 +202,7 @@ private: ...@@ -200,7 +202,7 @@ private:
PipelineStageInfo MakePipelineStageInfo(Stmt stmt, int idx) { PipelineStageInfo MakePipelineStageInfo(Stmt stmt, int idx) {
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
/*body*/ stmt); /*body*/ std::move(stmt));
Array<Array<BufferRegion>> access = Array<Array<BufferRegion>> access =
GetBlockReadWriteRegion(block, buffer_data_to_buffer_); GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
auto collector = BufferRegionCollector(buffer_data_to_buffer_); auto collector = BufferRegionCollector(buffer_data_to_buffer_);
...@@ -555,12 +557,12 @@ private: ...@@ -555,12 +557,12 @@ private:
Map<Var, Buffer> buffer_data_to_buffer_; Map<Var, Buffer> buffer_data_to_buffer_;
Target target_; Target target_;
bool use_async_copy_; bool use_async_copy_{};
}; };
tvm::transform::Pass PipelinePlanning() { tvm::transform::Pass PipelinePlanning() {
using namespace tir::transform; using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
bool use_async_copy = bool use_async_copy =
ctx->GetConfig<Bool>("tir.use_async_copy", Bool(true)).value(); ctx->GetConfig<Bool>("tir.use_async_copy", Bool(true)).value();
PrimFuncNode *fptr = f.CopyOnWrite(); PrimFuncNode *fptr = f.CopyOnWrite();
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <tvm/tir/utils.h> #include <tvm/tir/utils.h>
#include <utility>
#include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_mutator_with_analyzer.h"
#include "tir/analysis/control_flow_graph.h" #include "tir/analysis/control_flow_graph.h"
#include "tir/analysis/var_use_def_analysis.h" #include "tir/analysis/var_use_def_analysis.h"
...@@ -22,11 +24,11 @@ using namespace tir; ...@@ -22,11 +24,11 @@ using namespace tir;
using namespace arith; using namespace arith;
struct SimplifyConfigNode : public AttrsNodeReflAdapter<SimplifyConfigNode> { struct SimplifyConfigNode : public AttrsNodeReflAdapter<SimplifyConfigNode> {
bool transitively_prove_inequalities; bool transitively_prove_inequalities{};
bool propagate_knowns_to_prove_conditional; bool propagate_knowns_to_prove_conditional{};
bool propagate_knowns_to_simplify_expressions; bool propagate_knowns_to_simplify_expressions{};
bool convert_boolean_to_and_of_ors; bool convert_boolean_to_and_of_ors{};
bool apply_constraints_to_boolean_branches; bool apply_constraints_to_boolean_branches{};
static void RegisterReflection() { static void RegisterReflection() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
...@@ -85,7 +87,7 @@ CollectUsedBuffers(const PrimFunc &func) { ...@@ -85,7 +87,7 @@ CollectUsedBuffers(const PrimFunc &func) {
using StmtExprVisitor::VisitExpr_; using StmtExprVisitor::VisitExpr_;
using StmtExprVisitor::VisitStmt_; using StmtExprVisitor::VisitStmt_;
Visitor(PrimFunc func) : func(func) {} Visitor(PrimFunc func) : func(std::move(func)) {}
void VisitExpr_(const CallNode *op) override { void VisitExpr_(const CallNode *op) override {
for (const auto &arg : op->args) { for (const auto &arg : op->args) {
...@@ -215,9 +217,10 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig); ...@@ -215,9 +217,10 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig);
class StmtSimplifier : public IRMutatorWithAnalyzer { class StmtSimplifier : public IRMutatorWithAnalyzer {
public: public:
static PrimFunc Apply(PrimFunc func, Analyzer *analyzer, static PrimFunc
Optional<SimplifyConfig> config_opt = std::nullopt, Apply(PrimFunc func, Analyzer *analyzer,
bool simplify_arguments = false) { const Optional<SimplifyConfig> &config_opt = std::nullopt,
bool simplify_arguments = false) {
auto config = config_opt.value_or(AttrsWithDefaultValues<SimplifyConfig>()); auto config = config_opt.value_or(AttrsWithDefaultValues<SimplifyConfig>());
analyzer->rewrite_simplify.SetEnabledExtensions( analyzer->rewrite_simplify.SetEnabledExtensions(
config->GetEnabledExtensions()); config->GetEnabledExtensions());
...@@ -273,9 +276,9 @@ private: ...@@ -273,9 +276,9 @@ private:
Analyzer *analyzer, SimplifyConfig config, Analyzer *analyzer, SimplifyConfig config,
std::optional<ControlFlowGraph> touch_pattern, std::optional<ControlFlowGraph> touch_pattern,
std::unordered_set<const VarNode *> used_in_buffer_def) std::unordered_set<const VarNode *> used_in_buffer_def)
: IRMutatorWithAnalyzer(analyzer), config_(config), : IRMutatorWithAnalyzer(analyzer), config_(std::move(config)),
touch_pattern_(touch_pattern), used_in_buffer_def_(used_in_buffer_def) { touch_pattern_(std::move(touch_pattern)),
} used_in_buffer_def_(std::move(used_in_buffer_def)) {}
using Parent = IRMutatorWithAnalyzer; using Parent = IRMutatorWithAnalyzer;
using Parent::VisitExpr_; using Parent::VisitExpr_;
...@@ -476,10 +479,11 @@ private: ...@@ -476,10 +479,11 @@ private:
using namespace tir::transform; using namespace tir::transform;
tvm::transform::Pass Simplify(bool simplify_arguments = true) { tvm::transform::Pass Simplify(bool simplify_arguments = true) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
auto cfg = ctx->GetConfig<SimplifyConfig>("tl.Simplify"); auto cfg = ctx->GetConfig<SimplifyConfig>("tl.Simplify");
return StmtSimplifier::Apply(f, &analyzer, cfg, simplify_arguments); return StmtSimplifier::Apply(std::move(f), &analyzer, cfg,
simplify_arguments);
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {}); return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {});
} }
......
...@@ -96,7 +96,7 @@ void TileLangStorageAccessVisitor::VisitStmt_(const EvaluateNode *op) { ...@@ -96,7 +96,7 @@ void TileLangStorageAccessVisitor::VisitStmt_(const EvaluateNode *op) {
curr_stmt_.stmt = op; curr_stmt_.stmt = op;
IRVisitorWithAnalyzer::VisitStmt_(op); IRVisitorWithAnalyzer::VisitStmt_(op);
// push to the scope // push to the scope
if (curr_stmt_.access.size() != 0) { if (!curr_stmt_.access.empty()) {
scope_.back().push_back(curr_stmt_); scope_.back().push_back(curr_stmt_);
curr_stmt_.access.clear(); curr_stmt_.access.clear();
} }
...@@ -185,14 +185,14 @@ void TileLangStorageAccessVisitor::VisitStmt_(const ForNode *op) { ...@@ -185,14 +185,14 @@ void TileLangStorageAccessVisitor::VisitStmt_(const ForNode *op) {
s.stmt = op; s.stmt = op;
s.access = Summarize(std::move(scope_.back()), op); s.access = Summarize(std::move(scope_.back()), op);
scope_.pop_back(); scope_.pop_back();
if (s.access.size() != 0) { if (!s.access.empty()) {
// relax the touched set to contain all ranges in the loop. // relax the touched set to contain all ranges in the loop.
std::unordered_map<const VarNode *, arith::IntSet> relax_map; std::unordered_map<const VarNode *, arith::IntSet> relax_map;
relax_map[op->loop_var.get()] = relax_map[op->loop_var.get()] =
arith::IntSet::FromRange(Range::FromMinExtent(op->min, op->extent)); arith::IntSet::FromRange(Range::FromMinExtent(op->min, op->extent));
for (AccessEntry &e : s.access) { for (AccessEntry &e : s.access) {
if (e.buffer.defined()) { if (e.buffer.defined()) {
ICHECK(e.touched.size()); ICHECK(!e.touched.empty());
Array<arith::IntSet> new_touched; Array<arith::IntSet> new_touched;
for (const auto &touched : e.touched) { for (const auto &touched : e.touched) {
new_touched.push_back(arith::EvalSet(touched, relax_map)); new_touched.push_back(arith::EvalSet(touched, relax_map));
...@@ -312,9 +312,12 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) { ...@@ -312,9 +312,12 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
Array<Range> buffer_ranges; Array<Range> buffer_ranges;
// from indices to buffer indices // from indices to buffer indices
ICHECK(buffer->shape.size() == load->indices.size()); ICHECK(buffer->shape.size() == load->indices.size());
// Use buffer shape and indices to compute the buffer_ranges for each
// dimension.
for (size_t i = 0; i < buffer->shape.size(); ++i) { for (size_t i = 0; i < buffer->shape.size(); ++i) {
buffer_ranges.push_back( PrimExpr min = load->indices[i];
Range::FromMinExtent(load->indices[i], buffer->shape[i])); PrimExpr extent = make_const(buffer->shape[i].dtype(), 1);
buffer_ranges.push_back(Range::FromMinExtent(min, extent));
} }
if (Enabled(buffer_var, scope)) { if (Enabled(buffer_var, scope)) {
ICHECK(allow_append_); ICHECK(allow_append_);
...@@ -359,7 +362,7 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) { ...@@ -359,7 +362,7 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
auto linear_to_indices = [this](PrimExpr offset, auto linear_to_indices = [this](PrimExpr offset,
const Array<PrimExpr> &shape) { const Array<PrimExpr> &shape) {
Array<PrimExpr> indices; Array<PrimExpr> indices;
PrimExpr remaining = offset; PrimExpr remaining = std::move(offset);
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
PrimExpr stride = make_const(DataType::Int(32), 1); PrimExpr stride = make_const(DataType::Int(32), 1);
for (size_t j = i + 1; j < shape.size(); ++j) { for (size_t j = i + 1; j < shape.size(); ++j) {
...@@ -417,8 +420,8 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) { ...@@ -417,8 +420,8 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
} }
} }
Map<Var, Range> Map<Var, Range> TileLangStorageAccessVisitor::ComputeThreadRange(
TileLangStorageAccessVisitor::ComputeThreadRange(Array<IterVar> threads) { const Array<IterVar> &threads) {
Map<Var, Range> thread_range; Map<Var, Range> thread_range;
for (const auto &th : threads) { for (const auto &th : threads) {
auto thread_tag = th->thread_tag; auto thread_tag = th->thread_tag;
...@@ -436,7 +439,8 @@ TileLangStorageAccessVisitor::ComputeThreadRange(Array<IterVar> threads) { ...@@ -436,7 +439,8 @@ TileLangStorageAccessVisitor::ComputeThreadRange(Array<IterVar> threads) {
return thread_range; return thread_range;
} }
StorageScope TileLangStorageAccessVisitor::GetScope(Var buffer_var) const { StorageScope
TileLangStorageAccessVisitor::GetScope(const Var &buffer_var) const {
if (buffer_var->type_annotation.as<PointerTypeNode>()) { if (buffer_var->type_annotation.as<PointerTypeNode>()) {
return StorageScope::Create(GetPtrStorageScope(buffer_var)); return StorageScope::Create(GetPtrStorageScope(buffer_var));
} }
......
...@@ -49,7 +49,7 @@ using runtime::StorageScope; ...@@ -49,7 +49,7 @@ using runtime::StorageScope;
class TileLangStorageAccessVisitor : public IRVisitorWithAnalyzer { class TileLangStorageAccessVisitor : public IRVisitorWithAnalyzer {
public: public:
/*! \brief Storage access type */ /*! \brief Storage access type */
enum AccessType { enum AccessType : uint8_t {
kRead, kRead,
kWrite, kWrite,
kSync, kSync,
...@@ -88,7 +88,7 @@ public: ...@@ -88,7 +88,7 @@ public:
/*! \brief Access pattern about a single statement */ /*! \brief Access pattern about a single statement */
struct StmtEntry { struct StmtEntry {
/*! \brief The statement */ /*! \brief The statement */
const Object *stmt; const Object *stmt{};
/*! \brief access patterns in the statement */ /*! \brief access patterns in the statement */
std::vector<AccessEntry> access; std::vector<AccessEntry> access;
}; };
...@@ -144,13 +144,13 @@ protected: ...@@ -144,13 +144,13 @@ protected:
* \param threads The threads to compute the range for. * \param threads The threads to compute the range for.
* \return The thread range. * \return The thread range.
*/ */
Map<Var, Range> ComputeThreadRange(Array<IterVar> threads); Map<Var, Range> ComputeThreadRange(const Array<IterVar> &threads);
/*! /*!
* \brief Get the scope of the buffer array. * \brief Get the scope of the buffer array.
* \return The scope of the final buffer array. * \return The scope of the final buffer array.
*/ */
StorageScope GetScope(Var buffer_var) const; StorageScope GetScope(const Var &buffer_var) const;
// access scope // access scope
std::vector<std::vector<StmtEntry>> scope_; std::vector<std::vector<StmtEntry>> scope_;
......
...@@ -36,6 +36,7 @@ ...@@ -36,6 +36,7 @@
#include <map> #include <map>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include "arith/int_operator.h" #include "arith/int_operator.h"
#include "runtime/thread_storage_scope.h" #include "runtime/thread_storage_scope.h"
...@@ -95,17 +96,17 @@ static void LegalizeBufferLoadDType(BufferLoadNode *n) { ...@@ -95,17 +96,17 @@ static void LegalizeBufferLoadDType(BufferLoadNode *n) {
class AllocateCollector : public StmtExprVisitor { class AllocateCollector : public StmtExprVisitor {
private: private:
bool IsDynamicSharedMemory(Var buffer_var) { bool IsDynamicSharedMemory(Var buffer_var) {
StorageScope storage_scope = StorageScope storage_scope = runtime::StorageScope::Create(
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var)); GetPtrStorageScope(std::move(buffer_var)));
return storage_scope.rank == runtime::StorageRank::kShared && return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ".dyn"; storage_scope.tag == ".dyn";
} }
bool IsStaticSharedMemory(Var buffer_var) { bool IsStaticSharedMemory(Var buffer_var) {
StorageScope storage_scope = StorageScope storage_scope = runtime::StorageScope::Create(
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var)); GetPtrStorageScope(std::move(buffer_var)));
return storage_scope.rank == runtime::StorageRank::kShared && return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ""; storage_scope.tag.empty();
} }
public: public:
...@@ -143,7 +144,7 @@ public: ...@@ -143,7 +144,7 @@ public:
/*! \brief record the touch hist of statement. */ /*! \brief record the touch hist of statement. */
struct StmtEntry { struct StmtEntry {
// The statement // The statement
const Object *stmt; const Object *stmt{};
// The index in the linear_seq_ to point to end of the nested scope. // The index in the linear_seq_ to point to end of the nested scope.
// This is only set to non-zero if stmt is a nested scope. // This is only set to non-zero if stmt is a nested scope.
// if offset > 0, means this is the begin, the end entry is current_index + // if offset > 0, means this is the begin, the end entry is current_index +
...@@ -198,11 +199,11 @@ public: ...@@ -198,11 +199,11 @@ public:
<< it->second.num_physical_dimensions << it->second.num_physical_dimensions
<< " physical dimensions, but is accessed as having " << " physical dimensions, but is accessed as having "
<< op->buffer->axis_separators.size() + 1 << " physical dimensions" << op->buffer->axis_separators.size() + 1 << " physical dimensions"
<< std::endl; << '\n';
} }
StmtEntry e = scope_.back(); StmtEntry e = scope_.back();
scope_.pop_back(); scope_.pop_back();
if (e.touched.size() != 0) { if (!e.touched.empty()) {
e.stmt = op; e.stmt = op;
linear_seq_.push_back(e); linear_seq_.push_back(e);
} }
...@@ -227,7 +228,7 @@ public: ...@@ -227,7 +228,7 @@ public:
<< it->second.num_physical_dimensions << it->second.num_physical_dimensions
<< " physical dimensions, but is accessed as having " << " physical dimensions, but is accessed as having "
<< op->buffer->axis_separators.size() + 1 << " physical dimensions" << op->buffer->axis_separators.size() + 1 << " physical dimensions"
<< std::endl; << '\n';
} }
} }
...@@ -237,7 +238,7 @@ public: ...@@ -237,7 +238,7 @@ public:
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
StmtEntry e = scope_.back(); StmtEntry e = scope_.back();
scope_.pop_back(); scope_.pop_back();
if (e.touched.size() != 0) { if (!e.touched.empty()) {
e.stmt = op; e.stmt = op;
linear_seq_.push_back(e); linear_seq_.push_back(e);
} }
...@@ -345,15 +346,15 @@ public: ...@@ -345,15 +346,15 @@ public:
src_ = src; src_ = src;
result_ = true; result_ = true;
if (stmt->IsInstance<AttrStmtNode>()) { if (stmt->IsInstance<AttrStmtNode>()) {
VisitStmt_(static_cast<const AttrStmtNode *>(stmt)); VisitStmt_(reinterpret_cast<const AttrStmtNode *>(stmt));
} else if (stmt->IsInstance<ForNode>()) { } else if (stmt->IsInstance<ForNode>()) {
VisitStmt_(static_cast<const ForNode *>(stmt)); VisitStmt_(reinterpret_cast<const ForNode *>(stmt));
} else if (stmt->IsInstance<IfThenElseNode>()) { } else if (stmt->IsInstance<IfThenElseNode>()) {
VisitStmt_(static_cast<const IfThenElseNode *>(stmt)); VisitStmt_(reinterpret_cast<const IfThenElseNode *>(stmt));
} else if (stmt->IsInstance<WhileNode>()) { } else if (stmt->IsInstance<WhileNode>()) {
VisitStmt_(static_cast<const WhileNode *>(stmt)); VisitStmt_(reinterpret_cast<const WhileNode *>(stmt));
} else if (stmt->IsInstance<BufferStoreNode>()) { } else if (stmt->IsInstance<BufferStoreNode>()) {
VisitStmt_(static_cast<const BufferStoreNode *>(stmt)); VisitStmt_(reinterpret_cast<const BufferStoreNode *>(stmt));
} else { } else {
return false; return false;
} }
...@@ -442,9 +443,9 @@ private: ...@@ -442,9 +443,9 @@ private:
// result of the check // result of the check
bool result_{true}; bool result_{true};
// destination memory // destination memory
const VarNode *dst_; const VarNode *dst_{};
// source variable // source variable
const VarNode *src_; const VarNode *src_{};
// counter of load, // counter of load,
// it is not safe to inplace when there is nested load like A[B[i]] // it is not safe to inplace when there is nested load like A[B[i]]
int mem_nest_{0}; int mem_nest_{0};
...@@ -501,7 +502,7 @@ public: ...@@ -501,7 +502,7 @@ public:
return node; return node;
} }
Buffer RemapBuffer(Buffer buf, Var new_backing_array) { Buffer RemapBuffer(const Buffer &buf, const Var &new_backing_array) {
auto key = buf.get(); auto key = buf.get();
auto it = buffer_remap_.find(key); auto it = buffer_remap_.find(key);
if (it != buffer_remap_.end()) { if (it != buffer_remap_.end()) {
...@@ -641,7 +642,7 @@ private: ...@@ -641,7 +642,7 @@ private:
// The physical dimensionality of the allocations. Since // The physical dimensionality of the allocations. Since
// StorageRewrite is applied after StorageFlatten/FlattenBuffer, // StorageRewrite is applied after StorageFlatten/FlattenBuffer,
// this is size of `AllocateNode::extents`. If moved // this is size of `AllocateNode::extents`. If moved
size_t ndim; size_t ndim{};
// Allocs that shares this entry. // Allocs that shares this entry.
std::vector<const AllocateNode *> allocs; std::vector<const AllocateNode *> allocs;
// The children of this entry, not including itself. // The children of this entry, not including itself.
...@@ -671,7 +672,7 @@ private: ...@@ -671,7 +672,7 @@ private:
// Checks whether the storage_scope is especially tagged for a specific // Checks whether the storage_scope is especially tagged for a specific
// memory. Special memory is all combined into a single allocation. // memory. Special memory is all combined into a single allocation.
bool IsSpecialTaggedMemory(const StorageScope &scope) { bool IsSpecialTaggedMemory(const StorageScope &scope) {
return scope.tag.length() != 0 && scope.tag != ".dyn" && return !scope.tag.empty() && scope.tag != ".dyn" &&
scope.tag != ".barrier" && scope.tag != ".workspace" && scope.tag != ".barrier" && scope.tag != ".workspace" &&
scope.tag != ".vtcm"; scope.tag != ".vtcm";
} }
...@@ -729,7 +730,7 @@ private: ...@@ -729,7 +730,7 @@ private:
// already merged // already merged
if (e->bits_offset != 0) if (e->bits_offset != 0)
continue; continue;
if (e->merged_children.size() != 0) { if (!e->merged_children.empty()) {
NewAllocTagMerged(e); NewAllocTagMerged(e);
continue; continue;
} }
...@@ -993,7 +994,7 @@ private: ...@@ -993,7 +994,7 @@ private:
} }
// enter/exit new scope // enter/exit new scope
if (s.stmt->IsInstance<AttrStmtNode>()) { if (s.stmt->IsInstance<AttrStmtNode>()) {
const auto *op = static_cast<const AttrStmtNode *>(s.stmt); const auto *op = reinterpret_cast<const AttrStmtNode *>(s.stmt);
if (op->attr_key == tir::attr::thread_extent || if (op->attr_key == tir::attr::thread_extent ||
op->attr_key == tir::attr::virtual_thread || op->attr_key == tir::attr::virtual_thread ||
tir::attr::IsPragmaKey(op->attr_key)) { tir::attr::IsPragmaKey(op->attr_key)) {
...@@ -1002,7 +1003,7 @@ private: ...@@ -1002,7 +1003,7 @@ private:
ICHECK(op->attr_key == tir::attr::extern_scope); ICHECK(op->attr_key == tir::attr::extern_scope);
} }
} else if (s.stmt->IsInstance<ForNode>()) { } else if (s.stmt->IsInstance<ForNode>()) {
const auto *op = static_cast<const ForNode *>(s.stmt); const auto *op = reinterpret_cast<const ForNode *>(s.stmt);
if (op->kind == ForKind::kParallel) { if (op->kind == ForKind::kParallel) {
if (thread_scope_ == nullptr || thread_scope_ == op) { if (thread_scope_ == nullptr || thread_scope_ == op) {
PlanNewScope(op); PlanNewScope(op);
...@@ -1062,7 +1063,7 @@ private: ...@@ -1062,7 +1063,7 @@ private:
// disable reuse of small arrays, they will be lowered to registers in LLVM // disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory // This rules only apply if we are using non special memory
bool is_small_array = bool is_small_array =
(scope.tag.length() == 0) && (scope.tag.empty()) &&
(scope.rank >= StorageRank::kWarp || op->dtype.is_handle() || (scope.rank >= StorageRank::kWarp || op->dtype.is_handle() ||
(is_known_size && const_nbits <= 32)); (is_known_size && const_nbits <= 32));
...@@ -1134,7 +1135,7 @@ private: ...@@ -1134,7 +1135,7 @@ private:
// disable reuse of small arrays, they will be lowered to registers in LLVM // disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory // This rules only apply if we are using non special memory
if (e->scope.tag.length() == 0) { if (e->scope.tag.empty()) {
// Disable sharing of local memory. // Disable sharing of local memory.
if (e->scope.rank >= StorageRank::kWarp || if (e->scope.rank >= StorageRank::kWarp ||
e->allocs[0]->dtype.is_handle()) e->allocs[0]->dtype.is_handle())
...@@ -1182,7 +1183,7 @@ private: ...@@ -1182,7 +1183,7 @@ private:
* *
*/ */
struct BufferVarInfo { struct BufferVarInfo {
enum DeclarationLocation { enum DeclarationLocation : uint8_t {
kPrimFuncParam = (1 << 0), kPrimFuncParam = (1 << 0),
kPrimFuncBufferMap = (1 << 1), kPrimFuncBufferMap = (1 << 1),
kAllocateNode = (1 << 2), kAllocateNode = (1 << 2),
...@@ -1293,7 +1294,7 @@ public: ...@@ -1293,7 +1294,7 @@ public:
Var buffer_var = buffer->data; Var buffer_var = buffer->data;
DataType dtype = buffer->dtype; DataType dtype = buffer->dtype;
PrimExpr extent = PrimExpr extent =
buffer->shape.size() ? buffer->shape[buffer->shape.size() - 1] : 0; !buffer->shape.empty() ? buffer->shape[buffer->shape.size() - 1] : 0;
OnArrayDeclaration(buffer_var, dtype, extent, OnArrayDeclaration(buffer_var, dtype, extent,
BufferVarInfo::kPrimFuncParam); BufferVarInfo::kPrimFuncParam);
} }
...@@ -1350,7 +1351,7 @@ public: ...@@ -1350,7 +1351,7 @@ public:
void VisitStmt_(const AllocateConstNode *op) final { void VisitStmt_(const AllocateConstNode *op) final {
const Array<PrimExpr> &extents = op->extents; const Array<PrimExpr> &extents = op->extents;
PrimExpr extent = PrimExpr extent =
extents.size() ? extents[extents.size() - 1] : NullValue<PrimExpr>(); !extents.empty() ? extents[extents.size() - 1] : NullValue<PrimExpr>();
OnArrayDeclaration(op->buffer_var, op->dtype, extent, OnArrayDeclaration(op->buffer_var, op->dtype, extent,
BufferVarInfo::kAllocateConstNode); BufferVarInfo::kAllocateConstNode);
...@@ -1367,7 +1368,7 @@ public: ...@@ -1367,7 +1368,7 @@ public:
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
} }
void HandleLetNode(Var let_var) { void HandleLetNode(const Var &let_var) {
if (let_var->dtype.is_handle()) { if (let_var->dtype.is_handle()) {
auto pointer_type = GetPointerType(let_var->type_annotation); auto pointer_type = GetPointerType(let_var->type_annotation);
if (pointer_type.has_value()) { if (pointer_type.has_value()) {
...@@ -1397,7 +1398,7 @@ public: ...@@ -1397,7 +1398,7 @@ public:
* some locations can be rewritten without others. * some locations can be rewritten without others.
*/ */
void void
OnArrayDeclaration(Var buffer, DataType element_dtype, PrimExpr extent, OnArrayDeclaration(const Var &buffer, DataType element_dtype, PrimExpr extent,
BufferVarInfo::DeclarationLocation declaration_location) { BufferVarInfo::DeclarationLocation declaration_location) {
ICHECK(info_map_.find(buffer.get()) == info_map_.end()) ICHECK(info_map_.find(buffer.get()) == info_map_.end())
<< "Array declaration of " << buffer->name_hint << "Array declaration of " << buffer->name_hint
...@@ -1406,8 +1407,8 @@ public: ...@@ -1406,8 +1407,8 @@ public:
if (element_dtype == DataType::Bool()) { if (element_dtype == DataType::Bool()) {
element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes()); element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes());
} }
info_map_[buffer.get()] = info_map_[buffer.get()] = BufferVarInfo{
BufferVarInfo{buffer, element_dtype, extent, declaration_location}; buffer, element_dtype, std::move(extent), declaration_location};
} }
/* Update the type map for a buffer based on its usage /* Update the type map for a buffer based on its usage
...@@ -1452,7 +1453,7 @@ public: ...@@ -1452,7 +1453,7 @@ public:
ICHECK(indices[i].dtype().is_scalar()) ICHECK(indices[i].dtype().is_scalar())
<< "Only the last index of a buffer access may be a vector type."; << "Only the last index of a buffer access may be a vector type.";
} }
int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1; int index_lanes = !indices.empty() ? indices.back().dtype().lanes() : 1;
DataType access_dtype = value_dtype; DataType access_dtype = value_dtype;
...@@ -1488,7 +1489,7 @@ public: ...@@ -1488,7 +1489,7 @@ public:
// divisible by the number of number of lanes, and the predicate // divisible by the number of number of lanes, and the predicate
// does not apply any masking, then this array access could be // does not apply any masking, then this array access could be
// vectorized. // vectorized.
if (indices.size()) { if (!indices.empty()) {
const RampNode *ramp_index = indices[indices.size() - 1].as<RampNode>(); const RampNode *ramp_index = indices[indices.size() - 1].as<RampNode>();
if (ramp_index && is_one(ramp_index->stride)) { if (ramp_index && is_one(ramp_index->stride)) {
if (ramp_index->lanes->IsInstance<IntImmNode>()) { if (ramp_index->lanes->IsInstance<IntImmNode>()) {
...@@ -1502,7 +1503,7 @@ public: ...@@ -1502,7 +1503,7 @@ public:
} }
} }
if (detect_scalar_read_patterns_ && is_buffer_load && indices.size()) { if (detect_scalar_read_patterns_ && is_buffer_load && !indices.empty()) {
const PrimExpr last_dim_index = indices[indices.size() - 1]; const PrimExpr last_dim_index = indices[indices.size() - 1];
if (last_dim_index.dtype().lanes() == 1) { if (last_dim_index.dtype().lanes() == 1) {
arith::ModularSet me = analyzer_.modular_set(last_dim_index); arith::ModularSet me = analyzer_.modular_set(last_dim_index);
...@@ -1910,7 +1911,7 @@ PrimFunc PointerValueTypeRewrite( ...@@ -1910,7 +1911,7 @@ PrimFunc PointerValueTypeRewrite(
using namespace tir::transform; using namespace tir::transform;
namespace transform { namespace transform {
Pass StorageRewrite() { Pass StorageRewrite() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [](PrimFunc f, const IRModule &m, PassContext ctx) {
bool enable_reuse = true; bool enable_reuse = true;
bool reuse_require_exact_matched_dtype = false; bool reuse_require_exact_matched_dtype = false;
bool merge_static_smem = bool merge_static_smem =
...@@ -1957,7 +1958,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ ...@@ -1957,7 +1958,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
}); });
Pass PointerValueTypeRewrite() { Pass PointerValueTypeRewrite() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return tl::PointerValueTypeRewrite(std::move(f)); return tl::PointerValueTypeRewrite(std::move(f));
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.PointerValueTypeRewrite", {}); return CreatePrimFuncPass(pass_func, 0, "tl.PointerValueTypeRewrite", {});
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include "./common/thread_sync_types.h" #include "./common/thread_sync_types.h"
#include "./storage_access.h" #include "./storage_access.h"
...@@ -46,7 +47,7 @@ using arith::IRMutatorWithAnalyzer; ...@@ -46,7 +47,7 @@ using arith::IRMutatorWithAnalyzer;
class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor { class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
public: public:
explicit TileLangThreadSyncPlanner(StorageScope sync_scope) explicit TileLangThreadSyncPlanner(StorageScope sync_scope)
: sync_scope_(sync_scope) {} : sync_scope_(std::move(sync_scope)) {}
// The syncs inserted before each statement // The syncs inserted before each statement
std::unordered_set<const Object *> syncs_inserted_; std::unordered_set<const Object *> syncs_inserted_;
...@@ -404,7 +405,7 @@ private: ...@@ -404,7 +405,7 @@ private:
class ThreadSyncAfterWaitQueueInserter : public StmtExprMutator { class ThreadSyncAfterWaitQueueInserter : public StmtExprMutator {
public: public:
explicit ThreadSyncAfterWaitQueueInserter(StorageScope sync_scope) explicit ThreadSyncAfterWaitQueueInserter(StorageScope sync_scope)
: sync_scope_(sync_scope) {} : sync_scope_(std::move(sync_scope)) {}
Stmt VisitStmt_(const AttrStmtNode *op) final { Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tvm::tir::attr::async_wait_queue_scope) { if (op->attr_key == tvm::tir::attr::async_wait_queue_scope) {
...@@ -430,10 +431,10 @@ class ThreadSyncInserter : public StmtExprMutator { ...@@ -430,10 +431,10 @@ class ThreadSyncInserter : public StmtExprMutator {
public: public:
ThreadSyncInserter(StorageScope sync_scope, ThreadSyncInserter(StorageScope sync_scope,
const std::unordered_set<const Object *> &syncs) const std::unordered_set<const Object *> &syncs)
: sync_scope_(sync_scope), syncs_(syncs) {} : sync_scope_(std::move(sync_scope)), syncs_(syncs) {}
Stmt VisitStmt(const Stmt &stmt) final { Stmt VisitStmt(const Stmt &stmt) final {
if (syncs_.size() == 0) if (syncs_.empty())
return stmt; return stmt;
if (syncs_.count(stmt.get())) { if (syncs_.count(stmt.get())) {
Stmt barrier; Stmt barrier;
...@@ -535,7 +536,7 @@ private: ...@@ -535,7 +536,7 @@ private:
// Get current storage scope. // Get current storage scope.
StorageScope GetScope(Var buffer_var) const { StorageScope GetScope(Var buffer_var) const {
return StorageScope::Create(GetPtrStorageScope(buffer_var)); return StorageScope::Create(GetPtrStorageScope(std::move(buffer_var)));
} }
// private functions. // private functions.
...@@ -612,10 +613,10 @@ private: ...@@ -612,10 +613,10 @@ private:
Stmt VisitStmt_(const EvaluateNode *op) final { Stmt VisitStmt_(const EvaluateNode *op) final {
const CallNode *call = nullptr; const CallNode *call = nullptr;
if (op->value->IsInstance<CallNode>()) { if (op->value->IsInstance<CallNode>()) {
call = static_cast<const CallNode *>(op->value.get()); call = op->value.as<CallNode>();
if (call->op.same_as(builtin::tvm_storage_sync())) { if (call->op.same_as(builtin::tvm_storage_sync())) {
const auto &args = call->args; const auto &args = call->args;
ICHECK(args.size() > 0); ICHECK(!args.empty());
const auto *scope_node = args[0].as<StringImmNode>(); const auto *scope_node = args[0].as<StringImmNode>();
ICHECK(scope_node != nullptr); ICHECK(scope_node != nullptr);
const std::string &scope = scope_node->value; const std::string &scope = scope_node->value;
...@@ -741,11 +742,11 @@ private: ...@@ -741,11 +742,11 @@ private:
std::unordered_map<ThreadBoundKey, size_t> thread_count_map_; std::unordered_map<ThreadBoundKey, size_t> thread_count_map_;
}; };
PrimFunc TileLangThreadSync(PrimFunc func, std::string storage_scope) { PrimFunc TileLangThreadSync(PrimFunc func, const std::string &storage_scope) {
StorageScope sync_scope = StorageScope::Create(storage_scope); StorageScope sync_scope = StorageScope::Create(storage_scope);
auto *n = func.CopyOnWrite(); auto *n = func.CopyOnWrite();
auto stmt = n->body; auto stmt = n->body;
if (sync_scope.rank == StorageRank::kShared && sync_scope.tag == "") { if (sync_scope.rank == StorageRank::kShared && sync_scope.tag.empty()) {
stmt = ThreadSyncAfterWaitQueueInserter(sync_scope)(stmt); stmt = ThreadSyncAfterWaitQueueInserter(sync_scope)(stmt);
} }
TileLangThreadSyncPlanner planner(sync_scope); TileLangThreadSyncPlanner planner(sync_scope);
...@@ -764,8 +765,9 @@ using namespace tir::transform; ...@@ -764,8 +765,9 @@ using namespace tir::transform;
namespace transform { namespace transform {
tvm::transform::Pass ThreadSync(String storage_scope) { tvm::transform::Pass ThreadSync(const String &storage_scope) {
auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [storage_scope](PrimFunc f, const IRModule &m,
const PassContext &ctx) {
auto *n = f.CopyOnWrite(); auto *n = f.CopyOnWrite();
return tl::TileLangThreadSync(std::move(f), storage_scope); return tl::TileLangThreadSync(std::move(f), storage_scope);
; ;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment