"docs/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "754b0043cf4b45d63151b8c5621650b00e12f7d9"
Unverified Commit 2d4b848f authored by Kurisu's avatar Kurisu Committed by GitHub
Browse files

[Fix] tilelang can now vectorize `B[i,j] = c[i] + A[i,j]` (#798)

* Fix bug 0905: vectorize with broadcasted value

* fix lint error

* [Refactor] Use `tvm::tir::UseVar` and use Vectorizer

* Add loop size check in vectorize planner

* fix lint error
parent fa4fd0b7
...@@ -24,17 +24,14 @@ ...@@ -24,17 +24,14 @@
#include "loop_vectorize.h" #include "loop_vectorize.h"
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <numeric>
#include "../layout/layout.h"
#include "../layout/utils.h"
#include "arith/int_operator.h" #include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h" #include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_vectorization_utils.h" #include "common/loop_vectorization_utils.h"
#include "tvm/tir/analysis.h"
#include "tvm/tir/var.h"
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -56,15 +53,18 @@ public: ...@@ -56,15 +53,18 @@ public:
return vector_size_; return vector_size_;
} }
bool GetDynamic() { return dynamic_; }
PrimExpr GetCondition() { return condition_; }
private: private:
void VisitStmt_(const ForNode *node) final { void VisitStmt_(const ForNode *node) final {
inner_for_ = node; inner_for_ = node;
iter_map_.Set(node->loop_var, Range(node->min, node->extent)); auto extent_ptr = as_const_int(node->extent);
// Here I disable dynamic shape completely,
// In order to do it, the Planner should accept an analyzer with
// arithmetic info outside to prove the dividiblity of vector size
if (!extent_ptr) {
vector_size_ = 1;
return;
}
vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr);
arith::IRVisitorWithAnalyzer::VisitStmt_(node); arith::IRVisitorWithAnalyzer::VisitStmt_(node);
} }
...@@ -113,31 +113,7 @@ private: ...@@ -113,31 +113,7 @@ private:
void UpdateVectorSize(const Array<PrimExpr> &indices, const Buffer &buffer) { void UpdateVectorSize(const Array<PrimExpr> &indices, const Buffer &buffer) {
if (!inner_for_) if (!inner_for_)
return; return;
auto extent_ptr = inner_for_->extent.as<IntImmNode>(); // 1. Compute raw element offset
if (!extent_ptr)
return;
const DataType &access_type = buffer->dtype;
// i // 2, i % 8 can also be vectorized as factor 16
int max_vector_size = vector_load_bits_max_ / access_type.bits();
// so we should disable this GCD optimization
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
auto last_dim = buffer->shape.back();
auto mod_set = analyzer_.modular_set(last_dim);
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
// conditionally tail vectorize
if (buffer->shape.back().as<IntImmNode>()) {
max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);
auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base);
// If gcd_base is equal to the last dimension,
// we should analyze the second-to-last dimension
// in relation to the last dimension.
if (gcd_base < Downcast<IntImm>(last_dim)->value) {
max_vector_size = gcd_base;
}
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);
// Generate strides if not existed
auto strides = buffer->strides; auto strides = buffer->strides;
if (buffer->strides.empty()) { if (buffer->strides.empty()) {
PrimExpr stride = 1; PrimExpr stride = 1;
...@@ -147,42 +123,37 @@ private: ...@@ -147,42 +123,37 @@ private:
} }
strides = Array<PrimExpr>{strides.rbegin(), strides.rend()}; strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
} }
// Generate and check element offset expression
ICHECK(indices.size() == strides.size()) << "Invalid indices and strides";
PrimExpr elem_offset = 0; PrimExpr elem_offset = 0;
for (int i = 0; i < indices.size(); ++i) { for (int i = 0; i < indices.size(); ++i) {
elem_offset += indices[i] * strides[i]; elem_offset += indices[i] * strides[i];
} }
// 2. If element offset is independent with loop_var, ignore it
if (CanProveIndependent(elem_offset, inner_for_->loop_var, &analyzer_)) {
return;
}
// 3. Tight vectorize bound
vector_size_ = arith::ZeroAwareGCD(vector_size_, vector_load_bits_max_ /
buffer->dtype.bits());
// 4. Try to vectorize buffer load
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var, while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
inner_for_->extent, vector_size_, inner_for_->extent, vector_size_, &analyzer_)) {
&analyzer_)) {
vector_size_ /= 2; vector_size_ /= 2;
} }
} else if (vector_size_ <= vector_load_bits_max_ / buffer->dtype.bits()) {
// dynamic shape load: get the vectorization condition
dynamic_ = true;
PrimExpr offset = buffer.OffsetOf(indices).back();
condition_ = (FloorMod(offset, vector_size_) == 0);
}
} }
const int vector_load_bits_max_ = 128; const int vector_load_bits_max_ = 128;
const ForNode *inner_for_{}; const ForNode *inner_for_{};
Map<Var, Range> iter_map_;
bool has_nonlocal_memory_access_ = false; bool has_nonlocal_memory_access_ = false;
int vector_size_ = 128; int vector_size_ = 128;
// conditionally vectorize
bool dynamic_ = false;
PrimExpr condition_;
}; };
class VectorizeRewriter : public StmtExprMutator { class VectorizeRewriter : public StmtExprMutator {
public: public:
VectorizeRewriter(const VectorizePlanResult &plan) VectorizeRewriter(int vector_size) : vector_size_(vector_size) {}
: vector_size_(plan.vector_size), condition_(plan.condition),
dynamic_(plan.dynamic) {}
private: private:
Stmt VisitStmt_(const ForNode *node) final { Stmt VisitStmt_(const ForNode *node) final {
...@@ -197,7 +168,6 @@ private: ...@@ -197,7 +168,6 @@ private:
ICHECK(extent % vector_size_ == 0) ICHECK(extent % vector_size_ == 0)
<< "extent: " << extent << " vector_size_: " << vector_size_; << "extent: " << extent << " vector_size_: " << vector_size_;
ICHECK(is_zero(fnode->min)); ICHECK(is_zero(fnode->min));
if (!dynamic_) { // check dynamic shape
if (extent == vector_size_) { if (extent == vector_size_) {
fnode.CopyOnWrite()->kind = ForKind::kVectorized; fnode.CopyOnWrite()->kind = ForKind::kVectorized;
return fnode; return fnode;
...@@ -212,9 +182,6 @@ private: ...@@ -212,9 +182,6 @@ private:
fnode->thread_binding, fnode->annotations, fnode->span); fnode->thread_binding, fnode->annotations, fnode->span);
return body; return body;
} }
} else {
return fnode;
}
} else { } else {
return ret; return ret;
} }
...@@ -222,18 +189,25 @@ private: ...@@ -222,18 +189,25 @@ private:
const ForNode *inner_for_{}; const ForNode *inner_for_{};
const int vector_size_; const int vector_size_;
const PrimExpr condition_;
const bool dynamic_;
}; };
int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); } int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); }
VectorizePlanResult GetVectorizePlanResult(const For &loop) { bool CanProveIndependent(const PrimExpr &expr, Var var,
VectorizePlanner planner; arith::Analyzer *analyzer) {
int vector_size = planner.Plan(loop); // 1. if var doesn't exist, it is independent
bool dynamic = planner.GetDynamic(); bool used_var = UsesVar(
PrimExpr condition = planner.GetCondition(); expr, [&](const VarNode *v) { return GetRef<Var>(v).same_as(var); });
return {vector_size, dynamic, condition}; if (!used_var) {
return true;
}
// 2. if \forall v_1, v_2, f(v_1) == f(v_2), f is independent with v
Var var_1("_t", var.dtype());
auto expr_1 = Substitute(expr, {{var, var_1}});
if (analyzer->CanProveEqual(expr, expr_1)) {
return true;
}
return false;
} }
bool IndiceCanVectorize(const PrimExpr &expr, Var var, bool IndiceCanVectorize(const PrimExpr &expr, Var var,
...@@ -280,14 +254,13 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var, ...@@ -280,14 +254,13 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var,
} }
For VectorizeLoop(const For &loop, int vectorize_hint) { For VectorizeLoop(const For &loop, int vectorize_hint) {
VectorizePlanResult res{128, false, 0};
if (vectorize_hint <= 0) { if (vectorize_hint <= 0) {
res = GetVectorizePlanResult(loop); VectorizePlanner planner;
vectorize_hint = res.vector_size; vectorize_hint = planner.Plan(loop);
} }
if (vectorize_hint == 1) if (vectorize_hint == 1)
return loop; return loop;
auto rewriter = VectorizeRewriter(res); auto rewriter = VectorizeRewriter(vectorize_hint);
return Downcast<For>(rewriter(loop)); return Downcast<For>(rewriter(loop));
} }
......
...@@ -37,6 +37,10 @@ int GetVectorizeSize(const For &loop); ...@@ -37,6 +37,10 @@ int GetVectorizeSize(const For &loop);
For VectorizeLoop(const For &loop, int vectorize_hint = -1); For VectorizeLoop(const For &loop, int vectorize_hint = -1);
// Can prove expr is independent with var, i.e. the value of expr doesn't change
// when var changes
bool CanProveIndependent(const PrimExpr &expr, Var var,
arith::Analyzer *analyzer);
bool IndiceCanVectorize(const PrimExpr &expr, Var var, bool IndiceCanVectorize(const PrimExpr &expr, Var var,
const PrimExpr &iter_var_size, const PrimExpr &iter_var_size,
int target_vectorized_size, arith::Analyzer *analyzer); int target_vectorized_size, arith::Analyzer *analyzer);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment