"git@developer.sourcefind.cn:change/sglang.git" did not exist on "7ac6b900f47dd6ae66cfbf62e09803082129a05b"
Unverified Commit cdc5d8d3 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Introduce clang-tidy into format.sh (#777)

* [Refactor] Update Clang-Tidy Checks and Improve Code Consistency

- Enhanced .clang-tidy configuration by adding specific checks for better bug detection and performance optimization.
- Refactored function signatures across multiple files to use `const` references for parameters, improving performance and code clarity.
- Updated various methods to ensure consistent handling of parameters, particularly in `AddPredicate`, `Substitute`, and `PlanLoopPartition` functions.
- Improved readability by replacing size checks with `empty()` method calls in several locations, ensuring clearer intent in the code.
- General code cleanup and adherence to best practices for better maintainability.

* [Refactor] Enhance Code Consistency and Clang-Tidy Configuration

- Updated .clang-tidy configuration to include additional checks for improved code quality and performance.
- Refactored function signatures across multiple files to use `const` references, enhancing performance and clarity.
- Replaced size checks with `empty()` method calls in various locations for clearer intent.
- Improved handling of parameters in several functions, ensuring consistent usage of `std::move` where applicable.
- General code cleanup to adhere to best practices and improve maintainability.

* [Refactor] Integrate Clang-Tidy Checks and Enhance Code Consistency

- Added clang-tidy checks to the format script for improved code quality assurance.
- Refactored function signatures across multiple files to consistently use `const` references, enhancing performance and clarity.
- Updated the requirements-lint.txt file to include clang-tidy as a dependency.
- General code cleanup to adhere to best practices and improve maintainability.

* [CI] Update AMD CI Workflow to Include Build Directory Creation

- Added steps to create a build directory and configure CMake with ROCm support during the format check process.
- Ensured cleanup of the build directory after the format check to maintain a clean workspace.

* [Refactor] Remove Unused Member Variables in AtomicAddNode and CopyNode

- Removed the `args_` member variable from both `AtomicAddNode` and `CopyNode` classes to streamline the code and eliminate unnecessary data members.
- This change enhances code clarity and maintainability by focusing on relevant attributes for each class.

* [Refactor] Update Clang-Tidy Integration and Code Improvements

- Modified the format script to include the `-fix` option in the clang-tidy command for automatic code fixes.
- Refactored the `AtomicAddVectorizePlanner` class to improve variable handling and consistency, including changes to member variable types and function signatures.
- Enhanced code clarity by removing unnecessary `std::move` calls and ensuring consistent usage of types across the class.
- General code cleanup to adhere to best practices and improve maintainability.

* [Refactor] Improve Parameter Handling and Consistency in AtomicAddVectorize

- Updated function signatures in `AtomicAddVectorizePlanResult` and `AtomicAddVectorizeRewriter` to use `const` references and `std::move` for better performance and clarity.
- Enhanced the `UpdateVectorSize` method to accept `const Array<PrimExpr>&` for improved efficiency.
- General code cleanup to maintain consistency and adhere to best practices.

* [CI] Add Git Submodule Initialization to CI Workflow

- Included a step to initialize and update git submodules recursively in the CI workflow.
- This change ensures that all necessary submodules are available during the format check process, improving build reliability.

* [CI] Add Git Submodule Update Step to Format Check

- Included a command to initialize and update git submodules recursively in the CI workflow during the format check process.
- This enhancement ensures that all required submodules are available, contributing to improved build reliability.

* [Refactor] Update Function Signatures in AtomicAddVectorize

- Modified the `VectorizeAtomicAdd` function signature to use `const` references for `thread_var` and `thread_bounds`, enhancing performance and code clarity.
- This change aligns with previous refactoring efforts to improve parameter handling and consistency across the codebase.
parent 471cc7f8
...@@ -45,7 +45,7 @@ class FragmentAccessDetector : public StmtExprVisitor { ...@@ -45,7 +45,7 @@ class FragmentAccessDetector : public StmtExprVisitor {
public: public:
FragmentAccessDetector() = default; FragmentAccessDetector() = default;
void Collect(Stmt stmt) { VisitStmt(stmt); } void Collect(const Stmt &stmt) { VisitStmt(stmt); }
bool HasFragmentAccess() { return has_fragment_access_; } bool HasFragmentAccess() { return has_fragment_access_; }
...@@ -91,7 +91,7 @@ private: ...@@ -91,7 +91,7 @@ private:
*/ */
class ParallelLoopFuser : public IRMutatorWithAnalyzer { class ParallelLoopFuser : public IRMutatorWithAnalyzer {
public: public:
static Stmt Fuse(Stmt stmt) { static Stmt Fuse(const Stmt &stmt) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
ParallelLoopFuser substituter(&analyzer); ParallelLoopFuser substituter(&analyzer);
return substituter.VisitStmt(stmt); return substituter.VisitStmt(stmt);
......
...@@ -26,7 +26,7 @@ using arith::IRVisitorWithAnalyzer; ...@@ -26,7 +26,7 @@ using arith::IRVisitorWithAnalyzer;
class ParallelLoopTransformer : public IRMutatorWithAnalyzer { class ParallelLoopTransformer : public IRMutatorWithAnalyzer {
public: public:
static Stmt Substitute(Stmt stmt, bool skip_thread_partition = false) { static Stmt Substitute(const Stmt &stmt, bool skip_thread_partition = false) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
ParallelLoopTransformer transformer(&analyzer); ParallelLoopTransformer transformer(&analyzer);
return transformer.VisitStmt(stmt); return transformer.VisitStmt(stmt);
...@@ -75,8 +75,6 @@ public: ...@@ -75,8 +75,6 @@ public:
for (size_t i = 0; i < indices.size(); ++i) { for (size_t i = 0; i < indices.size(); ++i) {
auto index = indices[i]; auto index = indices[i];
auto bound = analyzer_->const_int_bound(index); auto bound = analyzer_->const_int_bound(index);
int64_t upper_bound = bound->max_value + 1;
int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;
// Collect the variables that used in the index // Collect the variables that used in the index
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_vars; std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_vars;
...@@ -86,7 +84,7 @@ public: ...@@ -86,7 +84,7 @@ public:
used_vars.insert(GetRef<Var>(v)); used_vars.insert(GetRef<Var>(v));
} }
}); });
if (used_vars.size() == 0) { if (used_vars.empty()) {
continue; continue;
} }
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <tvm/tir/utils.h> #include <tvm/tir/utils.h>
#include <queue> #include <queue>
#include <utility>
#include "../../op/parallel.h" #include "../../op/parallel.h"
#include "../loop_partition.h" #include "../loop_partition.h"
...@@ -86,7 +87,7 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { ...@@ -86,7 +87,7 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) {
class VecAllocAccess : public StmtExprMutator { class VecAllocAccess : public StmtExprMutator {
public: public:
VecAllocAccess(const VarNode *buf, Var var, PrimExpr var_lanes) VecAllocAccess(const VarNode *buf, Var var, PrimExpr var_lanes)
: buf_(buf), var_(var), var_lanes_(var_lanes) {} : buf_(buf), var_(std::move(var)), var_lanes_(std::move(var_lanes)) {}
PrimExpr VisitExpr_(const BufferLoadNode *op) final { PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
...@@ -176,7 +177,8 @@ public: ...@@ -176,7 +177,8 @@ public:
using ExprFunctor::VisitExpr; using ExprFunctor::VisitExpr;
using StmtMutator::operator(); using StmtMutator::operator();
Vectorizer(Var var, PrimExpr var_lanes) : var_(var), var_lanes_(var_lanes) { Vectorizer(const Var &var, const PrimExpr &var_lanes)
: var_(var), var_lanes_(var_lanes) {
ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes); ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes);
} }
...@@ -196,11 +198,13 @@ public: ...@@ -196,11 +198,13 @@ public:
} }
PrimExpr VisitExpr_(const AddNode *op) final { PrimExpr VisitExpr_(const AddNode *op) final {
return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a + b; }); return AddSubVec(
op, [](PrimExpr a, PrimExpr b) { return std::move(a) + std::move(b); });
} }
PrimExpr VisitExpr_(const SubNode *op) final { PrimExpr VisitExpr_(const SubNode *op) final {
return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a - b; }); return AddSubVec(
op, [](PrimExpr a, PrimExpr b) { return std::move(a) - std::move(b); });
} }
PrimExpr VisitExpr_(const MulNode *op) final { PrimExpr VisitExpr_(const MulNode *op) final {
...@@ -704,7 +708,7 @@ private: ...@@ -704,7 +708,7 @@ private:
// mutate array, with given lane requirement // mutate array, with given lane requirement
// when finished, p_lane updates the lane requirement. // when finished, p_lane updates the lane requirement.
Array<PrimExpr> MutateArray(Array<PrimExpr> arr, int *p_lanes) { Array<PrimExpr> MutateArray(Array<PrimExpr> arr, int *p_lanes) {
if (arr.size() == 0) if (arr.empty())
return arr; return arr;
int &lanes = *p_lanes; int &lanes = *p_lanes;
bool changed = false; bool changed = false;
......
...@@ -24,7 +24,7 @@ struct ThreadBoundKey { ...@@ -24,7 +24,7 @@ struct ThreadBoundKey {
// Number of threads syncing using the barrier must be a multiple of warp-size // Number of threads syncing using the barrier must be a multiple of warp-size
// ID 0 should not be used for safety, as other driver APIs (i.e. __syncthreads) // ID 0 should not be used for safety, as other driver APIs (i.e. __syncthreads)
// may use it and conflict with other uses. // may use it and conflict with other uses.
enum class ReservedNamedBarriers { enum class ReservedNamedBarriers : uint8_t {
kSyncThreads = 0, kSyncThreads = 0,
kReduce_0 = 1, kReduce_0 = 1,
kReduce_1 = 2, kReduce_1 = 2,
......
...@@ -18,7 +18,7 @@ public: ...@@ -18,7 +18,7 @@ public:
ConfigIndexBitwidthRewriter(int index_bitwidth) ConfigIndexBitwidthRewriter(int index_bitwidth)
: _index_bitwidth_(index_bitwidth) {} : _index_bitwidth_(index_bitwidth) {}
Stmt operator()(Stmt s) { return VisitStmt(s); } Stmt operator()(const Stmt &s) { return VisitStmt(s); }
protected: protected:
using Parent::VisitExpr_; using Parent::VisitExpr_;
...@@ -73,7 +73,7 @@ protected: ...@@ -73,7 +73,7 @@ protected:
class IndexLegalizer : public IRMutatorWithAnalyzer { class IndexLegalizer : public IRMutatorWithAnalyzer {
public: public:
static Stmt Rewrite(Stmt stmt) { static Stmt Rewrite(const Stmt &stmt) {
Analyzer ana; Analyzer ana;
auto pass = IndexLegalizer(&ana); auto pass = IndexLegalizer(&ana);
return pass.VisitStmt(stmt); return pass.VisitStmt(stmt);
...@@ -158,7 +158,7 @@ private: ...@@ -158,7 +158,7 @@ private:
tvm::transform::Pass ConfigIndexBitwidth() { tvm::transform::Pass ConfigIndexBitwidth() {
using namespace tir::transform; using namespace tir::transform;
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) {
auto *n = f.CopyOnWrite(); auto *n = f.CopyOnWrite();
// Get pass config `tl.config_index_bitwidth` // Get pass config `tl.config_index_bitwidth`
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
...@@ -166,11 +166,10 @@ tvm::transform::Pass ConfigIndexBitwidth() { ...@@ -166,11 +166,10 @@ tvm::transform::Pass ConfigIndexBitwidth() {
ctxt->GetConfig(kConfigIndexBitwidth, Optional<Integer>()); ctxt->GetConfig(kConfigIndexBitwidth, Optional<Integer>());
if (opt_config_index_bitwidth.defined()) { if (opt_config_index_bitwidth.defined()) {
int config_index_bitwidth = opt_config_index_bitwidth.value()->value; int config_index_bitwidth = opt_config_index_bitwidth.value()->value;
n->body = ConfigIndexBitwidthRewriter(config_index_bitwidth)( n->body = ConfigIndexBitwidthRewriter(config_index_bitwidth)(n->body);
std::move(n->body));
} }
// Legalize out-of-bound indices to be int64 // Legalize out-of-bound indices to be int64
n->body = IndexLegalizer::Rewrite(std::move(n->body)); n->body = IndexLegalizer::Rewrite(n->body);
return f; return f;
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.ConfigIndexBitwidth", {}); return CreatePrimFuncPass(pass_func, 0, "tl.ConfigIndexBitwidth", {});
......
...@@ -22,7 +22,7 @@ using arith::IRVisitorWithAnalyzer; ...@@ -22,7 +22,7 @@ using arith::IRVisitorWithAnalyzer;
class Eliminator : public IRMutatorWithAnalyzer { class Eliminator : public IRMutatorWithAnalyzer {
public: public:
static Stmt Substitute(Stmt stmt, bool skip_thread_partition = false) { static Stmt Substitute(const Stmt &stmt, bool skip_thread_partition = false) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
Eliminator transformer(&analyzer); Eliminator transformer(&analyzer);
return transformer.VisitStmt(stmt); return transformer.VisitStmt(stmt);
...@@ -37,7 +37,7 @@ public: ...@@ -37,7 +37,7 @@ public:
if (op->attr_key == "thread_extent") { if (op->attr_key == "thread_extent") {
const VarNode *var = nullptr; const VarNode *var = nullptr;
if (op->node->IsInstance<VarNode>()) { if (op->node->IsInstance<VarNode>()) {
var = static_cast<const VarNode *>(op->node.get()); var = op->node.as<VarNode>();
if (var->name_hint == "threadIdx.x") { if (var->name_hint == "threadIdx.x") {
thread_extent_ = op; thread_extent_ = op;
} }
...@@ -49,7 +49,7 @@ public: ...@@ -49,7 +49,7 @@ public:
Stmt VisitStmt_(const EvaluateNode *op) final { Stmt VisitStmt_(const EvaluateNode *op) final {
const CallNode *call = nullptr; const CallNode *call = nullptr;
if (op->value->IsInstance<CallNode>()) { if (op->value->IsInstance<CallNode>()) {
call = static_cast<const CallNode *>(op->value.get()); call = op->value.as<CallNode>();
if (call->op.same_as(builtin::tvm_storage_sync())) { if (call->op.same_as(builtin::tvm_storage_sync())) {
// Skip storage sync if we're in a region with mbarrier operations // Skip storage sync if we're in a region with mbarrier operations
// and we're not in a for loop with mbarrier operations // and we're not in a for loop with mbarrier operations
...@@ -107,9 +107,9 @@ using namespace tir::transform; ...@@ -107,9 +107,9 @@ using namespace tir::transform;
namespace transform { namespace transform {
tvm::transform::Pass EliminateStorageSyncForMBarrier() { tvm::transform::Pass EliminateStorageSyncForMBarrier() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) {
auto *n = f.CopyOnWrite(); auto *n = f.CopyOnWrite();
n->body = Eliminator::Substitute(std::move(n->body)); n->body = Eliminator::Substitute(n->body);
return f; return f;
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.EliminateStorageSyncForMBarrier", return CreatePrimFuncPass(pass_func, 0, "tl.EliminateStorageSyncForMBarrier",
......
...@@ -30,6 +30,8 @@ ...@@ -30,6 +30,8 @@
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <utility>
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -73,21 +75,23 @@ private: ...@@ -73,21 +75,23 @@ private:
Array<Buffer> alloc_buffers = op->alloc_buffers; Array<Buffer> alloc_buffers = op->alloc_buffers;
alloc_buffers.MutateByApply( alloc_buffers.MutateByApply(
[this](Buffer buf) { return GetFlattenedBuffer(buf); }); [this](const Buffer &buf) { return GetFlattenedBuffer(buf); });
if (!alloc_buffers.same_as(op->alloc_buffers)) { if (!alloc_buffers.same_as(op->alloc_buffers)) {
block.CopyOnWrite()->alloc_buffers = alloc_buffers; block.CopyOnWrite()->alloc_buffers = alloc_buffers;
} }
Array<BufferRegion> reads = op->reads; Array<BufferRegion> reads = op->reads;
reads.MutateByApply( reads.MutateByApply([this](BufferRegion region) {
[this](BufferRegion region) { return MutateBufferRegion(region); }); return MutateBufferRegion(std::move(region));
});
if (!reads.same_as(op->reads)) { if (!reads.same_as(op->reads)) {
block.CopyOnWrite()->reads = reads; block.CopyOnWrite()->reads = reads;
} }
Array<BufferRegion> writes = op->writes; Array<BufferRegion> writes = op->writes;
writes.MutateByApply( writes.MutateByApply([this](BufferRegion region) {
[this](BufferRegion region) { return MutateBufferRegion(region); }); return MutateBufferRegion(std::move(region));
});
if (!writes.same_as(op->writes)) { if (!writes.same_as(op->writes)) {
block.CopyOnWrite()->writes = writes; block.CopyOnWrite()->writes = writes;
} }
...@@ -169,7 +173,7 @@ private: ...@@ -169,7 +173,7 @@ private:
return VisitStmt(op->body); return VisitStmt(op->body);
} }
Buffer GetFlattenedBuffer(Buffer buf) { Buffer GetFlattenedBuffer(const Buffer &buf) {
auto it = buffer_remap_.find(buf); auto it = buffer_remap_.find(buf);
if (it != buffer_remap_.end()) { if (it != buffer_remap_.end()) {
return it->second; return it->second;
...@@ -294,12 +298,12 @@ private: ...@@ -294,12 +298,12 @@ private:
}; };
PrimFunc FlattenBufferRewriter(PrimFunc f) { PrimFunc FlattenBufferRewriter(PrimFunc f) {
return BufferFlattener::Flatten(f); return BufferFlattener::Flatten(std::move(f));
} }
using namespace tir::transform; using namespace tir::transform;
tvm::transform::Pass FlattenBuffer() { tvm::transform::Pass FlattenBuffer() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return FlattenBufferRewriter(std::move(f)); return FlattenBufferRewriter(std::move(f));
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.FlattenBuffer", {}); return CreatePrimFuncPass(pass_func, 0, "tl.FlattenBuffer", {});
......
...@@ -83,7 +83,7 @@ private: ...@@ -83,7 +83,7 @@ private:
using namespace tir::transform; using namespace tir::transform;
Pass FrontendLegalize() { Pass FrontendLegalize() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return FrontendLegalizer::Substitute(std::move(f)); return FrontendLegalizer::Substitute(std::move(f));
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.FrontendLegalize", {}); return CreatePrimFuncPass(pass_func, 0, "tl.FrontendLegalize", {});
......
...@@ -38,8 +38,8 @@ private: ...@@ -38,8 +38,8 @@ private:
ICHECK(then_case.defined()) << "then_case must be defined"; ICHECK(then_case.defined()) << "then_case must be defined";
ICHECK(!else_case.defined()) << "else_case must be undefined"; ICHECK(!else_case.defined()) << "else_case must be undefined";
auto bind_if_stmt = [](Optional<Stmt> body, auto bind_if_stmt = [](const Optional<Stmt> &body,
const PrimExpr condition) -> Stmt { const PrimExpr &condition) -> Stmt {
if (body.defined()) { if (body.defined()) {
auto stmt = body.value(); auto stmt = body.value();
if (auto seq_stmt = stmt.as<SeqStmtNode>()) { if (auto seq_stmt = stmt.as<SeqStmtNode>()) {
...@@ -75,7 +75,7 @@ private: ...@@ -75,7 +75,7 @@ private:
using namespace tir::transform; using namespace tir::transform;
tvm::transform::Pass IfStmtBinding() { tvm::transform::Pass IfStmtBinding() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return IfStmtBindingRewriter::Substitute(f); return IfStmtBindingRewriter::Substitute(f);
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.IfStmtBinding", {}); return CreatePrimFuncPass(pass_func, 0, "tl.IfStmtBinding", {});
......
...@@ -36,7 +36,7 @@ namespace tl { ...@@ -36,7 +36,7 @@ namespace tl {
using namespace tir; using namespace tir;
enum class Proxy { kGeneric, kAsync, kBoth }; enum class Proxy : uint8_t { kGeneric, kAsync, kBoth };
class ProxyMarker : public StmtVisitor { class ProxyMarker : public StmtVisitor {
public: public:
...@@ -155,7 +155,7 @@ private: ...@@ -155,7 +155,7 @@ private:
} }
Stmt VisitStmt_(const SeqStmtNode *op) final { Stmt VisitStmt_(const SeqStmtNode *op) final {
ICHECK(op->seq.size() > 0); ICHECK(!op->seq.empty());
Array<Stmt> new_body; Array<Stmt> new_body;
Proxy cur_proxy, prev_proxy; Proxy cur_proxy, prev_proxy;
auto fence_stmt = auto fence_stmt =
...@@ -172,7 +172,7 @@ private: ...@@ -172,7 +172,7 @@ private:
prev_proxy = cur_proxy; prev_proxy = cur_proxy;
} }
} }
ICHECK(new_body.size() > 0); ICHECK(!new_body.empty());
return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)); return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body));
} }
...@@ -187,7 +187,7 @@ private: ...@@ -187,7 +187,7 @@ private:
using namespace tir::transform; using namespace tir::transform;
tvm::transform::Pass InjectFenceProxy() { tvm::transform::Pass InjectFenceProxy() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
f = TMAStoreSyncInjector::Substitute(f); f = TMAStoreSyncInjector::Substitute(f);
return InjectFenceProxy::Substitute(f); return InjectFenceProxy::Substitute(f);
}; };
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include "support/utils.h" #include "support/utils.h"
#include "tir/schedule/utils.h" #include "tir/schedule/utils.h"
...@@ -104,7 +105,7 @@ public: ...@@ -104,7 +105,7 @@ public:
const Map<Buffer, Buffer> &buffer_remap, const Map<Buffer, Buffer> &buffer_remap,
For pipeline_loop, bool access_all_versions) For pipeline_loop, bool access_all_versions)
: buffer_data_to_buffer_(buffer_data_to_buffer), : buffer_data_to_buffer_(buffer_data_to_buffer),
buffer_remap_(buffer_remap), pipeline_loop_(pipeline_loop), buffer_remap_(buffer_remap), pipeline_loop_(std::move(pipeline_loop)),
access_all_versions_(access_all_versions) {} access_all_versions_(access_all_versions) {}
private: private:
...@@ -130,10 +131,12 @@ private: ...@@ -130,10 +131,12 @@ private:
} }
PrimExpr RewriteBufferAccess(const Call &call, PrimExpr RewriteBufferAccess(const Call &call,
const std::vector<int> arg_indices) { const std::vector<int> &arg_indices) {
auto product = [](const Array<PrimExpr> &input) { auto product = [](const Array<PrimExpr> &input) {
return foldl( return foldl(
[](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, [](PrimExpr a, PrimExpr b, Span span) {
return mul(std::move(a), std::move(b), std::move(span));
},
make_const(DataType::Int(32), 1), input); make_const(DataType::Int(32), 1), input);
}; };
Array<PrimExpr> new_args = call->args; Array<PrimExpr> new_args = call->args;
...@@ -363,7 +366,7 @@ private: ...@@ -363,7 +366,7 @@ private:
* \param region2 The second region. * \param region2 The second region.
* \return Whether region1 and region2 have intersections. * \return Whether region1 and region2 have intersections.
*/ */
bool MayConflict(Region region1, Region region2) { bool MayConflict(const Region &region1, const Region &region2) {
ICHECK(region1.size() == region2.size()); ICHECK(region1.size() == region2.size());
for (size_t i = 0; i < region1.size(); i++) { for (size_t i = 0; i < region1.size(); i++) {
Range dim1 = region1[i]; Range dim1 = region1[i];
...@@ -458,7 +461,7 @@ private: ...@@ -458,7 +461,7 @@ private:
Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) { Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get())); ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get()));
new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions)); new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
if (new_buffer->strides.size()) { if (!new_buffer->strides.empty()) {
ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size()); ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1]; PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1];
new_buffer->strides.insert(new_buffer->strides.begin(), stride_0); new_buffer->strides.insert(new_buffer->strides.begin(), stride_0);
...@@ -480,7 +483,9 @@ private: ...@@ -480,7 +483,9 @@ private:
PrimExpr producer_head; PrimExpr producer_head;
std::vector<std::vector<int>> commit_groups; std::vector<std::vector<int>> commit_groups;
std::unordered_map<const BufferNode *, int> buffer_to_commit_group_; std::unordered_map<const BufferNode *, int> buffer_to_commit_group_;
bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; } bool writes(const Buffer &buf) const {
return dst_buffers.count(buf.get()) > 0;
}
}; };
// Per-stage states that are local to each of pipeline prologue, body, and // Per-stage states that are local to each of pipeline prologue, body, and
...@@ -616,7 +621,7 @@ private: ...@@ -616,7 +621,7 @@ private:
* \param unroll_loop Whether the loop should be unrolled. * \param unroll_loop Whether the loop should be unrolled.
* \return The result loop. * \return The result loop.
*/ */
Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop, Stmt EmitImpl(const PrimExpr &start, const PrimExpr &end, bool unroll_loop,
bool need_bound_check) { bool need_bound_check) {
PrimExpr new_loop_var; PrimExpr new_loop_var;
PrimExpr extent = end - start; PrimExpr extent = end - start;
...@@ -719,7 +724,7 @@ private: ...@@ -719,7 +724,7 @@ private:
} }
return BlockRealize({}, Bool(true), return BlockRealize({}, Bool(true),
MakeBlock(std::move(new_loop), buffer_data_to_buffer_)); MakeBlock(new_loop, buffer_data_to_buffer_));
} }
arith::Analyzer analyzer_; arith::Analyzer analyzer_;
...@@ -782,7 +787,7 @@ public: ...@@ -782,7 +787,7 @@ public:
private: private:
explicit PipelineInjector(Optional<String> global_symbol) explicit PipelineInjector(Optional<String> global_symbol)
: global_symbol_(global_symbol) {} : global_symbol_(std::move(global_symbol)) {}
/*! /*!
* \brief Check the pipeline satisfies the following conditions: * \brief Check the pipeline satisfies the following conditions:
...@@ -982,7 +987,7 @@ private: ...@@ -982,7 +987,7 @@ private:
*/ */
tir::transform::Pass InjectSoftwarePipeline() { tir::transform::Pass InjectSoftwarePipeline() {
using namespace tir::transform; using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
auto *fptr = f.CopyOnWrite(); auto *fptr = f.CopyOnWrite();
fptr->body = software_pipeline::PipelineInjector::Inject(f); fptr->body = software_pipeline::PipelineInjector::Inject(f);
fptr->body = ConvertSSA(std::move(fptr->body)); fptr->body = ConvertSSA(std::move(fptr->body));
......
...@@ -53,7 +53,7 @@ public: ...@@ -53,7 +53,7 @@ public:
Stmt InjectPTX(const BufferLoadNode *load, const BufferStoreNode *store, Stmt InjectPTX(const BufferLoadNode *load, const BufferStoreNode *store,
bool predicated = false, bool predicated = false,
PrimExpr predicate_value = PrimExpr()) { const PrimExpr &predicate_value = PrimExpr()) {
if (load->buffer.scope() == "global") { if (load->buffer.scope() == "global") {
ICHECK(load->indices.size() == 1 && store->indices.size() == 1); ICHECK(load->indices.size() == 1 && store->indices.size() == 1);
ICHECK(load->indices[0]->dtype.lanes() == ICHECK(load->indices[0]->dtype.lanes() ==
...@@ -224,7 +224,7 @@ private: ...@@ -224,7 +224,7 @@ private:
using namespace tir::transform; using namespace tir::transform;
tvm::transform::Pass InjectPTXAsyncCopy() { tvm::transform::Pass InjectPTXAsyncCopy() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
auto *n = f.CopyOnWrite(); auto *n = f.CopyOnWrite();
n->body = PTXAsyncCopyInjector()(n->body); n->body = PTXAsyncCopyInjector()(n->body);
return f; return f;
......
...@@ -32,6 +32,8 @@ ...@@ -32,6 +32,8 @@
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <utility>
#include "../op/builtin.h" #include "../op/builtin.h"
#include "./common/attr.h" #include "./common/attr.h"
#include "./common/collector.h" #include "./common/collector.h"
...@@ -55,7 +57,7 @@ public: ...@@ -55,7 +57,7 @@ public:
loop_extents = 1; loop_extents = 1;
} }
void Collect(Stmt stmt) { VisitStmt(stmt); } void Collect(const Stmt &stmt) { VisitStmt(stmt); }
PrimExpr BulkCopyBytes() { return bulk_copy_bytes; } PrimExpr BulkCopyBytes() { return bulk_copy_bytes; }
...@@ -103,12 +105,12 @@ private: ...@@ -103,12 +105,12 @@ private:
IterVarType::kDataPar); IterVarType::kDataPar);
PrimExpr makeGetBarrier(PrimExpr barrier_id) { PrimExpr makeGetBarrier(PrimExpr barrier_id) {
return Call(DataType::Handle(), get_mbarrier(), {barrier_id}); return Call(DataType::Handle(), get_mbarrier(), {std::move(barrier_id)});
} }
Stmt makeExpectTX(PrimExpr barrier_id, PrimExpr bytes) { Stmt makeExpectTX(PrimExpr barrier_id, PrimExpr bytes) {
auto call = Call(DataType::Handle(), mbarrier_expect_tx(), auto call = Call(DataType::Handle(), mbarrier_expect_tx(),
{makeGetBarrier(barrier_id), bytes}); {makeGetBarrier(std::move(barrier_id)), std::move(bytes)});
return Evaluate(call); return Evaluate(call);
} }
...@@ -188,7 +190,7 @@ public: ...@@ -188,7 +190,7 @@ public:
Map<PrimExpr, IntImm> barrier_id_to_range() { return barrier_id_to_range_; } Map<PrimExpr, IntImm> barrier_id_to_range() { return barrier_id_to_range_; }
private: private:
void UpdateBarrierRange(PrimExpr barrier_id, IntImm extent) { void UpdateBarrierRange(const PrimExpr &barrier_id, const IntImm &extent) {
if (barrier_id_to_range_.count(barrier_id)) { if (barrier_id_to_range_.count(barrier_id)) {
auto old_extent = barrier_id_to_range_[barrier_id]; auto old_extent = barrier_id_to_range_[barrier_id];
ICHECK_EQ(old_extent->value, extent->value) ICHECK_EQ(old_extent->value, extent->value)
...@@ -207,7 +209,7 @@ private: ...@@ -207,7 +209,7 @@ private:
pending_tma_ops_.push_back(GetRef<Call>(call)); pending_tma_ops_.push_back(GetRef<Call>(call));
} else if (call->op.same_as(builtin::ptx_arrive_barrier())) { } else if (call->op.same_as(builtin::ptx_arrive_barrier())) {
PrimExpr barrier_id = call->args[0]; PrimExpr barrier_id = call->args[0];
for (auto tma_call : pending_tma_ops_) { for (const auto &tma_call : pending_tma_ops_) {
tma_op_to_barrier_id_.Set(tma_call, barrier_id); tma_op_to_barrier_id_.Set(tma_call, barrier_id);
} }
auto const_int_bound = analyzer_.const_int_bound(thread_var_); auto const_int_bound = analyzer_.const_int_bound(thread_var_);
...@@ -326,7 +328,7 @@ public: ...@@ -326,7 +328,7 @@ public:
std::vector<int> restore_barrier_ids_; std::vector<int> restore_barrier_ids_;
int if_depth_{0}; int if_depth_{0};
Map<ObjectRef, PrimExpr> tma_op_to_barrier_id_; Map<ObjectRef, PrimExpr> tma_op_to_barrier_id_;
arith::Analyzer *analyzer_; arith::Analyzer *analyzer_{};
Map<Var, arith::IntSet> var_int_set_; Map<Var, arith::IntSet> var_int_set_;
std::vector<arith::IntSet> int_sets_; std::vector<arith::IntSet> int_sets_;
}; };
...@@ -336,7 +338,7 @@ public: ...@@ -336,7 +338,7 @@ public:
BarrierCreationRewriter(std::vector<int> restore_barrier_ids, BarrierCreationRewriter(std::vector<int> restore_barrier_ids,
PrimExpr producer_thread_extent) PrimExpr producer_thread_extent)
: restore_barrier_ids_(std::move(restore_barrier_ids)), : restore_barrier_ids_(std::move(restore_barrier_ids)),
producer_thread_extent_(producer_thread_extent) {} producer_thread_extent_(std::move(producer_thread_extent)) {}
PrimExpr VisitExpr_(const CallNode *op) { PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(create_list_of_mbarrier())) { if (op->op.same_as(create_list_of_mbarrier())) {
...@@ -370,8 +372,8 @@ public: ...@@ -370,8 +372,8 @@ public:
Map<PrimExpr, IntImm> barrier_id_to_range, Map<PrimExpr, IntImm> barrier_id_to_range,
bool has_create_list_of_mbarrier) bool has_create_list_of_mbarrier)
: IRMutatorWithAnalyzer(analyzer), : IRMutatorWithAnalyzer(analyzer),
tma_op_to_barrier_id_(tma_op_to_barrier_id), tma_op_to_barrier_id_(std::move(tma_op_to_barrier_id)),
barrier_id_to_range_(barrier_id_to_range), barrier_id_to_range_(std::move(barrier_id_to_range)),
has_create_list_of_mbarrier_(has_create_list_of_mbarrier) {} has_create_list_of_mbarrier_(has_create_list_of_mbarrier) {}
static PrimFunc Rewrite(PrimFunc f, arith::Analyzer *analyzer) { static PrimFunc Rewrite(PrimFunc f, arith::Analyzer *analyzer) {
...@@ -405,7 +407,7 @@ public: ...@@ -405,7 +407,7 @@ public:
private: private:
Stmt VisitStmt_(const BlockNode *op) { Stmt VisitStmt_(const BlockNode *op) {
auto block = GetRef<Block>(op); auto block = GetRef<Block>(op);
if (!has_create_list_of_mbarrier_ && barrier_id_to_range_.size() > 0 && if (!has_create_list_of_mbarrier_ && !barrier_id_to_range_.empty() &&
op->name_hint == MainBlockName) { op->name_hint == MainBlockName) {
ICHECK(false) << "Please declare create_list_of_mbarrier."; ICHECK(false) << "Please declare create_list_of_mbarrier.";
} }
...@@ -503,7 +505,7 @@ private: ...@@ -503,7 +505,7 @@ private:
}; };
tvm::transform::Pass InjectTmaBarrier() { tvm::transform::Pass InjectTmaBarrier() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
// Check if function only uses threadIdx.x before proceeding // Check if function only uses threadIdx.x before proceeding
if (!ThreadTagChecker::HasOnlyThreadIdxX(f)) { if (!ThreadTagChecker::HasOnlyThreadIdxX(f)) {
LOG(WARNING) << "InjectTmaBarrier will be disabled because the program " LOG(WARNING) << "InjectTmaBarrier will be disabled because the program "
......
...@@ -551,7 +551,7 @@ public: ...@@ -551,7 +551,7 @@ public:
} }
private: private:
LayoutInferencer(const LayoutInferenceResult result, LayoutInferencer(const LayoutInferenceResult &result,
bool skip_thread_partition, arith::Analyzer *analyzer) bool skip_thread_partition, arith::Analyzer *analyzer)
: arith::IRMutatorWithAnalyzer(analyzer), result_(result), : arith::IRMutatorWithAnalyzer(analyzer), result_(result),
skip_thread_partition_(skip_thread_partition){}; skip_thread_partition_(skip_thread_partition){};
...@@ -713,11 +713,11 @@ private: ...@@ -713,11 +713,11 @@ private:
tvm::transform::Pass LayoutInference() { tvm::transform::Pass LayoutInference() {
using namespace tir::transform; using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
f.CopyOnWrite()->body = ParallelLoopTransformer::Substitute(f->body); f.CopyOnWrite()->body = ParallelLoopTransformer::Substitute(f->body);
ThreadBindingCollector collector; ThreadBindingCollector collector;
collector(f->body); collector(f->body);
bool has_thread_binding = collector.thread_binding_.size() > 0; bool has_thread_binding = !collector.thread_binding_.empty();
bool skip_thread_partition = !has_thread_binding; bool skip_thread_partition = !has_thread_binding;
return LayoutInferencer::Substitute(std::move(f), skip_thread_partition); return LayoutInferencer::Substitute(std::move(f), skip_thread_partition);
}; };
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <tvm/tir/utils.h> #include <tvm/tir/utils.h>
#include <utility>
#include "../op/builtin.h" #include "../op/builtin.h"
#include "../op/parallel.h" #include "../op/parallel.h"
#include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_mutator_with_analyzer.h"
...@@ -140,7 +142,8 @@ class SafeMemorysRewriter : public StmtExprMutator { ...@@ -140,7 +142,8 @@ class SafeMemorysRewriter : public StmtExprMutator {
public: public:
explicit SafeMemorysRewriter(Map<Buffer, PrimExpr> annotated_padding_map, explicit SafeMemorysRewriter(Map<Buffer, PrimExpr> annotated_padding_map,
arith::Analyzer *analyzer) arith::Analyzer *analyzer)
: annotated_padding_map_(annotated_padding_map), analyzer_(analyzer) {} : annotated_padding_map_(std::move(annotated_padding_map)),
analyzer_(analyzer) {}
private: private:
Stmt VisitStmt_(const BufferStoreNode *op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
...@@ -153,7 +156,7 @@ private: ...@@ -153,7 +156,7 @@ private:
// Skip boundary check if the store value is an IfThenElse // Skip boundary check if the store value is an IfThenElse
if (const IfThenElseNode *if_node = store->value.as<IfThenElseNode>()) { if (const IfThenElseNode *if_node = store->value.as<IfThenElseNode>()) {
if (conditions.size() > 0) { if (!conditions.empty()) {
LOG(WARNING) LOG(WARNING)
<< "Skipping boundary check for store with IfThenElse value: " << "Skipping boundary check for store with IfThenElse value: "
<< store->value << store->value
...@@ -165,7 +168,7 @@ private: ...@@ -165,7 +168,7 @@ private:
return store; return store;
} }
if (conditions.size() == 0) { if (conditions.empty()) {
return store; return store;
} }
...@@ -215,7 +218,7 @@ private: ...@@ -215,7 +218,7 @@ private:
checker(call); checker(call);
Array<PrimExpr> conditions = checker.GetConditions(); Array<PrimExpr> conditions = checker.GetConditions();
if (conditions.size() == 0) { if (conditions.empty()) {
return evaluate; return evaluate;
} }
...@@ -330,7 +333,7 @@ private: ...@@ -330,7 +333,7 @@ private:
static bool HasInnerLoop(const Stmt &stmt) { static bool HasInnerLoop(const Stmt &stmt) {
LeafForFinder finder; LeafForFinder finder;
finder(stmt); finder(stmt);
return finder.leaf_for_nodes.size() > 0; return !finder.leaf_for_nodes.empty();
} }
Map<Var, Buffer> buffer_data_to_buffer_; Map<Var, Buffer> buffer_data_to_buffer_;
...@@ -341,7 +344,7 @@ private: ...@@ -341,7 +344,7 @@ private:
tvm::transform::Pass LegalizeSafeMemoryAccess() { tvm::transform::Pass LegalizeSafeMemoryAccess() {
using namespace tir::transform; using namespace tir::transform;
// Define the transformation function to be applied // Define the transformation function to be applied
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
bool disable_safe_memory_legalize = bool disable_safe_memory_legalize =
ctx->GetConfig<Bool>(kDisableSafeMemoryLegalize, Bool(false)).value(); ctx->GetConfig<Bool>(kDisableSafeMemoryLegalize, Bool(false)).value();
if (disable_safe_memory_legalize) { if (disable_safe_memory_legalize) {
......
...@@ -73,7 +73,7 @@ private: ...@@ -73,7 +73,7 @@ private:
// Change the loop kind from vectorized to serial // Change the loop kind from vectorized to serial
for_node.CopyOnWrite()->kind = ForKind::kSerial; for_node.CopyOnWrite()->kind = ForKind::kSerial;
// Apply vectorization transformation to the loop // Apply vectorization transformation to the loop
return VectorizeLoop(std::move(for_node)); return VectorizeLoop(for_node);
} }
}; };
...@@ -81,7 +81,7 @@ private: ...@@ -81,7 +81,7 @@ private:
tvm::transform::Pass LegalizeVectorizedLoop() { tvm::transform::Pass LegalizeVectorizedLoop() {
using namespace tir::transform; using namespace tir::transform;
// Define the transformation function to be applied // Define the transformation function to be applied
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
return LoopVectorizedLegalizer::Substitute(std::move(f)); return LoopVectorizedLegalizer::Substitute(std::move(f));
}; };
// Create and return a PrimFunc pass with the transformation function // Create and return a PrimFunc pass with the transformation function
......
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <utility>
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -57,7 +59,7 @@ private: ...@@ -57,7 +59,7 @@ private:
// Rewrite the parallel loop into a common loop, which is mapped to threads // Rewrite the parallel loop into a common loop, which is mapped to threads
For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
Fragment loop_layout) { const Fragment &loop_layout) {
ICHECK(loop_layout.defined()); ICHECK(loop_layout.defined());
ICHECK(thread_var.defined()); ICHECK(thread_var.defined());
int old_loop_depth = loop_layout->InputDim(); int old_loop_depth = loop_layout->InputDim();
...@@ -72,7 +74,7 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, ...@@ -72,7 +74,7 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
vars.push_back(thread_var); vars.push_back(thread_var);
// create the substitute map, and the loop body // create the substitute map, and the loop body
Map<Var, PrimExpr> vmap; Map<Var, PrimExpr> vmap;
Stmt body = op; Stmt body = std::move(op);
auto inv_loop = loop_layout->Inverse(); auto inv_loop = loop_layout->Inverse();
auto indices = inv_loop->Forward(Array<PrimExpr>(vars.begin(), vars.end())); auto indices = inv_loop->Forward(Array<PrimExpr>(vars.begin(), vars.end()));
for (int i = 0; i < old_loop_depth; i++) { for (int i = 0; i < old_loop_depth; i++) {
...@@ -123,7 +125,7 @@ class LoopPartitioner : public StmtExprVisitor { ...@@ -123,7 +125,7 @@ class LoopPartitioner : public StmtExprVisitor {
public: public:
LoopPartitioner() = default; LoopPartitioner() = default;
Fragment Partition(For op, int num_thread, int vectorize_size) { Fragment Partition(const For &op, int num_thread, int vectorize_size) {
this->VisitStmt(op); this->VisitStmt(op);
int loop_size_full = 1; int loop_size_full = 1;
PrimExpr flattened = 0; PrimExpr flattened = 0;
...@@ -182,12 +184,14 @@ private: ...@@ -182,12 +184,14 @@ private:
Array<IterVar> loop_vars_; Array<IterVar> loop_vars_;
}; };
Fragment PlanLoopPartition(For op, size_t num_thread, int vectorize_size) { Fragment PlanLoopPartition(const For &op, size_t num_thread,
int vectorize_size) {
LoopPartitioner partitioner; LoopPartitioner partitioner;
return partitioner.Partition(op, num_thread, vectorize_size); return partitioner.Partition(op, num_thread, vectorize_size);
} }
Fragment PlanLoopPartition(For op, int vectorize_size, Range thread_range) { Fragment PlanLoopPartition(const For &op, int vectorize_size,
const Range &thread_range) {
size_t num_thread = *as_const_int(thread_range->extent); size_t num_thread = *as_const_int(thread_range->extent);
LoopPartitioner partitioner; LoopPartitioner partitioner;
Fragment fragment = partitioner.Partition(op, num_thread, vectorize_size); Fragment fragment = partitioner.Partition(op, num_thread, vectorize_size);
...@@ -196,7 +200,7 @@ Fragment PlanLoopPartition(For op, int vectorize_size, Range thread_range) { ...@@ -196,7 +200,7 @@ Fragment PlanLoopPartition(For op, int vectorize_size, Range thread_range) {
For LoopPragmaUnroll(For stmt) { For LoopPragmaUnroll(For stmt) {
LoopPramaUnroller unroller; LoopPramaUnroller unroller;
For unrolled = Downcast<For>(unroller(stmt)); For unrolled = Downcast<For>(unroller(std::move(stmt)));
return unrolled; return unrolled;
} }
......
...@@ -35,11 +35,13 @@ namespace tl { ...@@ -35,11 +35,13 @@ namespace tl {
using namespace tir; using namespace tir;
For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
Fragment loop_layout); const Fragment &loop_layout);
Fragment PlanLoopPartition(For op, size_t num_thread, int vectorize_size); Fragment PlanLoopPartition(const For &op, size_t num_thread,
int vectorize_size);
Fragment PlanLoopPartition(For op, int vectorize_size, Range thread_range); Fragment PlanLoopPartition(const For &op, int vectorize_size,
const Range &thread_range);
For LoopPragmaUnroll(For stmt); For LoopPragmaUnroll(For stmt);
......
...@@ -110,7 +110,7 @@ private: ...@@ -110,7 +110,7 @@ private:
// TODO: perform some checks here // TODO: perform some checks here
} }
void UpdateVectorSize(const Array<PrimExpr> indices, const Buffer &buffer) { void UpdateVectorSize(const Array<PrimExpr> &indices, const Buffer &buffer) {
if (!inner_for_) if (!inner_for_)
return; return;
auto extent_ptr = inner_for_->extent.as<IntImmNode>(); auto extent_ptr = inner_for_->extent.as<IntImmNode>();
...@@ -139,7 +139,7 @@ private: ...@@ -139,7 +139,7 @@ private:
// Generate strides if not existed // Generate strides if not existed
auto strides = buffer->strides; auto strides = buffer->strides;
if (buffer->strides.size() == 0) { if (buffer->strides.empty()) {
PrimExpr stride = 1; PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) { for (int i = indices.size() - 1; i >= 0; --i) {
strides.push_back(stride); strides.push_back(stride);
...@@ -169,7 +169,7 @@ private: ...@@ -169,7 +169,7 @@ private:
const int vector_load_bits_max_ = 128; const int vector_load_bits_max_ = 128;
const ForNode *inner_for_; const ForNode *inner_for_{};
Map<Var, Range> iter_map_; Map<Var, Range> iter_map_;
bool has_nonlocal_memory_access_ = false; bool has_nonlocal_memory_access_ = false;
int vector_size_ = 128; int vector_size_ = 128;
...@@ -180,7 +180,7 @@ private: ...@@ -180,7 +180,7 @@ private:
class VectorizeRewriter : public StmtExprMutator { class VectorizeRewriter : public StmtExprMutator {
public: public:
VectorizeRewriter(VectorizePlanResult plan) VectorizeRewriter(const VectorizePlanResult &plan)
: vector_size_(plan.vector_size), condition_(plan.condition), : vector_size_(plan.vector_size), condition_(plan.condition),
dynamic_(plan.dynamic) {} dynamic_(plan.dynamic) {}
...@@ -220,7 +220,7 @@ private: ...@@ -220,7 +220,7 @@ private:
} }
} }
const ForNode *inner_for_; const ForNode *inner_for_{};
const int vector_size_; const int vector_size_;
const PrimExpr condition_; const PrimExpr condition_;
const bool dynamic_; const bool dynamic_;
...@@ -236,7 +236,8 @@ VectorizePlanResult GetVectorizePlanResult(const For &loop) { ...@@ -236,7 +236,8 @@ VectorizePlanResult GetVectorizePlanResult(const For &loop) {
return {vector_size, dynamic, condition}; return {vector_size, dynamic, condition};
} }
bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, bool IndiceCanVectorize(const PrimExpr &expr, Var var,
const PrimExpr &iter_var_size,
int target_vectorized_size, arith::Analyzer *analyzer) { int target_vectorized_size, arith::Analyzer *analyzer) {
ICHECK(target_vectorized_size >= 1); ICHECK(target_vectorized_size >= 1);
if (target_vectorized_size == 1) if (target_vectorized_size == 1)
......
...@@ -37,7 +37,8 @@ int GetVectorizeSize(const For &loop); ...@@ -37,7 +37,8 @@ int GetVectorizeSize(const For &loop);
For VectorizeLoop(const For &loop, int vectorize_hint = -1); For VectorizeLoop(const For &loop, int vectorize_hint = -1);
bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, bool IndiceCanVectorize(const PrimExpr &expr, Var var,
const PrimExpr &iter_var_size,
int target_vectorized_size, arith::Analyzer *analyzer); int target_vectorized_size, arith::Analyzer *analyzer);
} // namespace tl } // namespace tl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment