Unverified Commit cdc5d8d3 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Introduce clang-tidy into format.sh (#777)

* [Refactor] Update Clang-Tidy Checks and Improve Code Consistency

- Enhanced .clang-tidy configuration by adding specific checks for better bug detection and performance optimization.
- Refactored function signatures across multiple files to use `const` references for parameters, improving performance and code clarity.
- Updated various methods to ensure consistent handling of parameters, particularly in `AddPredicate`, `Substitute`, and `PlanLoopPartition` functions.
- Improved readability by replacing size checks with `empty()` method calls in several locations, ensuring clearer intent in the code.
- General code cleanup and adherence to best practices for better maintainability.

* [Refactor] Enhance Code Consistency and Clang-Tidy Configuration

- Updated .clang-tidy configuration to include additional checks for improved code quality and performance.
- Refactored function signatures across multiple files to use `const` references, enhancing performance and clarity.
- Replaced size checks with `empty()` method calls in various locations for clearer intent.
- Improved handling of parameters in several functions, ensuring consistent usage of `std::move` where applicable.
- General code cleanup to adhere to best practices and improve maintainability.

* [Refactor] Integrate Clang-Tidy Checks and Enhance Code Consistency

- Added clang-tidy checks to the format script for improved code quality assurance.
- Refactored function signatures across multiple files to consistently use `const` references, enhancing performance and clarity.
- Updated the requirements-lint.txt file to include clang-tidy as a dependency.
- General code cleanup to adhere to best practices and improve maintainability.

* [CI] Update AMD CI Workflow to Include Build Directory Creation

- Added steps to create a build directory and configure CMake with ROCm support during the format check process.
- Ensured cleanup of the build directory after the format check to maintain a clean workspace.

* [Refactor] Remove Unused Member Variables in AtomicAddNode and CopyNode

- Removed the `args_` member variable from both `AtomicAddNode` and `CopyNode` classes to streamline the code and eliminate unnecessary data members.
- This change enhances code clarity and maintainability by focusing on relevant attributes for each class.

* [Refactor] Update Clang-Tidy Integration and Code Improvements

- Modified the format script to include the `-fix` option in the clang-tidy command for automatic code fixes.
- Refactored the `AtomicAddVectorizePlanner` class to improve variable handling and consistency, including changes to member variable types and function signatures.
- Enhanced code clarity by removing unnecessary `std::move` calls and ensuring consistent usage of types across the class.
- General code cleanup to adhere to best practices and improve maintainability.

* [Refactor] Improve Parameter Handling and Consistency in AtomicAddVectorize

- Updated function signatures in `AtomicAddVectorizePlanResult` and `AtomicAddVectorizeRewriter` to use `const` references and `std::move` for better performance and clarity.
- Enhanced the `UpdateVectorSize` method to accept `const Array<PrimExpr>&` for improved efficiency.
- General code cleanup to maintain consistency and adhere to best practices.

* [CI] Add Git Submodule Initialization to CI Workflow

- Included a step to initialize and update git submodules recursively in the CI workflow.
- This change ensures that all necessary submodules are available during the format check process, improving build reliability.

* [CI] Add Git Submodule Update Step to Format Check

- Included a command to initialize and update git submodules recursively in the CI workflow during the format check process.
- This enhancement ensures that all required submodules are available, contributing to improved build reliability.

* [Refactor] Update Function Signatures in AtomicAddVectorize

- Modified the `VectorizeAtomicAdd` function signature to use `const` references for `thread_var` and `thread_bounds`, enhancing performance and code clarity.
- This change aligns with previous refactoring efforts to improve parameter handling and consistency across the codebase.
parent 471cc7f8
......@@ -33,6 +33,7 @@
#include <tvm/tir/transform.h>
#include <unordered_map>
#include <utility>
#include <vector>
#include "arith/scalable_expression.h"
......@@ -127,7 +128,7 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) {
class TLVecAllocAccess : public StmtExprMutator {
public:
TLVecAllocAccess(const VarNode *buf, Var var, PrimExpr var_lanes)
: buf_(buf), var_(var), var_lanes_(var_lanes) {}
: buf_(buf), var_(std::move(var)), var_lanes_(std::move(var_lanes)) {}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
......@@ -207,7 +208,8 @@ public:
using ExprFunctor::VisitExpr;
using StmtMutator::operator();
TLVectorizer(Var var, PrimExpr var_lanes) : var_(var), var_lanes_(var_lanes) {
TLVectorizer(const Var &var, const PrimExpr &var_lanes)
: var_(var), var_lanes_(var_lanes) {
ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes);
}
......@@ -227,11 +229,13 @@ public:
}
PrimExpr VisitExpr_(const AddNode *op) final {
return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a + b; });
return AddSubVec(
op, [](PrimExpr a, PrimExpr b) { return std::move(a) + std::move(b); });
}
PrimExpr VisitExpr_(const SubNode *op) final {
return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a - b; });
return AddSubVec(
op, [](PrimExpr a, PrimExpr b) { return std::move(a) - std::move(b); });
}
PrimExpr VisitExpr_(const MulNode *op) final {
......@@ -712,7 +716,7 @@ private:
// mutate array, with given lane requirement
// when finished, p_lane updates the lane requirement.
Array<PrimExpr> MutateArray(Array<PrimExpr> arr, int *p_lanes) {
if (arr.size() == 0)
if (arr.empty())
return arr;
int &lanes = *p_lanes;
bool changed = false;
......@@ -826,7 +830,7 @@ Stmt SkipVectorize(Stmt stmt) { return VectorizeSkipper()(std::move(stmt)); }
tvm::transform::Pass VectorizeLoop(bool enable_vectorize = true) {
using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
auto *n = f.CopyOnWrite();
if (enable_vectorize) {
n->body = tvm::tl::LoopVectorizer()(std::move(n->body));
......
......@@ -12,6 +12,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <utility>
#include "../op/builtin.h"
#include "./common/collector.h"
#include "runtime/thread_storage_scope.h"
......@@ -30,13 +32,13 @@ struct LoopInfo {
PrimExpr min;
};
enum class Role { kConsumer, kProducer, kBoth };
enum class Role : uint8_t { kConsumer, kProducer, kBoth };
class ProducerBufferDetector : public StmtExprVisitor {
public:
ProducerBufferDetector(
std::unordered_set<const BufferNode *> cur_producer_buffers)
: cur_producer_buffers_(cur_producer_buffers) {}
: cur_producer_buffers_(std::move(cur_producer_buffers)) {}
void clear() { has_producer_buffer_ = false; }
......@@ -60,7 +62,7 @@ public:
class ProducerUsedBufferFinder : public StmtExprVisitor {
public:
auto FindProducerusedBuffer(Stmt stmt) {
auto FindProducerusedBuffer(const Stmt &stmt) {
producer_buffers_.clear();
std::unordered_set<const BufferNode *> last_producer_buffers_;
for (;;) {
......@@ -128,7 +130,7 @@ private:
class WarpSpecializedRoleMarker : public StmtVisitor {
public:
WarpSpecializedRoleMarker(Map<Var, Buffer> buffer_data_to_buffer)
: buffer_data_to_buffer_(buffer_data_to_buffer) {}
: buffer_data_to_buffer_(std::move(buffer_data_to_buffer)) {}
void Prepare(const Stmt &stmt) {
ProducerUsedBufferFinder finder;
......@@ -248,12 +250,12 @@ private:
};
static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
return Call(DataType::Handle(), get_mbarrier(), {barrier_id});
return Call(DataType::Handle(), get_mbarrier(), {std::move(barrier_id)});
}
static Stmt makeArriveBarrier(PrimExpr barrier_id, int cta_id = -1,
PrimExpr pred = 1) {
Array<PrimExpr> args = {makeGetBarrier(barrier_id)};
const PrimExpr &pred = 1) {
Array<PrimExpr> args = {makeGetBarrier(std::move(barrier_id))};
if (cta_id != -1) {
args.push_back(cta_id);
args.push_back(pred);
......@@ -264,13 +266,13 @@ static Stmt makeArriveBarrier(PrimExpr barrier_id, int cta_id = -1,
static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) {
auto call = Call(DataType::Handle(), builtin::ptx_cp_async_barrier(),
{makeGetBarrier(barrier_id)});
{makeGetBarrier(std::move(barrier_id))});
return Evaluate(call);
}
static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) {
auto call = Call(DataType::Handle(), mbarrier_wait_parity(),
{makeGetBarrier(barrier_id), parity});
{makeGetBarrier(std::move(barrier_id)), std::move(parity)});
return Evaluate(call);
}
......@@ -280,7 +282,7 @@ public:
void Clear() { has_simt_copy = false; }
void Collect(Stmt stmt) { VisitStmt(stmt); }
void Collect(const Stmt &stmt) { VisitStmt(stmt); }
bool HasSimtCopy() { return has_simt_copy; }
......@@ -304,7 +306,7 @@ private:
StmtExprVisitor::VisitExpr_(op);
}
bool has_simt_copy;
bool has_simt_copy{};
bool in_if_cond_ = false;
};
......@@ -313,8 +315,8 @@ class MbarrierRewriter : public StmtExprMutator {
public:
static Stmt Rewrite(Stmt stmt, PrimExpr barrier_id) {
MbarrierRewriter rewriter;
rewriter.producer_barrier_idx_ = barrier_id;
return rewriter(stmt);
rewriter.producer_barrier_idx_ = std::move(barrier_id);
return rewriter(std::move(stmt));
}
private:
......@@ -345,15 +347,16 @@ public:
static Stmt Rewrite(Stmt stmt, Var thread_var, PrimExpr replaced,
PrimExpr thread_extent, bool do_shuffle = false) {
auto rewriter =
ThreadIdxRewriter(thread_var, replaced, thread_extent, do_shuffle);
return rewriter(stmt);
ThreadIdxRewriter(std::move(thread_var), std::move(replaced),
std::move(thread_extent), do_shuffle);
return rewriter(std::move(stmt));
}
private:
ThreadIdxRewriter(Var thread_var, PrimExpr replaced, PrimExpr thread_extent,
bool do_shuffle)
: thread_var_(thread_var), replaced_(replaced),
thread_extent_(thread_extent), do_shuffle_(do_shuffle) {}
: thread_var_(std::move(thread_var)), replaced_(std::move(replaced)),
thread_extent_(std::move(thread_extent)), do_shuffle_(do_shuffle) {}
PrimExpr VisitExpr_(const VarNode *var) final {
if (var == thread_var_.get()) {
......@@ -415,15 +418,16 @@ Block MakeGroupBlock(const Stmt &stmt,
}
struct OpInfo {
int group_size, order, stage;
int group_size{}, order{}, stage{};
std::vector<int> group;
};
struct PipelineInfo {
std::vector<OpInfo> op_infos;
PipelineInfo() = default;
PipelineInfo(Array<Array<Integer>> group_info, Array<Integer> order_info,
Array<Integer> stage_info) {
PipelineInfo(const Array<Array<Integer>> &group_info,
const Array<Integer> &order_info,
const Array<Integer> &stage_info) {
int n = static_cast<int>(group_info.size());
ICHECK(n == static_cast<int>(order_info.size()));
ICHECK(n == static_cast<int>(stage_info.size()));
......@@ -441,7 +445,7 @@ struct PipelineInfo {
}
PipelineInfo(const PipelineInfo &other) {
for (auto op_info : other.op_infos) {
for (const auto &op_info : other.op_infos) {
op_infos.push_back(op_info);
}
}
......@@ -501,18 +505,19 @@ struct PipelineInfo {
}
void PrintPipelineInfo() {
std::cout << "Print op_infos:" << std::endl;
std::cout << "Print op_infos:" << '\n';
for (size_t i = 0; i < op_infos.size(); i++) {
std::cout << i << " " << op_infos[i].group_size << " "
<< op_infos[i].order << " " << op_infos[i].stage << std::endl;
<< op_infos[i].order << " " << op_infos[i].stage << '\n';
}
std::cout << "End of print" << std::endl;
std::cout << "End of print" << '\n';
}
};
class GroupOpRewriter : public StmtExprMutator {
public:
GroupOpRewriter(PipelineInfo pipeline_info) : pipeline_info_(pipeline_info) {}
GroupOpRewriter(const PipelineInfo &pipeline_info)
: pipeline_info_(pipeline_info) {}
private:
Stmt VisitStmt_(const ForNode *op) final {
......@@ -546,7 +551,7 @@ private:
}
Array<Integer> order_anno;
Array<Integer> stage_anno;
for (auto op_info : pipeline_info_.op_infos) {
for (const auto &op_info : pipeline_info_.op_infos) {
order_anno.push_back(Integer(op_info.order));
stage_anno.push_back(Integer(op_info.stage));
}
......@@ -588,7 +593,7 @@ public:
in_if_scope_ = false;
}
static bool HasWgMMA(Stmt stmt) {
static bool HasWgMMA(const Stmt &stmt) {
auto collector = WgMMACollector();
collector(stmt);
return collector.has_wgmma_;
......@@ -629,14 +634,14 @@ public:
* @param only_has_wgmma If true, adjust emission and barrier-thread-count
* logic for blocks that contain WgMMA operations.
*/
WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv,
WSCodeEmitter(bool is_emitting_producer, const IterVar &thread_iv,
Map<Var, Buffer> buffer_data_to_buffer,
const WarpSpecializedRoleMarker &marker,
bool mbarrier_only = false, bool only_has_wgmma = false)
: is_emitting_producer_(is_emitting_producer),
buffer_data_to_buffer_(buffer_data_to_buffer), marker_(marker),
thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only),
only_has_wgmma_(only_has_wgmma) {}
buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
marker_(marker), thread_var_(thread_iv->var),
mbarrier_only_(mbarrier_only), only_has_wgmma_(only_has_wgmma) {}
/**
* @brief Whether a SIMT-style bulk copy was detected.
......@@ -757,7 +762,7 @@ private:
return FilterByRole(op);
auto seq_transformed =
op->seq.Map([&](Stmt stmt) { return VisitStmt(stmt); });
op->seq.Map([&](const Stmt &stmt) { return VisitStmt(stmt); });
auto map = ExtractSyncPattern(op->seq);
......@@ -804,7 +809,7 @@ private:
: parity_;
block_stmt.push_back(makeParityWait(acquire_barrier_id, parity));
}
ICHECK(map.release[i].size() > 0);
ICHECK(!map.release[i].empty());
for (size_t j = 0; j < map.release[i].size(); j++) {
int pattern_idx = map.release[i][j];
PrimExpr release_barrier_id =
......@@ -890,7 +895,7 @@ private:
num_barriers_ += map.patterns.size() * num_stages_;
ICHECK(new_body.size() > 0);
ICHECK(!new_body.empty());
return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body));
}
......@@ -923,8 +928,8 @@ private:
PipelineInfo pipeline_info(group_info_array, order_info_array,
stage_info_array);
if (pipeline_info.op_infos.size() > 0) {
ICHECK(pipeline_info_.op_infos.size() == 0)
if (!pipeline_info.op_infos.empty()) {
ICHECK(pipeline_info_.op_infos.empty())
<< "Nested pipeline not supported.";
}
......@@ -946,7 +951,7 @@ private:
auto result = FilterByRole(op);
Stmt grouped_for_node;
if (result.as<ForNode>() && group_anno && group_info_array.size() > 0 &&
if (result.as<ForNode>() && group_anno && !group_info_array.empty() &&
!is_emitting_producer_) {
GroupOpRewriter group_op_rewriter(pipeline_info_);
auto for_node = Downcast<For>(result);
......@@ -963,12 +968,11 @@ private:
if (result.as<ForNode>()) {
auto for_node = Downcast<For>(result);
for_node.CopyOnWrite()->annotations.erase("num_stages");
if (is_emitting_producer_ || group_info_array.size() == 0) {
if (is_emitting_producer_ || group_info_array.empty()) {
for_node.CopyOnWrite()->annotations.erase("tl_pipeline_order");
for_node.CopyOnWrite()->annotations.erase("tl_pipeline_stage");
}
if (is_emitting_producer_ || !group_anno ||
group_info_array.size() == 0) {
if (is_emitting_producer_ || !group_anno || group_info_array.empty()) {
loop_stack_.pop_back();
return for_node;
}
......@@ -1017,7 +1021,7 @@ private:
};
std::vector<SyncPattern>
CreateBaseSyncPairs(Array<Stmt> seq_stmt,
CreateBaseSyncPairs(const Array<Stmt> &seq_stmt,
const std::vector<bool> &is_producer) {
const int n = seq_stmt.size();
std::vector<std::set<const BufferNode *>> reads, writes;
......@@ -1132,7 +1136,7 @@ private:
return sync_pattern_cleaned;
}
SyncPatternMap ExtractSyncPattern(Array<Stmt> seq_stmt) {
SyncPatternMap ExtractSyncPattern(const Array<Stmt> &seq_stmt) {
size_t num_stmts = seq_stmt.size();
std::vector<bool> is_producer;
is_producer.reserve(num_stmts);
......@@ -1165,7 +1169,7 @@ private:
std::vector<int> cur_consumer_barrier, cur_producer_barrier;
for (int i = num_stmts - 1; i >= 0; i--) {
if (is_producer[i]) {
if (map.release[i].size() == 0) {
if (map.release[i].empty()) {
for (auto pattern_idx : cur_producer_barrier) {
map.release[i].push_back(pattern_idx);
map.release_after[i].push_back(false);
......@@ -1176,7 +1180,7 @@ private:
}
}
} else {
if (map.release[i].size() == 0) {
if (map.release[i].empty()) {
for (auto pattern_idx : cur_consumer_barrier) {
map.release[i].push_back(pattern_idx);
map.release_after[i].push_back(false);
......@@ -1405,7 +1409,7 @@ private:
class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
public:
// return true means this aws will be disabled
static bool Detect(Stmt stmt, bool skip_thread_partition = false) {
static bool Detect(const Stmt &stmt, bool skip_thread_partition = false) {
WarpSpecializedDetector detector;
detector.VisitStmt(stmt);
if (detector.has_warp_specialization_) {
......@@ -1472,7 +1476,7 @@ private:
using namespace tir::transform;
tvm::transform::Pass WarpSpecialized() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
bool disable_warp_specialized =
ctx->GetConfig<Bool>(kDisableWarpSpecialized, Bool(false)).value();
bool disable_shuffle_elect =
......
......@@ -10,6 +10,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <utility>
#include "../op/builtin.h"
namespace tvm {
......@@ -17,7 +19,7 @@ namespace tl {
using namespace tir;
bool isGemm(Stmt stmt) {
bool isGemm(const Stmt &stmt) {
bool is_gemm = false;
if (stmt.as<EvaluateNode>()) {
auto call = Downcast<Evaluate>(stmt)->value.as<CallNode>();
......@@ -33,7 +35,7 @@ bool isGemm(Stmt stmt) {
return is_gemm;
}
bool isGemmSync(Stmt stmt) {
bool isGemmSync(const Stmt &stmt) {
bool is_gemm_sync = false;
if (stmt.as<EvaluateNode>()) {
auto call = Downcast<Evaluate>(stmt)->value.as<CallNode>();
......@@ -49,7 +51,7 @@ bool isGemmSync(Stmt stmt) {
return is_gemm_sync;
}
bool isArriveBarrier(Stmt stmt) {
bool isArriveBarrier(const Stmt &stmt) {
bool is_arrive_barrier = false;
if (stmt.as<EvaluateNode>()) {
auto call = Downcast<Evaluate>(stmt)->value.as<CallNode>();
......@@ -216,7 +218,8 @@ private:
gemm_count++;
} else if (isGemmSync(new_seq[i])) {
auto call = Downcast<Evaluate>(new_seq[i])->value.as<CallNode>();
auto sync_index = Downcast<IntImm>(call->args[1])->value;
auto sync_index =
static_cast<int>(Downcast<IntImm>(call->args[1])->value);
auto wait_count = gemm_count - sync_index - 1;
if (sync_index > max_sync_index)
max_sync_index = sync_index;
......@@ -257,8 +260,8 @@ private:
using namespace tir::transform;
tvm::transform::Pass RewriteWgmmaSync() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return WgmmaSyncRewriter::Substitute(f);
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return WgmmaSyncRewriter::Substitute(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.RewriteWgmmaSync", {});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment