Unverified Commit 2d4b848f authored by Kurisu's avatar Kurisu Committed by GitHub
Browse files

[Fix] tilelang can now vectorize `B[i,j] = c[i] + A[i,j]` (#798)

* Fix bug 0905: vectorize with broadcasted value

* fix lint error

* [Refactor] Use `tvm::tir::UseVar` and use Vectorizer

* Add loop size check in vectorize planner

* fix lint error
parent fa4fd0b7
......@@ -24,17 +24,14 @@
#include "loop_vectorize.h"
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <numeric>
#include "../layout/layout.h"
#include "../layout/utils.h"
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_vectorization_utils.h"
#include "tvm/tir/analysis.h"
#include "tvm/tir/var.h"
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
namespace tvm {
namespace tl {
......@@ -56,15 +53,18 @@ public:
return vector_size_;
}
bool GetDynamic() { return dynamic_; }
PrimExpr GetCondition() { return condition_; }
private:
void VisitStmt_(const ForNode *node) final {
inner_for_ = node;
iter_map_.Set(node->loop_var, Range(node->min, node->extent));
auto extent_ptr = as_const_int(node->extent);
// Here I disable dynamic shape completely,
// In order to do it, the Planner should accept an analyzer with
// arithmetic info outside to prove the dividiblity of vector size
if (!extent_ptr) {
vector_size_ = 1;
return;
}
vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr);
arith::IRVisitorWithAnalyzer::VisitStmt_(node);
}
......@@ -113,76 +113,47 @@ private:
void UpdateVectorSize(const Array<PrimExpr> &indices, const Buffer &buffer) {
if (!inner_for_)
return;
auto extent_ptr = inner_for_->extent.as<IntImmNode>();
if (!extent_ptr)
// 1. Compute raw element offset
auto strides = buffer->strides;
if (buffer->strides.empty()) {
PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) {
strides.push_back(stride);
stride = stride * buffer->shape[i];
}
strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
}
PrimExpr elem_offset = 0;
for (int i = 0; i < indices.size(); ++i) {
elem_offset += indices[i] * strides[i];
}
// 2. If element offset is independent with loop_var, ignore it
if (CanProveIndependent(elem_offset, inner_for_->loop_var, &analyzer_)) {
return;
}
const DataType &access_type = buffer->dtype;
// i // 2, i % 8 can also be vectorized as factor 16
int max_vector_size = vector_load_bits_max_ / access_type.bits();
// so we should disable this GCD optimization
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
auto last_dim = buffer->shape.back();
auto mod_set = analyzer_.modular_set(last_dim);
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
// conditionally tail vectorize
if (buffer->shape.back().as<IntImmNode>()) {
max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);
auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base);
// If gcd_base is equal to the last dimension,
// we should analyze the second-to-last dimension
// in relation to the last dimension.
if (gcd_base < Downcast<IntImm>(last_dim)->value) {
max_vector_size = gcd_base;
}
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);
// Generate strides if not existed
auto strides = buffer->strides;
if (buffer->strides.empty()) {
PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) {
strides.push_back(stride);
stride = stride * buffer->shape[i];
}
strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
}
// 3. Tight vectorize bound
vector_size_ = arith::ZeroAwareGCD(vector_size_, vector_load_bits_max_ /
buffer->dtype.bits());
// Generate and check element offset expression
ICHECK(indices.size() == strides.size()) << "Invalid indices and strides";
PrimExpr elem_offset = 0;
for (int i = 0; i < indices.size(); ++i) {
elem_offset += indices[i] * strides[i];
}
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
inner_for_->extent, vector_size_,
&analyzer_)) {
vector_size_ /= 2;
}
} else if (vector_size_ <= vector_load_bits_max_ / buffer->dtype.bits()) {
// dynamic shape load: get the vectorization condition
dynamic_ = true;
PrimExpr offset = buffer.OffsetOf(indices).back();
condition_ = (FloorMod(offset, vector_size_) == 0);
// 4. Try to vectorize buffer load
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
inner_for_->extent, vector_size_, &analyzer_)) {
vector_size_ /= 2;
}
}
const int vector_load_bits_max_ = 128;
const ForNode *inner_for_{};
Map<Var, Range> iter_map_;
bool has_nonlocal_memory_access_ = false;
int vector_size_ = 128;
// conditionally vectorize
bool dynamic_ = false;
PrimExpr condition_;
};
class VectorizeRewriter : public StmtExprMutator {
public:
VectorizeRewriter(const VectorizePlanResult &plan)
: vector_size_(plan.vector_size), condition_(plan.condition),
dynamic_(plan.dynamic) {}
VectorizeRewriter(int vector_size) : vector_size_(vector_size) {}
private:
Stmt VisitStmt_(const ForNode *node) final {
......@@ -197,23 +168,19 @@ private:
ICHECK(extent % vector_size_ == 0)
<< "extent: " << extent << " vector_size_: " << vector_size_;
ICHECK(is_zero(fnode->min));
if (!dynamic_) { // check dynamic shape
if (extent == vector_size_) {
fnode.CopyOnWrite()->kind = ForKind::kVectorized;
return fnode;
} else {
Var inner_var = Var("vec");
Var outer_var = Var(old_var->name_hint);
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
return body;
}
} else {
if (extent == vector_size_) {
fnode.CopyOnWrite()->kind = ForKind::kVectorized;
return fnode;
} else {
Var inner_var = Var("vec");
Var outer_var = Var(old_var->name_hint);
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
return body;
}
} else {
return ret;
......@@ -222,18 +189,25 @@ private:
const ForNode *inner_for_{};
const int vector_size_;
const PrimExpr condition_;
const bool dynamic_;
};
int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); }
VectorizePlanResult GetVectorizePlanResult(const For &loop) {
VectorizePlanner planner;
int vector_size = planner.Plan(loop);
bool dynamic = planner.GetDynamic();
PrimExpr condition = planner.GetCondition();
return {vector_size, dynamic, condition};
bool CanProveIndependent(const PrimExpr &expr, Var var,
arith::Analyzer *analyzer) {
// 1. if var doesn't exist, it is independent
bool used_var = UsesVar(
expr, [&](const VarNode *v) { return GetRef<Var>(v).same_as(var); });
if (!used_var) {
return true;
}
// 2. if \forall v_1, v_2, f(v_1) == f(v_2), f is independent with v
Var var_1("_t", var.dtype());
auto expr_1 = Substitute(expr, {{var, var_1}});
if (analyzer->CanProveEqual(expr, expr_1)) {
return true;
}
return false;
}
bool IndiceCanVectorize(const PrimExpr &expr, Var var,
......@@ -280,14 +254,13 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var,
}
For VectorizeLoop(const For &loop, int vectorize_hint) {
VectorizePlanResult res{128, false, 0};
if (vectorize_hint <= 0) {
res = GetVectorizePlanResult(loop);
vectorize_hint = res.vector_size;
VectorizePlanner planner;
vectorize_hint = planner.Plan(loop);
}
if (vectorize_hint == 1)
return loop;
auto rewriter = VectorizeRewriter(res);
auto rewriter = VectorizeRewriter(vectorize_hint);
return Downcast<For>(rewriter(loop));
}
......
......@@ -37,6 +37,10 @@ int GetVectorizeSize(const For &loop);
For VectorizeLoop(const For &loop, int vectorize_hint = -1);
// Can prove expr is independent with var, i.e. the value of expr doesn't change
// when var changes
bool CanProveIndependent(const PrimExpr &expr, Var var,
arith::Analyzer *analyzer);
bool IndiceCanVectorize(const PrimExpr &expr, Var var,
const PrimExpr &iter_var_size,
int target_vectorized_size, arith::Analyzer *analyzer);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment