Unverified Commit 9f7bac4c authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Backup Analyzer to get the appropriate arith informations (#1311)

* [Refactor] Update Vectorization Functions to Accept Analyzer Parameter

- Modified `VectorizeLoop` and related functions to accept an `arith::Analyzer` parameter, enhancing their capability to perform analysis during vectorization.
- Updated multiple instances in `copy.cc`, `fill.cc`, `parallel.cc`, and layout inference files to utilize the new analyzer parameter for improved performance and correctness.
- Ensured consistency across vectorization logic by integrating the analyzer into existing workflows, facilitating better optimization opportunities.

* [Fix] Corrected PostOrderVisit call in loop_vectorize.cc

- Updated the PostOrderVisit function to analyze the body of the loop node instead of the node itself, ensuring proper handling of nested loops during vectorization analysis.

* fix

* lint fix

* fix
parent 721baedb
Subproject commit bc31e7ad9f9fafd7659dfabafe359fd55a0ffc1e Subproject commit cd2b2b6013d155b5822300b0a0740fa65320dd9e
...@@ -852,7 +852,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, ...@@ -852,7 +852,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T,
auto par_op = ParallelOp(transformed_loop); auto par_op = ParallelOp(transformed_loop);
if (is_cpu_target) { if (is_cpu_target) {
vectorized_thread_loop = VectorizeLoop(transformed_loop); vectorized_thread_loop = VectorizeLoop(transformed_loop, analyzer);
} else { } else {
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict, std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree}; InferLevel::kFree};
...@@ -865,7 +865,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, ...@@ -865,7 +865,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T,
auto thread_var = T.thread_var; auto thread_var = T.thread_var;
auto thread_loop = auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout); PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
vectorized_thread_loop = VectorizeLoop(thread_loop); vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
} }
if (par_op->GetPredicate(T.thread_var).defined()) { if (par_op->GetPredicate(T.thread_var).defined()) {
......
...@@ -207,7 +207,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -207,7 +207,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
InferLevel::kFree); InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout()); par_op->GetLoopLayout());
auto vectorized_thread_loop = VectorizeLoop(thread_loop); auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
if (par_op->GetPredicate(T.thread_var).defined()) { if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(), return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop); vectorized_thread_loop);
...@@ -215,7 +215,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -215,7 +215,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return vectorized_thread_loop; return vectorized_thread_loop;
} else if (dst.scope() == "local") { } else if (dst.scope() == "local") {
auto init_loop = MakeSIMTLoop(analyzer); auto init_loop = MakeSIMTLoop(analyzer);
auto vectorized_thread_loop = VectorizeLoop(init_loop); auto vectorized_thread_loop = VectorizeLoop(init_loop, analyzer);
return vectorized_thread_loop; return vectorized_thread_loop;
} else if (dst.scope() == "shared.dyn" || dst.scope() == "shared" || } else if (dst.scope() == "shared.dyn" || dst.scope() == "shared" ||
dst.scope() == "global") { dst.scope() == "global") {
...@@ -225,7 +225,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -225,7 +225,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
InferLevel::kFree); InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout()); par_op->GetLoopLayout());
auto vectorized_thread_loop = VectorizeLoop(thread_loop); auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
if (par_op->GetPredicate(T.thread_var).defined()) { if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(), return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop); vectorized_thread_loop);
......
...@@ -452,8 +452,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -452,8 +452,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
// As the pass will do post processing to the layout // As the pass will do post processing to the layout
auto maybe_remapped_root_ = auto maybe_remapped_root_ =
IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map); IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
int vector_size = GetVectorizeSize(maybe_remapped_root_); int vector_size = GetVectorizeSize(maybe_remapped_root_, T.analyzer);
DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n'; DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n';
PrimExpr loop_total_size = 1; PrimExpr loop_total_size = 1;
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <tvm/tir/utils.h> #include <tvm/tir/utils.h>
#include <algorithm> #include <algorithm>
#include <memory>
#include <queue> #include <queue>
#include "../layout/utils.h" #include "../layout/utils.h"
...@@ -85,6 +86,7 @@ public: ...@@ -85,6 +86,7 @@ public:
auto &next = infer_list_[cur_infer_id]; auto &next = infer_list_[cur_infer_id];
auto iter_var = thread_var_vec_[cur_infer_id]; auto iter_var = thread_var_vec_[cur_infer_id];
auto thread_bounds = thread_bounds_vec_[cur_infer_id]; auto thread_bounds = thread_bounds_vec_[cur_infer_id];
arith::Analyzer *cur_analyzer = analyzer_vec_[cur_infer_id].get();
auto buffer_oob = buffer_oob_vec_[cur_infer_id]; auto buffer_oob = buffer_oob_vec_[cur_infer_id];
// Double-check that 'next' is valid // Double-check that 'next' is valid
ICHECK(next.defined()) << "infer_list_[" << cur_infer_id ICHECK(next.defined()) << "infer_list_[" << cur_infer_id
...@@ -108,7 +110,7 @@ public: ...@@ -108,7 +110,7 @@ public:
// Run InferLayout // Run InferLayout
auto updates = auto updates =
next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map, next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map,
&analyzer_, buffer_oob}, cur_analyzer, buffer_oob},
level); level);
// Process the returned updates // Process the returned updates
for (const auto &[buffer, layout] : updates) { for (const auto &[buffer, layout] : updates) {
...@@ -266,6 +268,9 @@ public: ...@@ -266,6 +268,9 @@ public:
ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size()) ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size())
<< "Size mismatch: thread_bounds_vec_ and infer_list_ must match in " << "Size mismatch: thread_bounds_vec_ and infer_list_ must match in "
"length."; "length.";
ICHECK_EQ(analyzer_vec_.size(), infer_list_.size())
<< "Size mismatch: analyzer_vec_ and infer_list_ must match in "
"length.";
ICHECK_EQ(buffer_oob_vec_.size(), infer_list_.size()) ICHECK_EQ(buffer_oob_vec_.size(), infer_list_.size())
<< "Size mismatch: buffer_oob_vec_ and infer_list_ must match in " << "Size mismatch: buffer_oob_vec_ and infer_list_ must match in "
"length."; "length.";
...@@ -452,6 +457,7 @@ private: ...@@ -452,6 +457,7 @@ private:
} else { } else {
thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1)); thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
} }
analyzer_vec_.push_back(analyzer_.Clone());
// Compute buffer oob for each buffer in the op // Compute buffer oob for each buffer in the op
if (const auto *copy = p.as<CopyNode>()) { if (const auto *copy = p.as<CopyNode>()) {
...@@ -542,6 +548,7 @@ private: ...@@ -542,6 +548,7 @@ private:
} else { } else {
thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1)); thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
} }
analyzer_vec_.push_back(analyzer_.Clone());
buffer_oob_vec_.push_back(false); buffer_oob_vec_.push_back(false);
} else { } else {
IRVisitorWithAnalyzer::VisitStmt(op->body); IRVisitorWithAnalyzer::VisitStmt(op->body);
...@@ -683,6 +690,7 @@ private: ...@@ -683,6 +690,7 @@ private:
IterVarType::kDataPar); IterVarType::kDataPar);
std::vector<IterVar> thread_var_vec_; std::vector<IterVar> thread_var_vec_;
std::vector<Range> thread_bounds_vec_; std::vector<Range> thread_bounds_vec_;
std::vector<std::unique_ptr<arith::Analyzer>> analyzer_vec_;
std::vector<bool> buffer_oob_vec_; std::vector<bool> buffer_oob_vec_;
Target target_; Target target_;
LayoutMap annotated_layout_map_; LayoutMap annotated_layout_map_;
...@@ -1024,7 +1032,7 @@ private: ...@@ -1024,7 +1032,7 @@ private:
}); });
if ((has_non_local || has_cast_operations) && !has_reducer) { if ((has_non_local || has_cast_operations) && !has_reducer) {
for_node = VectorizeLoop(for_node); for_node = VectorizeLoop(for_node, analyzer_);
} }
if (result_.predicate_map.count(root) && parallel_loop) { if (result_.predicate_map.count(root) && parallel_loop) {
......
...@@ -73,7 +73,7 @@ private: ...@@ -73,7 +73,7 @@ private:
// Change the loop kind from vectorized to serial // Change the loop kind from vectorized to serial
for_node.CopyOnWrite()->kind = ForKind::kSerial; for_node.CopyOnWrite()->kind = ForKind::kSerial;
// Apply vectorization transformation to the loop // Apply vectorization transformation to the loop
return VectorizeLoop(for_node); return VectorizeLoop(for_node, analyzer_);
} }
}; };
......
...@@ -45,7 +45,7 @@ struct VectorizePlanResult { ...@@ -45,7 +45,7 @@ struct VectorizePlanResult {
PrimExpr condition; PrimExpr condition;
}; };
class VectorizeFindGlobalAccess : public arith::IRVisitorWithAnalyzer { class VectorizeFindGlobalAccess : public StmtExprVisitor {
public: public:
VectorizeFindGlobalAccess() = default; VectorizeFindGlobalAccess() = default;
...@@ -60,19 +60,20 @@ private: ...@@ -60,19 +60,20 @@ private:
void VisitStmt_(const BufferStoreNode *node) final { void VisitStmt_(const BufferStoreNode *node) final {
if (node->buffer.scope() == "global") if (node->buffer.scope() == "global")
has_global_access_ = true; has_global_access_ = true;
return arith::IRVisitorWithAnalyzer::VisitStmt_(node); return StmtExprVisitor::VisitStmt_(node);
} }
void VisitExpr_(const BufferLoadNode *node) final { void VisitExpr_(const BufferLoadNode *node) final {
if (node->buffer.scope() == "global") if (node->buffer.scope() == "global")
has_global_access_ = true; has_global_access_ = true;
return arith::IRVisitorWithAnalyzer::VisitExpr_(node); return StmtExprVisitor::VisitExpr_(node);
} }
}; };
class VectorizePlanner : public arith::IRVisitorWithAnalyzer { class VectorizePlanner : public arith::IRMutatorWithAnalyzer {
public: public:
VectorizePlanner() = default; explicit VectorizePlanner(arith::Analyzer *analyzer)
: arith::IRMutatorWithAnalyzer(analyzer) {}
int Plan(const For &node) { int Plan(const For &node) {
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
...@@ -92,21 +93,31 @@ public: ...@@ -92,21 +93,31 @@ public:
} }
private: private:
void VisitStmt_(const ForNode *node) final { Stmt VisitStmt_(const ForNode *node) final {
inner_for_ = node; inner_for_ = node;
auto extent_ptr = as_const_int(analyzer_.Simplify(node->extent)); bool contains_nested_for = false;
// Must analysis vectorization on the innermost loop
PostOrderVisit(Downcast<Stmt>(node->body), [&](const ObjectRef &obj) {
if (obj.as<ForNode>()) {
contains_nested_for = true;
}
});
if (!contains_nested_for) {
auto extent_ptr = as_const_int(analyzer_->Simplify(node->extent));
// Here I disable dynamic shape completely, // Here I disable dynamic shape completely,
// In order to do it, the Planner should accept an analyzer with // In order to do it, the Planner should accept an analyzer with
// arithmetic info outside to prove the dividiblity of vector size // arithmetic info outside to prove the dividiblity of vector size
if (!extent_ptr) { if (!extent_ptr) {
vector_size_ = 1; vector_size_ = 1;
return; return ffi::GetRef<Stmt>(node);
} }
vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr); vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr);
arith::IRVisitorWithAnalyzer::VisitStmt_(node); }
return arith::IRMutatorWithAnalyzer::VisitStmt_(node);
} }
void VisitExpr_(const BufferLoadNode *node) final { PrimExpr VisitExpr_(const BufferLoadNode *node) final {
if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" || if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
node->buffer.scope() == "shared.dyn") node->buffer.scope() == "shared.dyn")
has_nonlocal_memory_access_ = true; has_nonlocal_memory_access_ = true;
...@@ -115,43 +126,44 @@ private: ...@@ -115,43 +126,44 @@ private:
// constant buffer that tl hack to use as local register. // constant buffer that tl hack to use as local register.
auto boundary_check = node->buffer->shape[0].as<IntImmNode>(); auto boundary_check = node->buffer->shape[0].as<IntImmNode>();
if (boundary_check && boundary_check->value == 1) { if (boundary_check && boundary_check->value == 1) {
return arith::IRVisitorWithAnalyzer::VisitExpr_(node); return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
} }
} }
UpdateVectorSize(node->indices, node->buffer); UpdateVectorSize(node->indices, node->buffer);
return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
} }
void VisitStmt_(const BufferStoreNode *node) final { Stmt VisitStmt_(const BufferStoreNode *node) final {
if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" || if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
node->buffer.scope() == "shared.dyn") node->buffer.scope() == "shared.dyn")
has_nonlocal_memory_access_ = true; has_nonlocal_memory_access_ = true;
UpdateVectorSize(node->indices, node->buffer); UpdateVectorSize(node->indices, node->buffer);
return arith::IRVisitorWithAnalyzer::VisitExpr(node->value); return arith::IRMutatorWithAnalyzer::VisitStmt_(node);
} }
void VisitStmt_(const IfThenElseNode *node) final { Stmt VisitStmt_(const IfThenElseNode *node) final {
CheckConditionVectorized(node->condition); CheckConditionVectorized(node->condition);
return arith::IRVisitorWithAnalyzer::VisitStmt_(node); return arith::IRMutatorWithAnalyzer::VisitStmt_(node);
} }
void VisitExpr_(const CallNode *node) final { PrimExpr VisitExpr_(const CallNode *node) final {
if (node->op == builtin::if_then_else()) { if (node->op == builtin::if_then_else()) {
CheckConditionVectorized(node->args[0]); CheckConditionVectorized(node->args[0]);
} else if (node->op == builtin::call_extern()) { } else if (node->op == builtin::call_extern()) {
// do not vectorize extern calls // do not vectorize extern calls
vector_size_ = 1; vector_size_ = 1;
} }
return arith::IRVisitorWithAnalyzer::VisitExpr_(node); return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
} }
void CheckConditionVectorized(const PrimExpr &cond) { void CheckConditionVectorized(const PrimExpr &cond) {
// TODO: perform some checks here // TODO: perform some checks here
} }
void VisitExpr_(const CastNode *node) final { PrimExpr VisitExpr_(const CastNode *node) final {
vector_size_ = arith::ZeroAwareGCD( vector_size_ = arith::ZeroAwareGCD(
vector_load_bits_max_ / node->dtype.bits(), vector_size_); vector_load_bits_max_ / node->dtype.bits(), vector_size_);
return arith::IRVisitorWithAnalyzer::VisitExpr_(node); return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
} }
void UpdateVectorSize(const Array<PrimExpr> indices, const Buffer &buffer) { void UpdateVectorSize(const Array<PrimExpr> indices, const Buffer &buffer) {
...@@ -171,19 +183,16 @@ private: ...@@ -171,19 +183,16 @@ private:
for (int i = 0; i < indices.size(); ++i) { for (int i = 0; i < indices.size(); ++i) {
elem_offset += indices[i] * strides[i]; elem_offset += indices[i] * strides[i];
} }
// 2. If element offset is independent with loop_var, ignore it // 2. If element offset is independent with loop_var, ignore it
if (CanProveIndependent(elem_offset, inner_for_->loop_var, &analyzer_)) { if (CanProveIndependent(elem_offset, inner_for_->loop_var, analyzer_)) {
return; return;
} }
// 3. Tight vectorize bound // 3. Tight vectorize bound
vector_size_ = arith::ZeroAwareGCD(vector_size_, vector_load_bits_max_ / vector_size_ = arith::ZeroAwareGCD(vector_size_, vector_load_bits_max_ /
buffer->dtype.bits()); buffer->dtype.bits());
// 4. Try to vectorize buffer load // 4. Try to vectorize buffer load
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var, while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
inner_for_->extent, vector_size_, &analyzer_)) { inner_for_->extent, vector_size_, analyzer_)) {
vector_size_ /= 2; vector_size_ /= 2;
} }
} }
...@@ -235,7 +244,14 @@ private: ...@@ -235,7 +244,14 @@ private:
const int vector_size_; const int vector_size_;
}; };
int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); } int GetVectorizeSize(const For &loop) {
arith::Analyzer analyzer;
return VectorizePlanner(&analyzer).Plan(loop);
}
int GetVectorizeSize(const For &loop, arith::Analyzer *analyzer) {
return VectorizePlanner(analyzer).Plan(loop);
}
bool CanProveIndependent(const PrimExpr &expr, Var var, bool CanProveIndependent(const PrimExpr &expr, Var var,
arith::Analyzer *analyzer) { arith::Analyzer *analyzer) {
...@@ -274,10 +290,10 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var, ...@@ -274,10 +290,10 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var,
if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_size_for_iter), if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_size_for_iter),
0)) 0))
return false; return false;
auto simplified_expr = analyzer->Simplify(Substitute(expr, {{var, zero}}));
// The base offset must be divisible // The base offset must be divisible
if (!analyzer->CanProveEqual( if (!analyzer->CanProveEqual(FloorMod(simplified_expr, target_size_for_expr),
FloorMod(Substitute(expr, {{var, zero}}), target_size_for_expr), 0)) { zero)) {
return false; return false;
} }
...@@ -308,7 +324,20 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var, ...@@ -308,7 +324,20 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var,
For VectorizeLoop(const For &loop, int vectorize_hint) { For VectorizeLoop(const For &loop, int vectorize_hint) {
if (vectorize_hint <= 0) { if (vectorize_hint <= 0) {
VectorizePlanner planner; arith::Analyzer analyzer;
VectorizePlanner planner(&analyzer);
vectorize_hint = planner.Plan(loop);
}
if (vectorize_hint == 1)
return loop;
auto rewriter = VectorizeRewriter(vectorize_hint);
return Downcast<For>(rewriter(loop));
}
For VectorizeLoop(const For &loop, arith::Analyzer *analyzer,
int vectorize_hint) {
if (vectorize_hint <= 0) {
VectorizePlanner planner(analyzer);
vectorize_hint = planner.Plan(loop); vectorize_hint = planner.Plan(loop);
} }
if (vectorize_hint == 1) if (vectorize_hint == 1)
......
...@@ -35,8 +35,13 @@ using namespace tir; ...@@ -35,8 +35,13 @@ using namespace tir;
int GetVectorizeSize(const For &loop); int GetVectorizeSize(const For &loop);
int GetVectorizeSize(const For &loop, arith::Analyzer *analyzer);
For VectorizeLoop(const For &loop, int vectorize_hint = -1); For VectorizeLoop(const For &loop, int vectorize_hint = -1);
For VectorizeLoop(const For &loop, arith::Analyzer *analyzer,
int vectorize_hint = -1);
// Can prove expr is independent with var, i.e. the value of expr doesn't change // Can prove expr is independent with var, i.e. the value of expr doesn't change
// when var changes // when var changes
bool CanProveIndependent(const PrimExpr &expr, Var var, bool CanProveIndependent(const PrimExpr &expr, Var var,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment