Unverified Commit 9f7bac4c authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Backup Analyzer to get the appropriate arith informations (#1311)

* [Refactor] Update Vectorization Functions to Accept Analyzer Parameter

- Modified `VectorizeLoop` and related functions to accept an `arith::Analyzer` parameter, enhancing their capability to perform analysis during vectorization.
- Updated multiple instances in `copy.cc`, `fill.cc`, `parallel.cc`, and layout inference files to utilize the new analyzer parameter for improved performance and correctness.
- Ensured consistency across vectorization logic by integrating the analyzer into existing workflows, facilitating better optimization opportunities.

* [Fix] Corrected PostOrderVisit call in loop_vectorize.cc

- Updated the PostOrderVisit function to analyze the body of the loop node instead of the node itself, ensuring proper handling of nested loops during vectorization analysis.

* fix

* lint fix

* fix
parent 721baedb
Subproject commit bc31e7ad9f9fafd7659dfabafe359fd55a0ffc1e
Subproject commit cd2b2b6013d155b5822300b0a0740fa65320dd9e
......@@ -852,7 +852,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T,
auto par_op = ParallelOp(transformed_loop);
if (is_cpu_target) {
vectorized_thread_loop = VectorizeLoop(transformed_loop);
vectorized_thread_loop = VectorizeLoop(transformed_loop, analyzer);
} else {
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree};
......@@ -865,7 +865,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T,
auto thread_var = T.thread_var;
auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
vectorized_thread_loop = VectorizeLoop(thread_loop);
vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
}
if (par_op->GetPredicate(T.thread_var).defined()) {
......
......@@ -207,7 +207,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout());
auto vectorized_thread_loop = VectorizeLoop(thread_loop);
auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop);
......@@ -215,7 +215,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return vectorized_thread_loop;
} else if (dst.scope() == "local") {
auto init_loop = MakeSIMTLoop(analyzer);
auto vectorized_thread_loop = VectorizeLoop(init_loop);
auto vectorized_thread_loop = VectorizeLoop(init_loop, analyzer);
return vectorized_thread_loop;
} else if (dst.scope() == "shared.dyn" || dst.scope() == "shared" ||
dst.scope() == "global") {
......@@ -225,7 +225,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout());
auto vectorized_thread_loop = VectorizeLoop(thread_loop);
auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop);
......
......@@ -452,8 +452,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
// As the pass will do post processing to the layout
auto maybe_remapped_root_ =
IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
int vector_size = GetVectorizeSize(maybe_remapped_root_);
int vector_size = GetVectorizeSize(maybe_remapped_root_, T.analyzer);
DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n';
PrimExpr loop_total_size = 1;
......
......@@ -12,6 +12,7 @@
#include <tvm/tir/utils.h>
#include <algorithm>
#include <memory>
#include <queue>
#include "../layout/utils.h"
......@@ -85,6 +86,7 @@ public:
auto &next = infer_list_[cur_infer_id];
auto iter_var = thread_var_vec_[cur_infer_id];
auto thread_bounds = thread_bounds_vec_[cur_infer_id];
arith::Analyzer *cur_analyzer = analyzer_vec_[cur_infer_id].get();
auto buffer_oob = buffer_oob_vec_[cur_infer_id];
// Double-check that 'next' is valid
ICHECK(next.defined()) << "infer_list_[" << cur_infer_id
......@@ -108,7 +110,7 @@ public:
// Run InferLayout
auto updates =
next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map,
&analyzer_, buffer_oob},
cur_analyzer, buffer_oob},
level);
// Process the returned updates
for (const auto &[buffer, layout] : updates) {
......@@ -266,6 +268,9 @@ public:
ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size())
<< "Size mismatch: thread_bounds_vec_ and infer_list_ must match in "
"length.";
ICHECK_EQ(analyzer_vec_.size(), infer_list_.size())
<< "Size mismatch: analyzer_vec_ and infer_list_ must match in "
"length.";
ICHECK_EQ(buffer_oob_vec_.size(), infer_list_.size())
<< "Size mismatch: buffer_oob_vec_ and infer_list_ must match in "
"length.";
......@@ -452,6 +457,7 @@ private:
} else {
thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
}
analyzer_vec_.push_back(analyzer_.Clone());
// Compute buffer oob for each buffer in the op
if (const auto *copy = p.as<CopyNode>()) {
......@@ -542,6 +548,7 @@ private:
} else {
thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
}
analyzer_vec_.push_back(analyzer_.Clone());
buffer_oob_vec_.push_back(false);
} else {
IRVisitorWithAnalyzer::VisitStmt(op->body);
......@@ -683,6 +690,7 @@ private:
IterVarType::kDataPar);
std::vector<IterVar> thread_var_vec_;
std::vector<Range> thread_bounds_vec_;
std::vector<std::unique_ptr<arith::Analyzer>> analyzer_vec_;
std::vector<bool> buffer_oob_vec_;
Target target_;
LayoutMap annotated_layout_map_;
......@@ -1024,7 +1032,7 @@ private:
});
if ((has_non_local || has_cast_operations) && !has_reducer) {
for_node = VectorizeLoop(for_node);
for_node = VectorizeLoop(for_node, analyzer_);
}
if (result_.predicate_map.count(root) && parallel_loop) {
......
......@@ -73,7 +73,7 @@ private:
// Change the loop kind from vectorized to serial
for_node.CopyOnWrite()->kind = ForKind::kSerial;
// Apply vectorization transformation to the loop
return VectorizeLoop(for_node);
return VectorizeLoop(for_node, analyzer_);
}
};
......
......@@ -45,7 +45,7 @@ struct VectorizePlanResult {
PrimExpr condition;
};
class VectorizeFindGlobalAccess : public arith::IRVisitorWithAnalyzer {
class VectorizeFindGlobalAccess : public StmtExprVisitor {
public:
VectorizeFindGlobalAccess() = default;
......@@ -60,19 +60,20 @@ private:
void VisitStmt_(const BufferStoreNode *node) final {
if (node->buffer.scope() == "global")
has_global_access_ = true;
return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
return StmtExprVisitor::VisitStmt_(node);
}
void VisitExpr_(const BufferLoadNode *node) final {
if (node->buffer.scope() == "global")
has_global_access_ = true;
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
return StmtExprVisitor::VisitExpr_(node);
}
};
class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
class VectorizePlanner : public arith::IRMutatorWithAnalyzer {
public:
VectorizePlanner() = default;
explicit VectorizePlanner(arith::Analyzer *analyzer)
: arith::IRMutatorWithAnalyzer(analyzer) {}
int Plan(const For &node) {
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
......@@ -92,21 +93,31 @@ public:
}
private:
void VisitStmt_(const ForNode *node) final {
Stmt VisitStmt_(const ForNode *node) final {
inner_for_ = node;
auto extent_ptr = as_const_int(analyzer_.Simplify(node->extent));
// Here I disable dynamic shape completely,
// In order to do it, the Planner should accept an analyzer with
// arithmetic info outside to prove the dividiblity of vector size
if (!extent_ptr) {
vector_size_ = 1;
return;
bool contains_nested_for = false;
// Must analysis vectorization on the innermost loop
PostOrderVisit(Downcast<Stmt>(node->body), [&](const ObjectRef &obj) {
if (obj.as<ForNode>()) {
contains_nested_for = true;
}
});
if (!contains_nested_for) {
auto extent_ptr = as_const_int(analyzer_->Simplify(node->extent));
// Here I disable dynamic shape completely,
// In order to do it, the Planner should accept an analyzer with
// arithmetic info outside to prove the dividiblity of vector size
if (!extent_ptr) {
vector_size_ = 1;
return ffi::GetRef<Stmt>(node);
}
vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr);
}
vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr);
arith::IRVisitorWithAnalyzer::VisitStmt_(node);
return arith::IRMutatorWithAnalyzer::VisitStmt_(node);
}
void VisitExpr_(const BufferLoadNode *node) final {
PrimExpr VisitExpr_(const BufferLoadNode *node) final {
if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
node->buffer.scope() == "shared.dyn")
has_nonlocal_memory_access_ = true;
......@@ -115,43 +126,44 @@ private:
// constant buffer that tl hack to use as local register.
auto boundary_check = node->buffer->shape[0].as<IntImmNode>();
if (boundary_check && boundary_check->value == 1) {
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
}
}
UpdateVectorSize(node->indices, node->buffer);
return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
}
void VisitStmt_(const BufferStoreNode *node) final {
Stmt VisitStmt_(const BufferStoreNode *node) final {
if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
node->buffer.scope() == "shared.dyn")
has_nonlocal_memory_access_ = true;
UpdateVectorSize(node->indices, node->buffer);
return arith::IRVisitorWithAnalyzer::VisitExpr(node->value);
return arith::IRMutatorWithAnalyzer::VisitStmt_(node);
}
void VisitStmt_(const IfThenElseNode *node) final {
Stmt VisitStmt_(const IfThenElseNode *node) final {
CheckConditionVectorized(node->condition);
return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
return arith::IRMutatorWithAnalyzer::VisitStmt_(node);
}
void VisitExpr_(const CallNode *node) final {
PrimExpr VisitExpr_(const CallNode *node) final {
if (node->op == builtin::if_then_else()) {
CheckConditionVectorized(node->args[0]);
} else if (node->op == builtin::call_extern()) {
// do not vectorize extern calls
vector_size_ = 1;
}
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
}
void CheckConditionVectorized(const PrimExpr &cond) {
// TODO: perform some checks here
}
void VisitExpr_(const CastNode *node) final {
PrimExpr VisitExpr_(const CastNode *node) final {
vector_size_ = arith::ZeroAwareGCD(
vector_load_bits_max_ / node->dtype.bits(), vector_size_);
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
}
void UpdateVectorSize(const Array<PrimExpr> indices, const Buffer &buffer) {
......@@ -171,19 +183,16 @@ private:
for (int i = 0; i < indices.size(); ++i) {
elem_offset += indices[i] * strides[i];
}
// 2. If element offset is independent with loop_var, ignore it
if (CanProveIndependent(elem_offset, inner_for_->loop_var, &analyzer_)) {
if (CanProveIndependent(elem_offset, inner_for_->loop_var, analyzer_)) {
return;
}
// 3. Tight vectorize bound
vector_size_ = arith::ZeroAwareGCD(vector_size_, vector_load_bits_max_ /
buffer->dtype.bits());
// 4. Try to vectorize buffer load
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
inner_for_->extent, vector_size_, &analyzer_)) {
inner_for_->extent, vector_size_, analyzer_)) {
vector_size_ /= 2;
}
}
......@@ -235,7 +244,14 @@ private:
const int vector_size_;
};
int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); }
int GetVectorizeSize(const For &loop) {
arith::Analyzer analyzer;
return VectorizePlanner(&analyzer).Plan(loop);
}
int GetVectorizeSize(const For &loop, arith::Analyzer *analyzer) {
return VectorizePlanner(analyzer).Plan(loop);
}
bool CanProveIndependent(const PrimExpr &expr, Var var,
arith::Analyzer *analyzer) {
......@@ -274,10 +290,10 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var,
if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_size_for_iter),
0))
return false;
auto simplified_expr = analyzer->Simplify(Substitute(expr, {{var, zero}}));
// The base offset must be divisible
if (!analyzer->CanProveEqual(
FloorMod(Substitute(expr, {{var, zero}}), target_size_for_expr), 0)) {
if (!analyzer->CanProveEqual(FloorMod(simplified_expr, target_size_for_expr),
zero)) {
return false;
}
......@@ -308,7 +324,20 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var,
For VectorizeLoop(const For &loop, int vectorize_hint) {
if (vectorize_hint <= 0) {
VectorizePlanner planner;
arith::Analyzer analyzer;
VectorizePlanner planner(&analyzer);
vectorize_hint = planner.Plan(loop);
}
if (vectorize_hint == 1)
return loop;
auto rewriter = VectorizeRewriter(vectorize_hint);
return Downcast<For>(rewriter(loop));
}
For VectorizeLoop(const For &loop, arith::Analyzer *analyzer,
int vectorize_hint) {
if (vectorize_hint <= 0) {
VectorizePlanner planner(analyzer);
vectorize_hint = planner.Plan(loop);
}
if (vectorize_hint == 1)
......
......@@ -35,8 +35,13 @@ using namespace tir;
int GetVectorizeSize(const For &loop);
int GetVectorizeSize(const For &loop, arith::Analyzer *analyzer);
For VectorizeLoop(const For &loop, int vectorize_hint = -1);
For VectorizeLoop(const For &loop, arith::Analyzer *analyzer,
int vectorize_hint = -1);
// Can prove expr is independent with var, i.e. the value of expr doesn't change
// when var changes
bool CanProveIndependent(const PrimExpr &expr, Var var,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment