"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "6c68e473cdc4b4abc59aea87e30072c45017a636"
Unverified Commit cdc5d8d3 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Introduce clang-tidy into format.sh (#777)

* [Refactor] Update Clang-Tidy Checks and Improve Code Consistency

- Enhanced .clang-tidy configuration by adding specific checks for better bug detection and performance optimization.
- Refactored function signatures across multiple files to use `const` references for parameters, improving performance and code clarity.
- Updated various methods to ensure consistent handling of parameters, particularly in `AddPredicate`, `Substitute`, and `PlanLoopPartition` functions.
- Improved readability by replacing size checks with `empty()` method calls in several locations, ensuring clearer intent in the code.
- General code cleanup and adherence to best practices for better maintainability.

* [Refactor] Enhance Code Consistency and Clang-Tidy Configuration

- Updated .clang-tidy configuration to include additional checks for improved code quality and performance.
- Refactored function signatures across multiple files to use `const` references, enhancing performance and clarity.
- Replaced size checks with `empty()` method calls in various locations for clearer intent.
- Improved handling of parameters in several functions, ensuring consistent usage of `std::move` where applicable.
- General code cleanup to adhere to best practices and improve maintainability.

* [Refactor] Integrate Clang-Tidy Checks and Enhance Code Consistency

- Added clang-tidy checks to the format script for improved code quality assurance.
- Refactored function signatures across multiple files to consistently use `const` references, enhancing performance and clarity.
- Updated the requirements-lint.txt file to include clang-tidy as a dependency.
- General code cleanup to adhere to best practices and improve maintainability.

* [CI] Update AMD CI Workflow to Include Build Directory Creation

- Added steps to create a build directory and configure CMake with ROCm support during the format check process.
- Ensured cleanup of the build directory after the format check to maintain a clean workspace.

* [Refactor] Remove Unused Member Variables in AtomicAddNode and CopyNode

- Removed the `args_` member variable from both `AtomicAddNode` and `CopyNode` classes to streamline the code and eliminate unnecessary data members.
- This change enhances code clarity and maintainability by focusing on relevant attributes for each class.

* [Refactor] Update Clang-Tidy Integration and Code Improvements

- Modified the format script to include the `-fix` option in the clang-tidy command for automatic code fixes.
- Refactored the `AtomicAddVectorizePlanner` class to improve variable handling and consistency, including changes to member variable types and function signatures.
- Enhanced code clarity by removing unnecessary `std::move` calls and ensuring consistent usage of types across the class.
- General code cleanup to adhere to best practices and improve maintainability.

* [Refactor] Improve Parameter Handling and Consistency in AtomicAddVectorize

- Updated function signatures in `AtomicAddVectorizePlanResult` and `AtomicAddVectorizeRewriter` to use `const` references and `std::move` for better performance and clarity.
- Enhanced the `UpdateVectorSize` method to accept `const Array<PrimExpr>&` for improved efficiency.
- General code cleanup to maintain consistency and adhere to best practices.

* [CI] Add Git Submodule Initialization to CI Workflow

- Included a step to initialize and update git submodules recursively in the CI workflow.
- This change ensures that all necessary submodules are available during the format check process, improving build reliability.

* [CI] Add Git Submodule Update Step to Format Check

- Included a command to initialize and update git submodules recursively in the CI workflow during the format check process.
- This enhancement ensures that all required submodules are available, contributing to improved build reliability.

* [Refactor] Update Function Signatures in AtomicAddVectorize

- Modified the `VectorizeAtomicAdd` function signature to use `const` references for `thread_var` and `thread_bounds`, enhancing performance and code clarity.
- This change aligns with previous refactoring efforts to improve parameter handling and consistency across the codebase.
parent 471cc7f8
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <unordered_map> #include <unordered_map>
#include <utility>
#include <vector> #include <vector>
#include "arith/scalable_expression.h" #include "arith/scalable_expression.h"
...@@ -127,7 +128,7 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { ...@@ -127,7 +128,7 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) {
class TLVecAllocAccess : public StmtExprMutator { class TLVecAllocAccess : public StmtExprMutator {
public: public:
TLVecAllocAccess(const VarNode *buf, Var var, PrimExpr var_lanes) TLVecAllocAccess(const VarNode *buf, Var var, PrimExpr var_lanes)
: buf_(buf), var_(var), var_lanes_(var_lanes) {} : buf_(buf), var_(std::move(var)), var_lanes_(std::move(var_lanes)) {}
PrimExpr VisitExpr_(const BufferLoadNode *op) final { PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
...@@ -207,7 +208,8 @@ public: ...@@ -207,7 +208,8 @@ public:
using ExprFunctor::VisitExpr; using ExprFunctor::VisitExpr;
using StmtMutator::operator(); using StmtMutator::operator();
TLVectorizer(Var var, PrimExpr var_lanes) : var_(var), var_lanes_(var_lanes) { TLVectorizer(const Var &var, const PrimExpr &var_lanes)
: var_(var), var_lanes_(var_lanes) {
ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes); ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes);
} }
...@@ -227,11 +229,13 @@ public: ...@@ -227,11 +229,13 @@ public:
} }
PrimExpr VisitExpr_(const AddNode *op) final { PrimExpr VisitExpr_(const AddNode *op) final {
return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a + b; }); return AddSubVec(
op, [](PrimExpr a, PrimExpr b) { return std::move(a) + std::move(b); });
} }
PrimExpr VisitExpr_(const SubNode *op) final { PrimExpr VisitExpr_(const SubNode *op) final {
return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a - b; }); return AddSubVec(
op, [](PrimExpr a, PrimExpr b) { return std::move(a) - std::move(b); });
} }
PrimExpr VisitExpr_(const MulNode *op) final { PrimExpr VisitExpr_(const MulNode *op) final {
...@@ -712,7 +716,7 @@ private: ...@@ -712,7 +716,7 @@ private:
// mutate array, with given lane requirement // mutate array, with given lane requirement
// when finished, p_lane updates the lane requirement. // when finished, p_lane updates the lane requirement.
Array<PrimExpr> MutateArray(Array<PrimExpr> arr, int *p_lanes) { Array<PrimExpr> MutateArray(Array<PrimExpr> arr, int *p_lanes) {
if (arr.size() == 0) if (arr.empty())
return arr; return arr;
int &lanes = *p_lanes; int &lanes = *p_lanes;
bool changed = false; bool changed = false;
...@@ -826,7 +830,7 @@ Stmt SkipVectorize(Stmt stmt) { return VectorizeSkipper()(std::move(stmt)); } ...@@ -826,7 +830,7 @@ Stmt SkipVectorize(Stmt stmt) { return VectorizeSkipper()(std::move(stmt)); }
tvm::transform::Pass VectorizeLoop(bool enable_vectorize = true) { tvm::transform::Pass VectorizeLoop(bool enable_vectorize = true) {
using namespace tir::transform; using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
auto *n = f.CopyOnWrite(); auto *n = f.CopyOnWrite();
if (enable_vectorize) { if (enable_vectorize) {
n->body = tvm::tl::LoopVectorizer()(std::move(n->body)); n->body = tvm::tl::LoopVectorizer()(std::move(n->body));
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <utility>
#include "../op/builtin.h" #include "../op/builtin.h"
#include "./common/collector.h" #include "./common/collector.h"
#include "runtime/thread_storage_scope.h" #include "runtime/thread_storage_scope.h"
...@@ -30,13 +32,13 @@ struct LoopInfo { ...@@ -30,13 +32,13 @@ struct LoopInfo {
PrimExpr min; PrimExpr min;
}; };
enum class Role { kConsumer, kProducer, kBoth }; enum class Role : uint8_t { kConsumer, kProducer, kBoth };
class ProducerBufferDetector : public StmtExprVisitor { class ProducerBufferDetector : public StmtExprVisitor {
public: public:
ProducerBufferDetector( ProducerBufferDetector(
std::unordered_set<const BufferNode *> cur_producer_buffers) std::unordered_set<const BufferNode *> cur_producer_buffers)
: cur_producer_buffers_(cur_producer_buffers) {} : cur_producer_buffers_(std::move(cur_producer_buffers)) {}
void clear() { has_producer_buffer_ = false; } void clear() { has_producer_buffer_ = false; }
...@@ -60,7 +62,7 @@ public: ...@@ -60,7 +62,7 @@ public:
class ProducerUsedBufferFinder : public StmtExprVisitor { class ProducerUsedBufferFinder : public StmtExprVisitor {
public: public:
auto FindProducerusedBuffer(Stmt stmt) { auto FindProducerusedBuffer(const Stmt &stmt) {
producer_buffers_.clear(); producer_buffers_.clear();
std::unordered_set<const BufferNode *> last_producer_buffers_; std::unordered_set<const BufferNode *> last_producer_buffers_;
for (;;) { for (;;) {
...@@ -128,7 +130,7 @@ private: ...@@ -128,7 +130,7 @@ private:
class WarpSpecializedRoleMarker : public StmtVisitor { class WarpSpecializedRoleMarker : public StmtVisitor {
public: public:
WarpSpecializedRoleMarker(Map<Var, Buffer> buffer_data_to_buffer) WarpSpecializedRoleMarker(Map<Var, Buffer> buffer_data_to_buffer)
: buffer_data_to_buffer_(buffer_data_to_buffer) {} : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)) {}
void Prepare(const Stmt &stmt) { void Prepare(const Stmt &stmt) {
ProducerUsedBufferFinder finder; ProducerUsedBufferFinder finder;
...@@ -248,12 +250,12 @@ private: ...@@ -248,12 +250,12 @@ private:
}; };
static PrimExpr makeGetBarrier(PrimExpr barrier_id) { static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
return Call(DataType::Handle(), get_mbarrier(), {barrier_id}); return Call(DataType::Handle(), get_mbarrier(), {std::move(barrier_id)});
} }
static Stmt makeArriveBarrier(PrimExpr barrier_id, int cta_id = -1, static Stmt makeArriveBarrier(PrimExpr barrier_id, int cta_id = -1,
PrimExpr pred = 1) { const PrimExpr &pred = 1) {
Array<PrimExpr> args = {makeGetBarrier(barrier_id)}; Array<PrimExpr> args = {makeGetBarrier(std::move(barrier_id))};
if (cta_id != -1) { if (cta_id != -1) {
args.push_back(cta_id); args.push_back(cta_id);
args.push_back(pred); args.push_back(pred);
...@@ -264,13 +266,13 @@ static Stmt makeArriveBarrier(PrimExpr barrier_id, int cta_id = -1, ...@@ -264,13 +266,13 @@ static Stmt makeArriveBarrier(PrimExpr barrier_id, int cta_id = -1,
static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) { static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) {
auto call = Call(DataType::Handle(), builtin::ptx_cp_async_barrier(), auto call = Call(DataType::Handle(), builtin::ptx_cp_async_barrier(),
{makeGetBarrier(barrier_id)}); {makeGetBarrier(std::move(barrier_id))});
return Evaluate(call); return Evaluate(call);
} }
static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) { static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) {
auto call = Call(DataType::Handle(), mbarrier_wait_parity(), auto call = Call(DataType::Handle(), mbarrier_wait_parity(),
{makeGetBarrier(barrier_id), parity}); {makeGetBarrier(std::move(barrier_id)), std::move(parity)});
return Evaluate(call); return Evaluate(call);
} }
...@@ -280,7 +282,7 @@ public: ...@@ -280,7 +282,7 @@ public:
void Clear() { has_simt_copy = false; } void Clear() { has_simt_copy = false; }
void Collect(Stmt stmt) { VisitStmt(stmt); } void Collect(const Stmt &stmt) { VisitStmt(stmt); }
bool HasSimtCopy() { return has_simt_copy; } bool HasSimtCopy() { return has_simt_copy; }
...@@ -304,7 +306,7 @@ private: ...@@ -304,7 +306,7 @@ private:
StmtExprVisitor::VisitExpr_(op); StmtExprVisitor::VisitExpr_(op);
} }
bool has_simt_copy; bool has_simt_copy{};
bool in_if_cond_ = false; bool in_if_cond_ = false;
}; };
...@@ -313,8 +315,8 @@ class MbarrierRewriter : public StmtExprMutator { ...@@ -313,8 +315,8 @@ class MbarrierRewriter : public StmtExprMutator {
public: public:
static Stmt Rewrite(Stmt stmt, PrimExpr barrier_id) { static Stmt Rewrite(Stmt stmt, PrimExpr barrier_id) {
MbarrierRewriter rewriter; MbarrierRewriter rewriter;
rewriter.producer_barrier_idx_ = barrier_id; rewriter.producer_barrier_idx_ = std::move(barrier_id);
return rewriter(stmt); return rewriter(std::move(stmt));
} }
private: private:
...@@ -345,15 +347,16 @@ public: ...@@ -345,15 +347,16 @@ public:
static Stmt Rewrite(Stmt stmt, Var thread_var, PrimExpr replaced, static Stmt Rewrite(Stmt stmt, Var thread_var, PrimExpr replaced,
PrimExpr thread_extent, bool do_shuffle = false) { PrimExpr thread_extent, bool do_shuffle = false) {
auto rewriter = auto rewriter =
ThreadIdxRewriter(thread_var, replaced, thread_extent, do_shuffle); ThreadIdxRewriter(std::move(thread_var), std::move(replaced),
return rewriter(stmt); std::move(thread_extent), do_shuffle);
return rewriter(std::move(stmt));
} }
private: private:
ThreadIdxRewriter(Var thread_var, PrimExpr replaced, PrimExpr thread_extent, ThreadIdxRewriter(Var thread_var, PrimExpr replaced, PrimExpr thread_extent,
bool do_shuffle) bool do_shuffle)
: thread_var_(thread_var), replaced_(replaced), : thread_var_(std::move(thread_var)), replaced_(std::move(replaced)),
thread_extent_(thread_extent), do_shuffle_(do_shuffle) {} thread_extent_(std::move(thread_extent)), do_shuffle_(do_shuffle) {}
PrimExpr VisitExpr_(const VarNode *var) final { PrimExpr VisitExpr_(const VarNode *var) final {
if (var == thread_var_.get()) { if (var == thread_var_.get()) {
...@@ -415,15 +418,16 @@ Block MakeGroupBlock(const Stmt &stmt, ...@@ -415,15 +418,16 @@ Block MakeGroupBlock(const Stmt &stmt,
} }
struct OpInfo { struct OpInfo {
int group_size, order, stage; int group_size{}, order{}, stage{};
std::vector<int> group; std::vector<int> group;
}; };
struct PipelineInfo { struct PipelineInfo {
std::vector<OpInfo> op_infos; std::vector<OpInfo> op_infos;
PipelineInfo() = default; PipelineInfo() = default;
PipelineInfo(Array<Array<Integer>> group_info, Array<Integer> order_info, PipelineInfo(const Array<Array<Integer>> &group_info,
Array<Integer> stage_info) { const Array<Integer> &order_info,
const Array<Integer> &stage_info) {
int n = static_cast<int>(group_info.size()); int n = static_cast<int>(group_info.size());
ICHECK(n == static_cast<int>(order_info.size())); ICHECK(n == static_cast<int>(order_info.size()));
ICHECK(n == static_cast<int>(stage_info.size())); ICHECK(n == static_cast<int>(stage_info.size()));
...@@ -441,7 +445,7 @@ struct PipelineInfo { ...@@ -441,7 +445,7 @@ struct PipelineInfo {
} }
PipelineInfo(const PipelineInfo &other) { PipelineInfo(const PipelineInfo &other) {
for (auto op_info : other.op_infos) { for (const auto &op_info : other.op_infos) {
op_infos.push_back(op_info); op_infos.push_back(op_info);
} }
} }
...@@ -501,18 +505,19 @@ struct PipelineInfo { ...@@ -501,18 +505,19 @@ struct PipelineInfo {
} }
void PrintPipelineInfo() { void PrintPipelineInfo() {
std::cout << "Print op_infos:" << std::endl; std::cout << "Print op_infos:" << '\n';
for (size_t i = 0; i < op_infos.size(); i++) { for (size_t i = 0; i < op_infos.size(); i++) {
std::cout << i << " " << op_infos[i].group_size << " " std::cout << i << " " << op_infos[i].group_size << " "
<< op_infos[i].order << " " << op_infos[i].stage << std::endl; << op_infos[i].order << " " << op_infos[i].stage << '\n';
} }
std::cout << "End of print" << std::endl; std::cout << "End of print" << '\n';
} }
}; };
class GroupOpRewriter : public StmtExprMutator { class GroupOpRewriter : public StmtExprMutator {
public: public:
GroupOpRewriter(PipelineInfo pipeline_info) : pipeline_info_(pipeline_info) {} GroupOpRewriter(const PipelineInfo &pipeline_info)
: pipeline_info_(pipeline_info) {}
private: private:
Stmt VisitStmt_(const ForNode *op) final { Stmt VisitStmt_(const ForNode *op) final {
...@@ -546,7 +551,7 @@ private: ...@@ -546,7 +551,7 @@ private:
} }
Array<Integer> order_anno; Array<Integer> order_anno;
Array<Integer> stage_anno; Array<Integer> stage_anno;
for (auto op_info : pipeline_info_.op_infos) { for (const auto &op_info : pipeline_info_.op_infos) {
order_anno.push_back(Integer(op_info.order)); order_anno.push_back(Integer(op_info.order));
stage_anno.push_back(Integer(op_info.stage)); stage_anno.push_back(Integer(op_info.stage));
} }
...@@ -588,7 +593,7 @@ public: ...@@ -588,7 +593,7 @@ public:
in_if_scope_ = false; in_if_scope_ = false;
} }
static bool HasWgMMA(Stmt stmt) { static bool HasWgMMA(const Stmt &stmt) {
auto collector = WgMMACollector(); auto collector = WgMMACollector();
collector(stmt); collector(stmt);
return collector.has_wgmma_; return collector.has_wgmma_;
...@@ -629,14 +634,14 @@ public: ...@@ -629,14 +634,14 @@ public:
* @param only_has_wgmma If true, adjust emission and barrier-thread-count * @param only_has_wgmma If true, adjust emission and barrier-thread-count
* logic for blocks that contain WgMMA operations. * logic for blocks that contain WgMMA operations.
*/ */
WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv, WSCodeEmitter(bool is_emitting_producer, const IterVar &thread_iv,
Map<Var, Buffer> buffer_data_to_buffer, Map<Var, Buffer> buffer_data_to_buffer,
const WarpSpecializedRoleMarker &marker, const WarpSpecializedRoleMarker &marker,
bool mbarrier_only = false, bool only_has_wgmma = false) bool mbarrier_only = false, bool only_has_wgmma = false)
: is_emitting_producer_(is_emitting_producer), : is_emitting_producer_(is_emitting_producer),
buffer_data_to_buffer_(buffer_data_to_buffer), marker_(marker), buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only), marker_(marker), thread_var_(thread_iv->var),
only_has_wgmma_(only_has_wgmma) {} mbarrier_only_(mbarrier_only), only_has_wgmma_(only_has_wgmma) {}
/** /**
* @brief Whether a SIMT-style bulk copy was detected. * @brief Whether a SIMT-style bulk copy was detected.
...@@ -757,7 +762,7 @@ private: ...@@ -757,7 +762,7 @@ private:
return FilterByRole(op); return FilterByRole(op);
auto seq_transformed = auto seq_transformed =
op->seq.Map([&](Stmt stmt) { return VisitStmt(stmt); }); op->seq.Map([&](const Stmt &stmt) { return VisitStmt(stmt); });
auto map = ExtractSyncPattern(op->seq); auto map = ExtractSyncPattern(op->seq);
...@@ -804,7 +809,7 @@ private: ...@@ -804,7 +809,7 @@ private:
: parity_; : parity_;
block_stmt.push_back(makeParityWait(acquire_barrier_id, parity)); block_stmt.push_back(makeParityWait(acquire_barrier_id, parity));
} }
ICHECK(map.release[i].size() > 0); ICHECK(!map.release[i].empty());
for (size_t j = 0; j < map.release[i].size(); j++) { for (size_t j = 0; j < map.release[i].size(); j++) {
int pattern_idx = map.release[i][j]; int pattern_idx = map.release[i][j];
PrimExpr release_barrier_id = PrimExpr release_barrier_id =
...@@ -890,7 +895,7 @@ private: ...@@ -890,7 +895,7 @@ private:
num_barriers_ += map.patterns.size() * num_stages_; num_barriers_ += map.patterns.size() * num_stages_;
ICHECK(new_body.size() > 0); ICHECK(!new_body.empty());
return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)); return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body));
} }
...@@ -923,8 +928,8 @@ private: ...@@ -923,8 +928,8 @@ private:
PipelineInfo pipeline_info(group_info_array, order_info_array, PipelineInfo pipeline_info(group_info_array, order_info_array,
stage_info_array); stage_info_array);
if (pipeline_info.op_infos.size() > 0) { if (!pipeline_info.op_infos.empty()) {
ICHECK(pipeline_info_.op_infos.size() == 0) ICHECK(pipeline_info_.op_infos.empty())
<< "Nested pipeline not supported."; << "Nested pipeline not supported.";
} }
...@@ -946,7 +951,7 @@ private: ...@@ -946,7 +951,7 @@ private:
auto result = FilterByRole(op); auto result = FilterByRole(op);
Stmt grouped_for_node; Stmt grouped_for_node;
if (result.as<ForNode>() && group_anno && group_info_array.size() > 0 && if (result.as<ForNode>() && group_anno && !group_info_array.empty() &&
!is_emitting_producer_) { !is_emitting_producer_) {
GroupOpRewriter group_op_rewriter(pipeline_info_); GroupOpRewriter group_op_rewriter(pipeline_info_);
auto for_node = Downcast<For>(result); auto for_node = Downcast<For>(result);
...@@ -963,12 +968,11 @@ private: ...@@ -963,12 +968,11 @@ private:
if (result.as<ForNode>()) { if (result.as<ForNode>()) {
auto for_node = Downcast<For>(result); auto for_node = Downcast<For>(result);
for_node.CopyOnWrite()->annotations.erase("num_stages"); for_node.CopyOnWrite()->annotations.erase("num_stages");
if (is_emitting_producer_ || group_info_array.size() == 0) { if (is_emitting_producer_ || group_info_array.empty()) {
for_node.CopyOnWrite()->annotations.erase("tl_pipeline_order"); for_node.CopyOnWrite()->annotations.erase("tl_pipeline_order");
for_node.CopyOnWrite()->annotations.erase("tl_pipeline_stage"); for_node.CopyOnWrite()->annotations.erase("tl_pipeline_stage");
} }
if (is_emitting_producer_ || !group_anno || if (is_emitting_producer_ || !group_anno || group_info_array.empty()) {
group_info_array.size() == 0) {
loop_stack_.pop_back(); loop_stack_.pop_back();
return for_node; return for_node;
} }
...@@ -1017,7 +1021,7 @@ private: ...@@ -1017,7 +1021,7 @@ private:
}; };
std::vector<SyncPattern> std::vector<SyncPattern>
CreateBaseSyncPairs(Array<Stmt> seq_stmt, CreateBaseSyncPairs(const Array<Stmt> &seq_stmt,
const std::vector<bool> &is_producer) { const std::vector<bool> &is_producer) {
const int n = seq_stmt.size(); const int n = seq_stmt.size();
std::vector<std::set<const BufferNode *>> reads, writes; std::vector<std::set<const BufferNode *>> reads, writes;
...@@ -1132,7 +1136,7 @@ private: ...@@ -1132,7 +1136,7 @@ private:
return sync_pattern_cleaned; return sync_pattern_cleaned;
} }
SyncPatternMap ExtractSyncPattern(Array<Stmt> seq_stmt) { SyncPatternMap ExtractSyncPattern(const Array<Stmt> &seq_stmt) {
size_t num_stmts = seq_stmt.size(); size_t num_stmts = seq_stmt.size();
std::vector<bool> is_producer; std::vector<bool> is_producer;
is_producer.reserve(num_stmts); is_producer.reserve(num_stmts);
...@@ -1165,7 +1169,7 @@ private: ...@@ -1165,7 +1169,7 @@ private:
std::vector<int> cur_consumer_barrier, cur_producer_barrier; std::vector<int> cur_consumer_barrier, cur_producer_barrier;
for (int i = num_stmts - 1; i >= 0; i--) { for (int i = num_stmts - 1; i >= 0; i--) {
if (is_producer[i]) { if (is_producer[i]) {
if (map.release[i].size() == 0) { if (map.release[i].empty()) {
for (auto pattern_idx : cur_producer_barrier) { for (auto pattern_idx : cur_producer_barrier) {
map.release[i].push_back(pattern_idx); map.release[i].push_back(pattern_idx);
map.release_after[i].push_back(false); map.release_after[i].push_back(false);
...@@ -1176,7 +1180,7 @@ private: ...@@ -1176,7 +1180,7 @@ private:
} }
} }
} else { } else {
if (map.release[i].size() == 0) { if (map.release[i].empty()) {
for (auto pattern_idx : cur_consumer_barrier) { for (auto pattern_idx : cur_consumer_barrier) {
map.release[i].push_back(pattern_idx); map.release[i].push_back(pattern_idx);
map.release_after[i].push_back(false); map.release_after[i].push_back(false);
...@@ -1405,7 +1409,7 @@ private: ...@@ -1405,7 +1409,7 @@ private:
class WarpSpecializedDetector : public IRVisitorWithAnalyzer { class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
public: public:
// return true means this aws will be disabled // return true means this aws will be disabled
static bool Detect(Stmt stmt, bool skip_thread_partition = false) { static bool Detect(const Stmt &stmt, bool skip_thread_partition = false) {
WarpSpecializedDetector detector; WarpSpecializedDetector detector;
detector.VisitStmt(stmt); detector.VisitStmt(stmt);
if (detector.has_warp_specialization_) { if (detector.has_warp_specialization_) {
...@@ -1472,7 +1476,7 @@ private: ...@@ -1472,7 +1476,7 @@ private:
using namespace tir::transform; using namespace tir::transform;
tvm::transform::Pass WarpSpecialized() { tvm::transform::Pass WarpSpecialized() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
bool disable_warp_specialized = bool disable_warp_specialized =
ctx->GetConfig<Bool>(kDisableWarpSpecialized, Bool(false)).value(); ctx->GetConfig<Bool>(kDisableWarpSpecialized, Bool(false)).value();
bool disable_shuffle_elect = bool disable_shuffle_elect =
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <utility>
#include "../op/builtin.h" #include "../op/builtin.h"
namespace tvm { namespace tvm {
...@@ -17,7 +19,7 @@ namespace tl { ...@@ -17,7 +19,7 @@ namespace tl {
using namespace tir; using namespace tir;
bool isGemm(Stmt stmt) { bool isGemm(const Stmt &stmt) {
bool is_gemm = false; bool is_gemm = false;
if (stmt.as<EvaluateNode>()) { if (stmt.as<EvaluateNode>()) {
auto call = Downcast<Evaluate>(stmt)->value.as<CallNode>(); auto call = Downcast<Evaluate>(stmt)->value.as<CallNode>();
...@@ -33,7 +35,7 @@ bool isGemm(Stmt stmt) { ...@@ -33,7 +35,7 @@ bool isGemm(Stmt stmt) {
return is_gemm; return is_gemm;
} }
bool isGemmSync(Stmt stmt) { bool isGemmSync(const Stmt &stmt) {
bool is_gemm_sync = false; bool is_gemm_sync = false;
if (stmt.as<EvaluateNode>()) { if (stmt.as<EvaluateNode>()) {
auto call = Downcast<Evaluate>(stmt)->value.as<CallNode>(); auto call = Downcast<Evaluate>(stmt)->value.as<CallNode>();
...@@ -49,7 +51,7 @@ bool isGemmSync(Stmt stmt) { ...@@ -49,7 +51,7 @@ bool isGemmSync(Stmt stmt) {
return is_gemm_sync; return is_gemm_sync;
} }
bool isArriveBarrier(Stmt stmt) { bool isArriveBarrier(const Stmt &stmt) {
bool is_arrive_barrier = false; bool is_arrive_barrier = false;
if (stmt.as<EvaluateNode>()) { if (stmt.as<EvaluateNode>()) {
auto call = Downcast<Evaluate>(stmt)->value.as<CallNode>(); auto call = Downcast<Evaluate>(stmt)->value.as<CallNode>();
...@@ -216,7 +218,8 @@ private: ...@@ -216,7 +218,8 @@ private:
gemm_count++; gemm_count++;
} else if (isGemmSync(new_seq[i])) { } else if (isGemmSync(new_seq[i])) {
auto call = Downcast<Evaluate>(new_seq[i])->value.as<CallNode>(); auto call = Downcast<Evaluate>(new_seq[i])->value.as<CallNode>();
auto sync_index = Downcast<IntImm>(call->args[1])->value; auto sync_index =
static_cast<int>(Downcast<IntImm>(call->args[1])->value);
auto wait_count = gemm_count - sync_index - 1; auto wait_count = gemm_count - sync_index - 1;
if (sync_index > max_sync_index) if (sync_index > max_sync_index)
max_sync_index = sync_index; max_sync_index = sync_index;
...@@ -257,8 +260,8 @@ private: ...@@ -257,8 +260,8 @@ private:
using namespace tir::transform; using namespace tir::transform;
tvm::transform::Pass RewriteWgmmaSync() { tvm::transform::Pass RewriteWgmmaSync() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return WgmmaSyncRewriter::Substitute(f); return WgmmaSyncRewriter::Substitute(std::move(f));
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.RewriteWgmmaSync", {}); return CreatePrimFuncPass(pass_func, 0, "tl.RewriteWgmmaSync", {});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment