"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "68ed24b7d7c5134d762dfb5f78b1670f0d7ea863"
Unverified Commit cdc5d8d3 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Introduce clang-tidy into format.sh (#777)

* [Refactor] Update Clang-Tidy Checks and Improve Code Consistency

- Enhanced .clang-tidy configuration by adding specific checks for better bug detection and performance optimization.
- Refactored function signatures across multiple files to use `const` references for parameters, improving performance and code clarity.
- Updated various methods to ensure consistent handling of parameters, particularly in `AddPredicate`, `Substitute`, and `PlanLoopPartition` functions.
- Improved readability by replacing size checks with `empty()` method calls in several locations, ensuring clearer intent in the code.
- General code cleanup and adherence to best practices for better maintainability.

* [Refactor] Enhance Code Consistency and Clang-Tidy Configuration

- Updated .clang-tidy configuration to include additional checks for improved code quality and performance.
- Refactored function signatures across multiple files to use `const` references, enhancing performance and clarity.
- Replaced size checks with `empty()` method calls in various locations for clearer intent.
- Improved handling of parameters in several functions, ensuring consistent usage of `std::move` where applicable.
- General code cleanup to adhere to best practices and improve maintainability.

* [Refactor] Integrate Clang-Tidy Checks and Enhance Code Consistency

- Added clang-tidy checks to the format script for improved code quality assurance.
- Refactored function signatures across multiple files to consistently use `const` references, enhancing performance and clarity.
- Updated the requirements-lint.txt file to include clang-tidy as a dependency.
- General code cleanup to adhere to best practices and improve maintainability.

* [CI] Update AMD CI Workflow to Include Build Directory Creation

- Added steps to create a build directory and configure CMake with ROCm support during the format check process.
- Ensured cleanup of the build directory after the format check to maintain a clean workspace.

* [Refactor] Remove Unused Member Variables in AtomicAddNode and CopyNode

- Removed the `args_` member variable from both `AtomicAddNode` and `CopyNode` classes to streamline the code and eliminate unnecessary data members.
- This change enhances code clarity and maintainability by focusing on relevant attributes for each class.

* [Refactor] Update Clang-Tidy Integration and Code Improvements

- Modified the format script to include the `-fix` option in the clang-tidy command for automatic code fixes.
- Refactored the `AtomicAddVectorizePlanner` class to improve variable handling and consistency, including changes to member variable types and function signatures.
- Enhanced code clarity by removing unnecessary `std::move` calls and ensuring consistent usage of types across the class.
- General code cleanup to adhere to best practices and improve maintainability.

* [Refactor] Improve Parameter Handling and Consistency in AtomicAddVectorize

- Updated function signatures in `AtomicAddVectorizePlanResult` and `AtomicAddVectorizeRewriter` to use `const` references and `std::move` for better performance and clarity.
- Enhanced the `UpdateVectorSize` method to accept `const Array<PrimExpr>&` for improved efficiency.
- General code cleanup to maintain consistency and adhere to best practices.

* [CI] Add Git Submodule Initialization to CI Workflow

- Included a step to initialize and update git submodules recursively in the CI workflow.
- This change ensures that all necessary submodules are available during the format check process, improving build reliability.

* [CI] Add Git Submodule Update Step to Format Check

- Included a command to initialize and update git submodules recursively in the CI workflow during the format check process.
- This enhancement ensures that all required submodules are available, contributing to improved build reliability.

* [Refactor] Update Function Signatures in AtomicAddVectorize

- Modified the `VectorizeAtomicAdd` function signature to use `const` references for `thread_var` and `thread_bounds`, enhancing performance and code clarity.
- This change aligns with previous refactoring efforts to improve parameter handling and consistency across the codebase.
parent 471cc7f8
......@@ -12,6 +12,7 @@
#include <tvm/tir/stmt_functor.h>
#include <numeric>
#include <utility>
#include "../layout/layout.h"
#include "../layout/utils.h"
......@@ -32,7 +33,8 @@ struct VectorizePlanResult {
PrimExpr condition;
};
bool IndiceCanVectorizeDynamic(PrimExpr expr, Var var, PrimExpr iter_var_size,
bool IndiceCanVectorizeDynamic(const PrimExpr &expr, Var var,
const PrimExpr &iter_var_size,
int target_vectorized_size,
arith::Analyzer *analyzer) {
ICHECK(target_vectorized_size >= 1);
......@@ -136,7 +138,7 @@ private:
// TODO: may perform some checks here
}
void UpdateVectorSize(const Array<PrimExpr> indices, const Buffer &buffer) {
void UpdateVectorSize(const Array<PrimExpr> &indices, const Buffer &buffer) {
if (!inner_for_)
return;
auto extent_ptr = inner_for_->extent.as<IntImmNode>();
......@@ -198,7 +200,7 @@ private:
int vector_size_;
const ForNode *inner_for_;
const ForNode *inner_for_{};
Map<Var, Range> iter_map_;
bool has_nonlocal_memory_access_ = false;
// conditionally vectorize
......@@ -210,8 +212,8 @@ class VectorizedBodyMutator : public StmtExprMutator {
public:
VectorizedBodyMutator(Var inner_var, int vector_size,
std::vector<PrimExpr> conditions)
: inner_var_(inner_var), vector_size_(vector_size),
conditions_(conditions) {}
: inner_var_(std::move(inner_var)), vector_size_(vector_size),
conditions_(std::move(conditions)) {}
private:
PrimExpr VisitExpr_(const CallNode *op) final {
......@@ -244,7 +246,7 @@ private:
class VectorizedConditionExtracter : public StmtExprVisitor {
public:
VectorizedConditionExtracter() = default;
std::vector<PrimExpr> GetConditions(Stmt body) {
std::vector<PrimExpr> GetConditions(const Stmt &body) {
this->VisitStmt(body);
return conditions_;
}
......@@ -269,7 +271,7 @@ private:
class NestedLoopChecker : public StmtExprVisitor {
public:
NestedLoopChecker() : loop_num_(0) {}
int GetNestLoopNum(Stmt body) {
int GetNestLoopNum(const Stmt &body) {
this->VisitStmt(body);
return loop_num_;
}
......@@ -286,7 +288,7 @@ private:
class VectorizedConditionMutator : public StmtExprMutator {
public:
VectorizedConditionMutator(Var inner_var, int extent)
: inner_var_(inner_var), vector_size_(extent) {}
: inner_var_(std::move(inner_var)), vector_size_(extent) {}
private:
PrimExpr VisitExpr_(const GENode *node) final {
......@@ -343,7 +345,7 @@ private:
class VectorizeRewriterDynamic : public StmtExprMutator {
public:
VectorizeRewriterDynamic(VectorizePlanResult plan,
VectorizeRewriterDynamic(const VectorizePlanResult &plan,
bool disable_dynamic_tail_split)
: vector_size_(plan.vector_size), condition_(plan.condition),
dynamic_(plan.dynamic),
......@@ -396,7 +398,7 @@ private:
// Adaptively set vectorized variable to the min/max value of the extent
PrimExpr condition_bound;
if (conditions.size() > 0) {
if (!conditions.empty()) {
condition_bound = condition_mutator(conditions[0]);
for (int i = 1; i < conditions.size(); ++i) {
condition_bound = condition_bound && condition_mutator(conditions[i]);
......@@ -413,7 +415,7 @@ private:
For vectorize_for =
For(inner_var, 0, vector_size_, ForKind::kVectorized, vectorize_body);
For serial_for = For(inner_var, 0, vector_size_, ForKind::kSerial, body);
if (conditions.size() > 0) {
if (!conditions.empty()) {
body = IfThenElse(condition_bound, vectorize_for, serial_for);
} else {
body = vectorize_for;
......@@ -436,7 +438,7 @@ private:
}
}
const ForNode *inner_for_;
const ForNode *inner_for_{};
int vector_size_;
const PrimExpr condition_;
const bool dynamic_;
......@@ -484,7 +486,6 @@ private:
// non-vectorized loop
return for_node;
}
int vectorize_hint = res.vector_size;
auto rewriter = VectorizeRewriterDynamic(res, disable_dynamic_tail_split_);
return Downcast<For>(rewriter(for_node));
}
......@@ -509,7 +510,7 @@ public:
tvm::transform::Pass LoopVectorizeDynamic() {
using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
bool disable_dynamic_tail_split =
ctx->GetConfig<Bool>(kDisableDynamicTailSplit, Bool(true)).value();
int dynamic_alignment =
......
......@@ -356,7 +356,7 @@ namespace transform {
tvm::transform::Pass LowerDeviceKernelLaunch() {
auto pass_func = [](IRModule mod,
tir::transform::PassContext ctx) -> IRModule {
const tir::transform::PassContext &ctx) -> IRModule {
auto mutator = [&mod]() {
std::unordered_map<const GlobalVarNode *, KernelInfo> device_info_map;
for (const auto &[gvar, base_func] : mod->functions) {
......@@ -380,7 +380,7 @@ tvm::transform::Pass LowerDeviceKernelLaunch() {
}
}
if (updates->functions.size()) {
if (!updates->functions.empty()) {
mod.CopyOnWrite()->Update(updates);
}
}
......@@ -396,7 +396,7 @@ tvm::transform::Pass LowerDeviceKernelLaunch() {
}
}
if (updates->functions.size()) {
if (!updates->functions.empty()) {
mod.CopyOnWrite()->Update(updates);
}
}
......
......@@ -44,7 +44,7 @@ class StorageAccessInfoLower : public StmtExprMutator {
public:
Stmt VisitStmt_(const AllocateNode *op) final {
auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var));
if (scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".var" &&
if (!scope.tag.empty() && scope.tag != ".dyn" && scope.tag != ".var" &&
scope.tag != ".barrier") {
auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var));
ICHECK(info.defined())
......@@ -105,8 +105,8 @@ private:
return AddressOffset(buffer_var, dtype, offset);
}
PrimExpr MakeTaggedAccessPtr(DataType ptr_type, Var buffer_var,
DataType dtype, PrimExpr offset,
PrimExpr MakeTaggedAccessPtr(DataType ptr_type, const Var &buffer_var,
DataType dtype, const PrimExpr &offset,
const MemoryInfo &info) {
if (ptr_type.is_handle()) {
ICHECK(info->head_address.defined())
......@@ -134,7 +134,7 @@ namespace transform {
using namespace tir::transform;
Pass LowerDeviceStorageAccessInfo() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) {
auto *n = f.CopyOnWrite();
n->body = StorageAccessInfoLower()(std::move(n->body));
return f;
......
......@@ -26,7 +26,7 @@ public:
LowerHopperIntrin substituter(disable_shuffle_elect);
fptr->body = substituter.VisitStmt(f->body);
Map<String, Array<PrimExpr>> init_desc_arg_map;
for (auto [call, var] : substituter.desc_map_) {
for (const auto &[call, var] : substituter.desc_map_) {
// Should allocate 128 bytes for TensorMap on stack
Call alloc_desc = Call(DataType::Handle(), builtin::tvm_stack_alloca(),
{StringImm("arg_value"), 16});
......@@ -117,7 +117,7 @@ public:
}
return var;
} else if (call->op.same_as(create_list_of_mbarrier())) {
ICHECK(init_mbarrier_calls_.size() == 0);
ICHECK(init_mbarrier_calls_.empty());
int num_barriers = static_cast<int>(call->args.size());
for (int i = 0; i < num_barriers; i++) {
PrimExpr mbarrier = Call(DataType::Handle(), get_mbarrier(), {i});
......@@ -143,7 +143,7 @@ private:
using namespace tir::transform;
tvm::transform::Pass LowerHopperIntrin() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
bool disable_shuffle_elect =
ctx->GetConfig<Bool>(kDisableShuffleElect, Bool(false)).value();
return LowerHopperIntrin::Substitute(f, disable_shuffle_elect);
......
......@@ -47,7 +47,7 @@ public:
l2_persistent_arguments.push_back(size_in_bytes);
init_l2_persistent_map.Set(buffer->name, l2_persistent_arguments);
}
if (init_l2_persistent_map.size() > 0) {
if (!init_l2_persistent_map.empty()) {
f = WithAttr(std::move(f), attr::kL2PersistentMap,
init_l2_persistent_map);
}
......@@ -92,7 +92,7 @@ private:
using namespace tir::transform;
tvm::transform::Pass LowerL2Persistent() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return LowerL2Persistent::Substitute(f);
};
return CreatePrimFuncPass(pass_func, 0, "tl.LowerL2Persistent", {});
......
......@@ -25,6 +25,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <utility>
#include "tir/transforms/ir_utils.h"
namespace tvm {
......@@ -144,8 +146,8 @@ private:
}
static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var,
String thread_tag, Stmt body) {
IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent),
const String &thread_tag, Stmt body) {
IterVar iter_var(/*dom=*/Range::FromMinExtent(std::move(min), extent),
/*var=*/std::move(var),
/*iter_type=*/IterVarType::kThreadIndex,
/*thread_tag=*/thread_tag);
......@@ -223,7 +225,7 @@ PrimFunc TLLowerOpaqueBlock(PrimFunc f) {
tir::transform::Pass LowerOpaqueBlock() {
using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return TLLowerOpaqueBlock(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.LowerOpaqueBlock", {});
......
......@@ -13,6 +13,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <utility>
namespace tvm {
namespace tl {
......@@ -22,7 +24,7 @@ class SharedBarrierRewriter : public StmtExprMutator {
public:
static Stmt Rewrite(Stmt body, bool disable_shuffle_elect = false) {
SharedBarrierRewriter rewriter(disable_shuffle_elect);
return rewriter(body);
return rewriter(std::move(body));
}
private:
......@@ -43,7 +45,7 @@ private:
Array<Buffer> barrier_buffers;
for (auto [data, buffer] : buffer_map_) {
for (const auto &[data, buffer] : buffer_map_) {
const auto *ptr_type =
buffer->data->type_annotation.as<PointerTypeNode>();
auto storage_scope = ptr_type->storage_scope;
......@@ -53,7 +55,7 @@ private:
}
}
if (barrier_buffers.size() == 0) {
if (barrier_buffers.empty()) {
return StmtExprMutator::VisitStmt_(op);
}
......@@ -189,7 +191,7 @@ namespace transform {
using namespace tir::transform;
tvm::transform::Pass LowerSharedBarrier() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
bool disable_shuffle_elect =
ctx->GetConfig<Bool>(kDisableShuffleElect, Bool(false)).value();
return tl::LowerSharedBarrier(std::move(f), disable_shuffle_elect);
......
......@@ -30,6 +30,7 @@
#include <tvm/tir/transform.h>
#include <unordered_set>
#include <utility>
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
......@@ -49,17 +50,17 @@ class AllocateCollector : public StmtExprVisitor {
private:
bool IsDynamicSharedMemory(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
StorageScope storage_scope = runtime::StorageScope::Create(
GetPtrStorageScope(std::move(buffer_var)));
return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ".dyn";
}
bool IsStaticSharedMemory(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
StorageScope storage_scope = runtime::StorageScope::Create(
GetPtrStorageScope(std::move(buffer_var)));
return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == "";
storage_scope.tag.empty();
}
public:
......@@ -175,7 +176,6 @@ public:
}
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
op = load.get();
if (auto opt = GetRemappedBuffer(load->buffer)) {
load.CopyOnWrite()->buffer = opt.value();
......@@ -197,7 +197,7 @@ private:
struct ThreadEntry {
runtime::ThreadScope scope;
IterVar iv;
int extent;
int extent{};
// comparator
bool operator<(const ThreadEntry &other) const {
return scope.dim_index < other.scope.dim_index;
......@@ -532,7 +532,7 @@ private:
// Fix all local allocations as all statements are built.
Stmt body = SeqStmt::Flatten(seq);
for (Buffer buf : new_alloc_bufs) {
for (const Buffer &buf : new_alloc_bufs) {
body = DeclBuffer(buf, body);
body = Allocate(buf->data, buf->dtype, buf->shape,
const_true(buf->dtype.lanes()), body);
......@@ -542,12 +542,13 @@ private:
}
std::pair<std::vector<PrimExpr>, std::vector<Buffer>>
MakeWarpAllreduce(std::vector<PrimExpr> src_values, //
std::vector<DataType> dtypes, //
const CommReducerNode *combiner, //
PrimExpr reduce_index, int reduce_extent, //
PrimExpr group_index, //
PrimExpr mask, Optional<PrimExpr> predicate, //
MakeWarpAllreduce(std::vector<PrimExpr> src_values, //
std::vector<DataType> dtypes, //
const CommReducerNode *combiner, //
const PrimExpr &reduce_index, int reduce_extent, //
const PrimExpr &group_index, //
const PrimExpr &mask,
const Optional<PrimExpr> &predicate, //
std::vector<Stmt> *seq) {
int n_buffers = src_values.size();
......@@ -785,7 +786,7 @@ private:
int *out_total_extent) {
int &total_extent = *out_total_extent;
total_extent = 1;
if (tvec.size() == 0) {
if (tvec.empty()) {
return make_zero(DataType::Int(32));
}
......@@ -802,7 +803,7 @@ private:
return ret;
}
// The local buffer index.
PrimExpr BufIndex(PrimExpr reduce_index, PrimExpr group_index,
PrimExpr BufIndex(PrimExpr reduce_index, const PrimExpr &group_index,
int reduce_extent) {
if (!is_zero(group_index)) {
return analyzer_.Simplify(group_index * reduce_extent + reduce_index);
......@@ -817,8 +818,8 @@ private:
}
// Emit warp shuffle calls.
PrimExpr WarpShuffle(const Op &op, Optional<Buffer> mask_buffer, PrimExpr val,
PrimExpr delta_or_lane) {
PrimExpr WarpShuffle(const Op &op, const Optional<Buffer> &mask_buffer,
const PrimExpr &val, PrimExpr delta_or_lane) {
Array<PrimExpr> indices = {0};
PrimExpr mask;
if (mask_buffer.defined()) {
......@@ -827,7 +828,7 @@ private:
mask = IntImm(DataType::Int(32), 0);
}
PrimExpr width = IntImm(DataType::Int(32), warp_size_);
Array<PrimExpr> args{mask, val, delta_or_lane, width, width};
Array<PrimExpr> args{mask, val, std::move(delta_or_lane), width, width};
return Call(val.dtype(), op, args);
}
......@@ -904,7 +905,7 @@ private:
// The maximum number of threads of the device. "-1" denotes unknown.
int max_num_threads_{-1};
// A boolean indicating if the target supports warp-level masking.
bool need_warp_shuffle_mask_;
bool need_warp_shuffle_mask_{};
// surrounding scope of thread extent.
std::vector<const AttrStmtNode *> thread_extents_;
......@@ -925,7 +926,7 @@ namespace transform {
using namespace tir::transform;
tvm::transform::Pass LowerThreadAllreduce() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) {
AllocateCollector collector;
collector(f->body);
bool is_dynamic = collector.dyn_shmem_allocs_.size() > 1;
......
......@@ -79,7 +79,7 @@ public:
void Clear() { buffer_var_gemm_.clear(); }
void Collect(Stmt stmt) { VisitStmt(stmt); }
void Collect(const Stmt &stmt) { VisitStmt(stmt); }
Array<Var> GetBufferVarGemm() { return buffer_var_gemm_; }
......@@ -133,7 +133,7 @@ public:
* remapping. \param stmt The statement to rewrite. \param buffer_remap A map
* from old buffers to new buffers. \return The rewritten statement.
*/
static Stmt Substitute(Stmt stmt, Map<Buffer, Buffer> buffer_remap) {
static Stmt Substitute(const Stmt &stmt, Map<Buffer, Buffer> buffer_remap) {
arith::Analyzer analyzer;
RemapBufferRewriter substituter(&analyzer);
substituter.buffer_remap_ = std::move(buffer_remap);
......@@ -279,7 +279,7 @@ private:
return block;
}
int CheckAndGetBufferRowSize(Buffer buffer) {
int CheckAndGetBufferRowSize(const Buffer &buffer) {
CHECK(buffer->shape.size() >= 2)
<< "The dimension of Buffer \"" << buffer->name << "\" with shape "
<< buffer->shape << " should be at least 2";
......@@ -289,9 +289,10 @@ private:
return buffer_row_size;
}
PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr,
Optional<PrimExpr> offset = std::nullopt,
DataType dtype = DataType::Int(32)) {
PrimExpr
HandleAccessPtrAndOffset(const PrimExpr &access_ptr,
const Optional<PrimExpr> &offset = std::nullopt,
DataType dtype = DataType::Int(32)) {
// The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and
// accumulate it to smem_offset
CHECK(access_ptr->IsInstance<CallNode>())
......@@ -569,7 +570,7 @@ namespace transform {
using namespace tir::transform;
tvm::transform::Pass LowerTileOp() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return LowerTileOpPass::Substitute(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.LowerTileOp", {});
......
......@@ -48,7 +48,7 @@ namespace {
class ReturnRewriter : public StmtMutator {
public:
explicit ReturnRewriter(Var ret_var, Var ret_tcode)
: ret_var_(ret_var), ret_tcode_(ret_tcode) {}
: ret_var_(std::move(ret_var)), ret_tcode_(std::move(ret_tcode)) {}
Stmt VisitStmt_(const ForNode *node) override {
if (node->kind == ForKind::kParallel)
......@@ -82,7 +82,7 @@ private:
Buffer dummy_tcode_buffer;
};
ConvertedInfo ConvertForFFI(PrimExpr val) {
ConvertedInfo ConvertForFFI(const PrimExpr &val) {
ConvertedInfo info;
// convert val's data type to FFI data type, return type code
......@@ -124,7 +124,7 @@ private:
return info;
}
Stmt WriteToOut(PrimExpr val) {
Stmt WriteToOut(const PrimExpr &val) {
auto info = ConvertForFFI(val);
Stmt store_val = BufferStore(info.dummy_val_buffer, info.expr, {0});
Stmt store_tcode =
......@@ -142,8 +142,8 @@ private:
};
Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) {
ReturnRewriter rewriter(ret_var, ret_tcode);
return rewriter(body);
ReturnRewriter rewriter(std::move(ret_var), std::move(ret_tcode));
return rewriter(std::move(body));
}
class SubroutineCallRewriter : public StmtExprMutator {
......@@ -151,7 +151,7 @@ public:
static Optional<Stmt> Apply(const Map<GlobalVar, String> &packed_func_methods,
Stmt stmt) {
SubroutineCallRewriter rewriter(packed_func_methods);
stmt = rewriter.VisitStmt(std::move(stmt));
stmt = rewriter.VisitStmt(stmt);
if (rewriter.made_change_) {
return stmt;
} else {
......@@ -192,12 +192,13 @@ private:
} // namespace
inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, const std::string &msg) {
return AssertStmt(std::move(lhs) == std::move(rhs), tvm::tir::StringImm(msg),
Evaluate(0));
}
inline Stmt MakeAssertNotNull(PrimExpr ptr, std::string msg) {
Call isnull(DataType::Bool(), builtin::isnullptr(), {ptr});
inline Stmt MakeAssertNotNull(PrimExpr ptr, const std::string &msg) {
Call isnull(DataType::Bool(), builtin::isnullptr(), {std::move(ptr)});
return AssertStmt(!isnull, tvm::tir::StringImm(msg), Evaluate(0));
}
......@@ -472,7 +473,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
tvm::transform::Pass MakePackedAPI() {
using tvm::transform::Pass;
auto pass_func = [](IRModule mod, tvm::transform::PassContext ctx) {
auto pass_func = [](IRModule mod, const tvm::transform::PassContext &ctx) {
Map<GlobalVar, String> packed_func_methods;
for (const auto &[gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<PrimFunc>()) {
......@@ -504,7 +505,7 @@ tvm::transform::Pass MakePackedAPI() {
}
}
if (updates->functions.size()) {
if (!updates->functions.empty()) {
mod.CopyOnWrite()->Update(updates);
}
return mod;
......
......@@ -92,7 +92,7 @@ private:
using namespace tir::transform;
tvm::transform::Pass MergeIfStmt() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return MergeIfStmtRewriter::Substitute(f);
};
return CreatePrimFuncPass(pass_func, 0, "tl.MergeIfStmt", {});
......
......@@ -33,6 +33,7 @@
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "../op/builtin.h"
#include "../target/utils.h"
......@@ -51,16 +52,16 @@ using runtime::StorageScope;
static bool IsDynamicSharedMemory(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
runtime::StorageScope::Create(GetPtrStorageScope(std::move(buffer_var)));
return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ".dyn";
}
static bool IsStaticSharedMemory(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
runtime::StorageScope::Create(GetPtrStorageScope(std::move(buffer_var)));
return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == "";
storage_scope.tag.empty();
}
/*!
......@@ -106,7 +107,7 @@ public:
/*! \brief record the touch list of statement. */
struct StmtEntry {
// The statement
const Object *stmt;
const Object *stmt{};
// The index in the linear_seq_ to point to end of the nested scope.
// This is only set to non-zero if stmt is a nested scope.
// if offset > 0, means this is the begin, the end entry is current_index +
......@@ -167,7 +168,7 @@ public:
StmtEntry e = scope_.back();
scope_.pop_back();
if (e.touched.size() != 0) {
if (!e.touched.empty()) {
e.stmt = op;
UpdateStmtAttr(op, scope_level_);
linear_seq_.push_back(e);
......@@ -180,7 +181,7 @@ public:
StmtExprVisitor::VisitStmt_(op);
StmtEntry e = scope_.back();
scope_.pop_back();
if (e.touched.size() != 0) {
if (!e.touched.empty()) {
e.stmt = op;
UpdateStmtAttr(op, scope_level_);
linear_seq_.push_back(e);
......@@ -602,7 +603,7 @@ private:
}
}
PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) {
PrimExpr GetBufferOffset(const Var &buffer_var, DataType dtype) {
auto it = buffer_byte_offsets_.find(buffer_var.get());
ICHECK(it != buffer_byte_offsets_.end())
<< "buffer_var = " << buffer_var->name_hint << ", dtype = " << dtype;
......@@ -750,8 +751,8 @@ private:
std::vector<StmtEntry> gen_kill_seq;
for (const auto &stmt_entry : seq) {
// if has gen and kill, add to gen_kill_seq
if (event_map_[stmt_entry.stmt].gen.size() > 0 ||
event_map_[stmt_entry.stmt].kill.size() > 0) {
if (!event_map_[stmt_entry.stmt].gen.empty() ||
!event_map_[stmt_entry.stmt].kill.empty()) {
gen_kill_seq.push_back(stmt_entry);
}
}
......@@ -1124,8 +1125,8 @@ namespace transform {
Pass MergeSharedMemoryAllocations(bool enable_aggressive_merge = false,
int align_bytes = 16) {
auto pass_func = [enable_aggressive_merge,
align_bytes](PrimFunc f, IRModule m, PassContext ctx) {
auto pass_func = [enable_aggressive_merge, align_bytes](
PrimFunc f, const IRModule &m, PassContext ctx) {
bool merge_static_smem =
ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
bool debug_merge_shared_memory_allocations =
......
......@@ -10,6 +10,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <utility>
#include "../op/builtin.h"
namespace tvm {
......@@ -17,12 +19,12 @@ namespace tl {
using namespace tir;
enum class Role { kConsumer, kProducer, kBoth };
enum class Role : uint8_t { kConsumer, kProducer, kBoth };
class WarpSpecializedRoleMarker_ : public StmtVisitor {
public:
WarpSpecializedRoleMarker_(Map<Var, Buffer> buffer_data_to_buffer)
: buffer_data_to_buffer_(buffer_data_to_buffer) {}
: buffer_data_to_buffer_(std::move(buffer_data_to_buffer)) {}
Role GetRole(const StmtNode *stmt) const {
auto it = map_.find(stmt);
......@@ -135,8 +137,8 @@ public:
private:
MultiVersionBufferRewriter() = default;
Array<Buffer> GetVersionedBuffers(Array<Stmt> seq_stmt,
Array<Buffer> scoped_buffers) {
Array<Buffer> GetVersionedBuffers(const Array<Stmt> &seq_stmt,
const Array<Buffer> &scoped_buffers) {
std::vector<Role> roles;
Array<Array<BufferRegion>> reads, writes;
auto marker = WarpSpecializedRoleMarker_(buffer_data_to_buffer_);
......@@ -145,8 +147,8 @@ private:
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
/*name_hint=*/"", /*body*/ stmt);
auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
reads.push_back(std::move(access[0]));
writes.push_back(std::move(access[1]));
reads.push_back(access[0]);
writes.push_back(access[1]);
roles.push_back(marker.GetRole(stmt));
}
......@@ -173,7 +175,7 @@ private:
static Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get()));
new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
if (new_buffer->strides.size()) {
if (!new_buffer->strides.empty()) {
ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1];
new_buffer->strides.insert(new_buffer->strides.begin(), stride_0);
......@@ -277,10 +279,12 @@ private:
}
PrimExpr RewriteBufferAccess(const Call &call,
const std::vector<int> arg_indices) {
const std::vector<int> &arg_indices) {
auto product = [](const Array<PrimExpr> &input) {
return foldl(
[](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
[](PrimExpr a, PrimExpr b, Span span) {
return mul(std::move(a), std::move(b), std::move(span));
},
make_const(DataType::Int(32), 1), input);
};
Array<PrimExpr> new_args = call->args;
......@@ -316,7 +320,7 @@ private:
using namespace tir::transform;
tvm::transform::Pass MultiVersionBuffer() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return MultiVersionBufferRewriter::Substitute(f);
};
return CreatePrimFuncPass(pass_func, 0, "tl.MultiVersionBuffer", {});
......
......@@ -53,7 +53,7 @@ private:
using namespace tir::transform;
tvm::transform::Pass PersistThreadblock() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return PersistThreadblock::Substitute(f);
};
return CreatePrimFuncPass(pass_func, 0, "tl.PersistThreadblock", {});
......
......@@ -5,6 +5,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <utility>
#include "../target/utils.h"
#include "tvm/ir/expr.h"
......@@ -19,7 +21,7 @@ using namespace tir;
* \param region2 The second region.
* \return Whether region1 and region2 have intersections.
*/
bool MayConflict(Region region1, Region region2) {
bool MayConflict(const Region &region1, const Region &region2) {
ICHECK(region1.size() == region2.size());
for (size_t i = 0; i < region1.size(); i++) {
Range dim1 = region1[i];
......@@ -42,7 +44,7 @@ bool MayConflict(Region region1, Region region2) {
class BufferRegionCollector : public StmtExprVisitor {
public:
BufferRegionCollector(Map<Var, Buffer> buffer_data_to_buffer)
: buffer_data_to_buffer_(buffer_data_to_buffer) {}
: buffer_data_to_buffer_(std::move(buffer_data_to_buffer)) {}
Array<BufferRegion> GetReads() const { return reads_; }
......@@ -182,7 +184,7 @@ private:
*/
struct PipelineStageInfo {
Array<BufferRegion> reads, writes;
int original_stmt_index;
int original_stmt_index{};
int order = -1, stage = -1;
bool copy_stage = false;
bool producer_for_copy = false;
......@@ -200,7 +202,7 @@ private:
PipelineStageInfo MakePipelineStageInfo(Stmt stmt, int idx) {
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
/*body*/ stmt);
/*body*/ std::move(stmt));
Array<Array<BufferRegion>> access =
GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
auto collector = BufferRegionCollector(buffer_data_to_buffer_);
......@@ -555,12 +557,12 @@ private:
Map<Var, Buffer> buffer_data_to_buffer_;
Target target_;
bool use_async_copy_;
bool use_async_copy_{};
};
tvm::transform::Pass PipelinePlanning() {
using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
bool use_async_copy =
ctx->GetConfig<Bool>("tir.use_async_copy", Bool(true)).value();
PrimFuncNode *fptr = f.CopyOnWrite();
......
......@@ -11,6 +11,8 @@
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include <utility>
#include "arith/ir_mutator_with_analyzer.h"
#include "tir/analysis/control_flow_graph.h"
#include "tir/analysis/var_use_def_analysis.h"
......@@ -22,11 +24,11 @@ using namespace tir;
using namespace arith;
struct SimplifyConfigNode : public AttrsNodeReflAdapter<SimplifyConfigNode> {
bool transitively_prove_inequalities;
bool propagate_knowns_to_prove_conditional;
bool propagate_knowns_to_simplify_expressions;
bool convert_boolean_to_and_of_ors;
bool apply_constraints_to_boolean_branches;
bool transitively_prove_inequalities{};
bool propagate_knowns_to_prove_conditional{};
bool propagate_knowns_to_simplify_expressions{};
bool convert_boolean_to_and_of_ors{};
bool apply_constraints_to_boolean_branches{};
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
......@@ -85,7 +87,7 @@ CollectUsedBuffers(const PrimFunc &func) {
using StmtExprVisitor::VisitExpr_;
using StmtExprVisitor::VisitStmt_;
Visitor(PrimFunc func) : func(func) {}
Visitor(PrimFunc func) : func(std::move(func)) {}
void VisitExpr_(const CallNode *op) override {
for (const auto &arg : op->args) {
......@@ -215,9 +217,10 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig);
class StmtSimplifier : public IRMutatorWithAnalyzer {
public:
static PrimFunc Apply(PrimFunc func, Analyzer *analyzer,
Optional<SimplifyConfig> config_opt = std::nullopt,
bool simplify_arguments = false) {
static PrimFunc
Apply(PrimFunc func, Analyzer *analyzer,
const Optional<SimplifyConfig> &config_opt = std::nullopt,
bool simplify_arguments = false) {
auto config = config_opt.value_or(AttrsWithDefaultValues<SimplifyConfig>());
analyzer->rewrite_simplify.SetEnabledExtensions(
config->GetEnabledExtensions());
......@@ -273,9 +276,9 @@ private:
Analyzer *analyzer, SimplifyConfig config,
std::optional<ControlFlowGraph> touch_pattern,
std::unordered_set<const VarNode *> used_in_buffer_def)
: IRMutatorWithAnalyzer(analyzer), config_(config),
touch_pattern_(touch_pattern), used_in_buffer_def_(used_in_buffer_def) {
}
: IRMutatorWithAnalyzer(analyzer), config_(std::move(config)),
touch_pattern_(std::move(touch_pattern)),
used_in_buffer_def_(std::move(used_in_buffer_def)) {}
using Parent = IRMutatorWithAnalyzer;
using Parent::VisitExpr_;
......@@ -476,10 +479,11 @@ private:
using namespace tir::transform;
tvm::transform::Pass Simplify(bool simplify_arguments = true) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
arith::Analyzer analyzer;
auto cfg = ctx->GetConfig<SimplifyConfig>("tl.Simplify");
return StmtSimplifier::Apply(f, &analyzer, cfg, simplify_arguments);
return StmtSimplifier::Apply(std::move(f), &analyzer, cfg,
simplify_arguments);
};
return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {});
}
......
......@@ -96,7 +96,7 @@ void TileLangStorageAccessVisitor::VisitStmt_(const EvaluateNode *op) {
curr_stmt_.stmt = op;
IRVisitorWithAnalyzer::VisitStmt_(op);
// push to the scope
if (curr_stmt_.access.size() != 0) {
if (!curr_stmt_.access.empty()) {
scope_.back().push_back(curr_stmt_);
curr_stmt_.access.clear();
}
......@@ -185,14 +185,14 @@ void TileLangStorageAccessVisitor::VisitStmt_(const ForNode *op) {
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), op);
scope_.pop_back();
if (s.access.size() != 0) {
if (!s.access.empty()) {
// relax the touched set to contain all ranges in the loop.
std::unordered_map<const VarNode *, arith::IntSet> relax_map;
relax_map[op->loop_var.get()] =
arith::IntSet::FromRange(Range::FromMinExtent(op->min, op->extent));
for (AccessEntry &e : s.access) {
if (e.buffer.defined()) {
ICHECK(e.touched.size());
ICHECK(!e.touched.empty());
Array<arith::IntSet> new_touched;
for (const auto &touched : e.touched) {
new_touched.push_back(arith::EvalSet(touched, relax_map));
......@@ -312,9 +312,12 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
Array<Range> buffer_ranges;
// from indices to buffer indices
ICHECK(buffer->shape.size() == load->indices.size());
// Use buffer shape and indices to compute the buffer_ranges for each
// dimension.
for (size_t i = 0; i < buffer->shape.size(); ++i) {
buffer_ranges.push_back(
Range::FromMinExtent(load->indices[i], buffer->shape[i]));
PrimExpr min = load->indices[i];
PrimExpr extent = make_const(buffer->shape[i].dtype(), 1);
buffer_ranges.push_back(Range::FromMinExtent(min, extent));
}
if (Enabled(buffer_var, scope)) {
ICHECK(allow_append_);
......@@ -359,7 +362,7 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
auto linear_to_indices = [this](PrimExpr offset,
const Array<PrimExpr> &shape) {
Array<PrimExpr> indices;
PrimExpr remaining = offset;
PrimExpr remaining = std::move(offset);
for (size_t i = 0; i < shape.size(); ++i) {
PrimExpr stride = make_const(DataType::Int(32), 1);
for (size_t j = i + 1; j < shape.size(); ++j) {
......@@ -417,8 +420,8 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
}
}
Map<Var, Range>
TileLangStorageAccessVisitor::ComputeThreadRange(Array<IterVar> threads) {
Map<Var, Range> TileLangStorageAccessVisitor::ComputeThreadRange(
const Array<IterVar> &threads) {
Map<Var, Range> thread_range;
for (const auto &th : threads) {
auto thread_tag = th->thread_tag;
......@@ -436,7 +439,8 @@ TileLangStorageAccessVisitor::ComputeThreadRange(Array<IterVar> threads) {
return thread_range;
}
StorageScope TileLangStorageAccessVisitor::GetScope(Var buffer_var) const {
StorageScope
TileLangStorageAccessVisitor::GetScope(const Var &buffer_var) const {
if (buffer_var->type_annotation.as<PointerTypeNode>()) {
return StorageScope::Create(GetPtrStorageScope(buffer_var));
}
......
......@@ -49,7 +49,7 @@ using runtime::StorageScope;
class TileLangStorageAccessVisitor : public IRVisitorWithAnalyzer {
public:
/*! \brief Storage access type */
enum AccessType {
enum AccessType : uint8_t {
kRead,
kWrite,
kSync,
......@@ -88,7 +88,7 @@ public:
/*! \brief Access pattern about a single statement */
struct StmtEntry {
/*! \brief The statement */
const Object *stmt;
const Object *stmt{};
/*! \brief access patterns in the statement */
std::vector<AccessEntry> access;
};
......@@ -144,13 +144,13 @@ protected:
* \param threads The threads to compute the range for.
* \return The thread range.
*/
Map<Var, Range> ComputeThreadRange(Array<IterVar> threads);
Map<Var, Range> ComputeThreadRange(const Array<IterVar> &threads);
/*!
* \brief Get the scope of the buffer array.
* \return The scope of the final buffer array.
*/
StorageScope GetScope(Var buffer_var) const;
StorageScope GetScope(const Var &buffer_var) const;
// access scope
std::vector<std::vector<StmtEntry>> scope_;
......
......@@ -36,6 +36,7 @@
#include <map>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "arith/int_operator.h"
#include "runtime/thread_storage_scope.h"
......@@ -95,17 +96,17 @@ static void LegalizeBufferLoadDType(BufferLoadNode *n) {
class AllocateCollector : public StmtExprVisitor {
private:
bool IsDynamicSharedMemory(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
StorageScope storage_scope = runtime::StorageScope::Create(
GetPtrStorageScope(std::move(buffer_var)));
return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ".dyn";
}
bool IsStaticSharedMemory(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
StorageScope storage_scope = runtime::StorageScope::Create(
GetPtrStorageScope(std::move(buffer_var)));
return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == "";
storage_scope.tag.empty();
}
public:
......@@ -143,7 +144,7 @@ public:
/*! \brief record the touch hist of statement. */
struct StmtEntry {
// The statement
const Object *stmt;
const Object *stmt{};
// The index in the linear_seq_ to point to end of the nested scope.
// This is only set to non-zero if stmt is a nested scope.
// if offset > 0, means this is the begin, the end entry is current_index +
......@@ -198,11 +199,11 @@ public:
<< it->second.num_physical_dimensions
<< " physical dimensions, but is accessed as having "
<< op->buffer->axis_separators.size() + 1 << " physical dimensions"
<< std::endl;
<< '\n';
}
StmtEntry e = scope_.back();
scope_.pop_back();
if (e.touched.size() != 0) {
if (!e.touched.empty()) {
e.stmt = op;
linear_seq_.push_back(e);
}
......@@ -227,7 +228,7 @@ public:
<< it->second.num_physical_dimensions
<< " physical dimensions, but is accessed as having "
<< op->buffer->axis_separators.size() + 1 << " physical dimensions"
<< std::endl;
<< '\n';
}
}
......@@ -237,7 +238,7 @@ public:
StmtExprVisitor::VisitStmt_(op);
StmtEntry e = scope_.back();
scope_.pop_back();
if (e.touched.size() != 0) {
if (!e.touched.empty()) {
e.stmt = op;
linear_seq_.push_back(e);
}
......@@ -345,15 +346,15 @@ public:
src_ = src;
result_ = true;
if (stmt->IsInstance<AttrStmtNode>()) {
VisitStmt_(static_cast<const AttrStmtNode *>(stmt));
VisitStmt_(reinterpret_cast<const AttrStmtNode *>(stmt));
} else if (stmt->IsInstance<ForNode>()) {
VisitStmt_(static_cast<const ForNode *>(stmt));
VisitStmt_(reinterpret_cast<const ForNode *>(stmt));
} else if (stmt->IsInstance<IfThenElseNode>()) {
VisitStmt_(static_cast<const IfThenElseNode *>(stmt));
VisitStmt_(reinterpret_cast<const IfThenElseNode *>(stmt));
} else if (stmt->IsInstance<WhileNode>()) {
VisitStmt_(static_cast<const WhileNode *>(stmt));
VisitStmt_(reinterpret_cast<const WhileNode *>(stmt));
} else if (stmt->IsInstance<BufferStoreNode>()) {
VisitStmt_(static_cast<const BufferStoreNode *>(stmt));
VisitStmt_(reinterpret_cast<const BufferStoreNode *>(stmt));
} else {
return false;
}
......@@ -442,9 +443,9 @@ private:
// result of the check
bool result_{true};
// destination memory
const VarNode *dst_;
const VarNode *dst_{};
// source variable
const VarNode *src_;
const VarNode *src_{};
// counter of load,
// it is not safe to inplace when there is nested load like A[B[i]]
int mem_nest_{0};
......@@ -501,7 +502,7 @@ public:
return node;
}
Buffer RemapBuffer(Buffer buf, Var new_backing_array) {
Buffer RemapBuffer(const Buffer &buf, const Var &new_backing_array) {
auto key = buf.get();
auto it = buffer_remap_.find(key);
if (it != buffer_remap_.end()) {
......@@ -641,7 +642,7 @@ private:
// The physical dimensionality of the allocations. Since
// StorageRewrite is applied after StorageFlatten/FlattenBuffer,
// this is size of `AllocateNode::extents`. If moved
size_t ndim;
size_t ndim{};
// Allocs that shares this entry.
std::vector<const AllocateNode *> allocs;
// The children of this entry, not including itself.
......@@ -671,7 +672,7 @@ private:
// Checks whether the storage_scope is especially tagged for a specific
// memory. Special memory is all combined into a single allocation.
bool IsSpecialTaggedMemory(const StorageScope &scope) {
return scope.tag.length() != 0 && scope.tag != ".dyn" &&
return !scope.tag.empty() && scope.tag != ".dyn" &&
scope.tag != ".barrier" && scope.tag != ".workspace" &&
scope.tag != ".vtcm";
}
......@@ -729,7 +730,7 @@ private:
// already merged
if (e->bits_offset != 0)
continue;
if (e->merged_children.size() != 0) {
if (!e->merged_children.empty()) {
NewAllocTagMerged(e);
continue;
}
......@@ -993,7 +994,7 @@ private:
}
// enter/exit new scope
if (s.stmt->IsInstance<AttrStmtNode>()) {
const auto *op = static_cast<const AttrStmtNode *>(s.stmt);
const auto *op = reinterpret_cast<const AttrStmtNode *>(s.stmt);
if (op->attr_key == tir::attr::thread_extent ||
op->attr_key == tir::attr::virtual_thread ||
tir::attr::IsPragmaKey(op->attr_key)) {
......@@ -1002,7 +1003,7 @@ private:
ICHECK(op->attr_key == tir::attr::extern_scope);
}
} else if (s.stmt->IsInstance<ForNode>()) {
const auto *op = static_cast<const ForNode *>(s.stmt);
const auto *op = reinterpret_cast<const ForNode *>(s.stmt);
if (op->kind == ForKind::kParallel) {
if (thread_scope_ == nullptr || thread_scope_ == op) {
PlanNewScope(op);
......@@ -1062,7 +1063,7 @@ private:
// disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory
bool is_small_array =
(scope.tag.length() == 0) &&
(scope.tag.empty()) &&
(scope.rank >= StorageRank::kWarp || op->dtype.is_handle() ||
(is_known_size && const_nbits <= 32));
......@@ -1134,7 +1135,7 @@ private:
// disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory
if (e->scope.tag.length() == 0) {
if (e->scope.tag.empty()) {
// Disable sharing of local memory.
if (e->scope.rank >= StorageRank::kWarp ||
e->allocs[0]->dtype.is_handle())
......@@ -1182,7 +1183,7 @@ private:
*
*/
struct BufferVarInfo {
enum DeclarationLocation {
enum DeclarationLocation : uint8_t {
kPrimFuncParam = (1 << 0),
kPrimFuncBufferMap = (1 << 1),
kAllocateNode = (1 << 2),
......@@ -1293,7 +1294,7 @@ public:
Var buffer_var = buffer->data;
DataType dtype = buffer->dtype;
PrimExpr extent =
buffer->shape.size() ? buffer->shape[buffer->shape.size() - 1] : 0;
!buffer->shape.empty() ? buffer->shape[buffer->shape.size() - 1] : 0;
OnArrayDeclaration(buffer_var, dtype, extent,
BufferVarInfo::kPrimFuncParam);
}
......@@ -1350,7 +1351,7 @@ public:
void VisitStmt_(const AllocateConstNode *op) final {
const Array<PrimExpr> &extents = op->extents;
PrimExpr extent =
extents.size() ? extents[extents.size() - 1] : NullValue<PrimExpr>();
!extents.empty() ? extents[extents.size() - 1] : NullValue<PrimExpr>();
OnArrayDeclaration(op->buffer_var, op->dtype, extent,
BufferVarInfo::kAllocateConstNode);
......@@ -1367,7 +1368,7 @@ public:
StmtExprVisitor::VisitStmt_(op);
}
void HandleLetNode(Var let_var) {
void HandleLetNode(const Var &let_var) {
if (let_var->dtype.is_handle()) {
auto pointer_type = GetPointerType(let_var->type_annotation);
if (pointer_type.has_value()) {
......@@ -1397,7 +1398,7 @@ public:
* some locations can be rewritten without others.
*/
void
OnArrayDeclaration(Var buffer, DataType element_dtype, PrimExpr extent,
OnArrayDeclaration(const Var &buffer, DataType element_dtype, PrimExpr extent,
BufferVarInfo::DeclarationLocation declaration_location) {
ICHECK(info_map_.find(buffer.get()) == info_map_.end())
<< "Array declaration of " << buffer->name_hint
......@@ -1406,8 +1407,8 @@ public:
if (element_dtype == DataType::Bool()) {
element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes());
}
info_map_[buffer.get()] =
BufferVarInfo{buffer, element_dtype, extent, declaration_location};
info_map_[buffer.get()] = BufferVarInfo{
buffer, element_dtype, std::move(extent), declaration_location};
}
/* Update the type map for a buffer based on its usage
......@@ -1452,7 +1453,7 @@ public:
ICHECK(indices[i].dtype().is_scalar())
<< "Only the last index of a buffer access may be a vector type.";
}
int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1;
int index_lanes = !indices.empty() ? indices.back().dtype().lanes() : 1;
DataType access_dtype = value_dtype;
......@@ -1488,7 +1489,7 @@ public:
// divisible by the number of number of lanes, and the predicate
// does not apply any masking, then this array access could be
// vectorized.
if (indices.size()) {
if (!indices.empty()) {
const RampNode *ramp_index = indices[indices.size() - 1].as<RampNode>();
if (ramp_index && is_one(ramp_index->stride)) {
if (ramp_index->lanes->IsInstance<IntImmNode>()) {
......@@ -1502,7 +1503,7 @@ public:
}
}
if (detect_scalar_read_patterns_ && is_buffer_load && indices.size()) {
if (detect_scalar_read_patterns_ && is_buffer_load && !indices.empty()) {
const PrimExpr last_dim_index = indices[indices.size() - 1];
if (last_dim_index.dtype().lanes() == 1) {
arith::ModularSet me = analyzer_.modular_set(last_dim_index);
......@@ -1910,7 +1911,7 @@ PrimFunc PointerValueTypeRewrite(
using namespace tir::transform;
namespace transform {
Pass StorageRewrite() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto pass_func = [](PrimFunc f, const IRModule &m, PassContext ctx) {
bool enable_reuse = true;
bool reuse_require_exact_matched_dtype = false;
bool merge_static_smem =
......@@ -1957,7 +1958,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
});
Pass PointerValueTypeRewrite() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return tl::PointerValueTypeRewrite(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.PointerValueTypeRewrite", {});
......
......@@ -30,6 +30,7 @@
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "./common/thread_sync_types.h"
#include "./storage_access.h"
......@@ -46,7 +47,7 @@ using arith::IRMutatorWithAnalyzer;
class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
public:
explicit TileLangThreadSyncPlanner(StorageScope sync_scope)
: sync_scope_(sync_scope) {}
: sync_scope_(std::move(sync_scope)) {}
// The syncs inserted before each statement
std::unordered_set<const Object *> syncs_inserted_;
......@@ -404,7 +405,7 @@ private:
class ThreadSyncAfterWaitQueueInserter : public StmtExprMutator {
public:
explicit ThreadSyncAfterWaitQueueInserter(StorageScope sync_scope)
: sync_scope_(sync_scope) {}
: sync_scope_(std::move(sync_scope)) {}
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tvm::tir::attr::async_wait_queue_scope) {
......@@ -430,10 +431,10 @@ class ThreadSyncInserter : public StmtExprMutator {
public:
ThreadSyncInserter(StorageScope sync_scope,
const std::unordered_set<const Object *> &syncs)
: sync_scope_(sync_scope), syncs_(syncs) {}
: sync_scope_(std::move(sync_scope)), syncs_(syncs) {}
Stmt VisitStmt(const Stmt &stmt) final {
if (syncs_.size() == 0)
if (syncs_.empty())
return stmt;
if (syncs_.count(stmt.get())) {
Stmt barrier;
......@@ -535,7 +536,7 @@ private:
// Get current storage scope.
StorageScope GetScope(Var buffer_var) const {
return StorageScope::Create(GetPtrStorageScope(buffer_var));
return StorageScope::Create(GetPtrStorageScope(std::move(buffer_var)));
}
// private functions.
......@@ -612,10 +613,10 @@ private:
Stmt VisitStmt_(const EvaluateNode *op) final {
const CallNode *call = nullptr;
if (op->value->IsInstance<CallNode>()) {
call = static_cast<const CallNode *>(op->value.get());
call = op->value.as<CallNode>();
if (call->op.same_as(builtin::tvm_storage_sync())) {
const auto &args = call->args;
ICHECK(args.size() > 0);
ICHECK(!args.empty());
const auto *scope_node = args[0].as<StringImmNode>();
ICHECK(scope_node != nullptr);
const std::string &scope = scope_node->value;
......@@ -741,11 +742,11 @@ private:
std::unordered_map<ThreadBoundKey, size_t> thread_count_map_;
};
PrimFunc TileLangThreadSync(PrimFunc func, std::string storage_scope) {
PrimFunc TileLangThreadSync(PrimFunc func, const std::string &storage_scope) {
StorageScope sync_scope = StorageScope::Create(storage_scope);
auto *n = func.CopyOnWrite();
auto stmt = n->body;
if (sync_scope.rank == StorageRank::kShared && sync_scope.tag == "") {
if (sync_scope.rank == StorageRank::kShared && sync_scope.tag.empty()) {
stmt = ThreadSyncAfterWaitQueueInserter(sync_scope)(stmt);
}
TileLangThreadSyncPlanner planner(sync_scope);
......@@ -764,8 +765,9 @@ using namespace tir::transform;
namespace transform {
tvm::transform::Pass ThreadSync(String storage_scope) {
auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) {
tvm::transform::Pass ThreadSync(const String &storage_scope) {
auto pass_func = [storage_scope](PrimFunc f, const IRModule &m,
const PassContext &ctx) {
auto *n = f.CopyOnWrite();
return tl::TileLangThreadSync(std::move(f), storage_scope);
;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment