Unverified Commit a7a29c09 authored by yyttt6's avatar yyttt6 Committed by GitHub
Browse files

[Bugfix]:Fix atomic add auto vectorize negative optimization (#765)

* [Bugfix]:Fix atomic add auto vectorize negative optimization

* fixbug

* format

* fix bug
parent 2af3f22e
/*! /*!
* \file tl/op/atomic_add.cc * \file tl/op/atomic_add.cc
* *
* Define elment-wise operators. * Define element-wise operators.
*/ */
#include "./atomic_add.h" #include "./atomic_add.h"
...@@ -368,10 +368,8 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -368,10 +368,8 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Range thread_bounds = T.thread_bounds; Range thread_bounds = T.thread_bounds;
auto thread_loop = auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout); PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
// TODO(@dyq): buggy implementation, need to fix auto vectorized_thread_loop = VectorizeAtomicAdd(
// vectorized_thread_loop = VectorizeAtomicAdd( thread_loop, thread_var, thread_bounds, GetArchInt(target));
// thread_loop, thread_var, thread_bounds, GetArchInt(target));
auto vectorized_thread_loop = VectorizeLoop(thread_loop);
if (par_op->GetPredicate(T.thread_var).defined()) { if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(), return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
......
...@@ -125,7 +125,7 @@ private: ...@@ -125,7 +125,7 @@ private:
// dynamic shape load: get the vectorization condition // dynamic shape load: get the vectorization condition
dynamic_ = true; dynamic_ = true;
PrimExpr offset = buffer.OffsetOf(indices).back(); PrimExpr offset = buffer.OffsetOf(indices).back();
condition_ = (FloorMod(offset, vector_size_) == 0); condition_ = (truncmod(offset, vector_size_) == 0);
} }
} }
...@@ -141,9 +141,17 @@ private: ...@@ -141,9 +141,17 @@ private:
class AtomicAddVectorizeRewriter : public StmtExprMutator { class AtomicAddVectorizeRewriter : public StmtExprMutator {
public: public:
AtomicAddVectorizeRewriter(AtomicAddVectorizePlanResult plan) AtomicAddVectorizeRewriter(AtomicAddVectorizePlanResult plan, Var thread_var,
PrimExpr by_var, PrimExpr bx_var,
Range thread_bounds, int stride_y, int stride_x)
: vector_size_(plan.vector_size), condition_(plan.condition), : vector_size_(plan.vector_size), condition_(plan.condition),
dynamic_(plan.dynamic) {} dynamic_(plan.dynamic), tx_var_(thread_var), by_var_(by_var),
bx_var_(bx_var), stride_y_(stride_y), stride_x_(stride_x) {
const int64_t *tx_ext = as_const_int(thread_bounds->extent);
ICHECK(tx_ext)
<< "thread_bounds->extent must be a constant for vectorization.";
extent_tx_ = static_cast<int>(*tx_ext);
}
private: private:
/** /**
...@@ -174,10 +182,10 @@ private: ...@@ -174,10 +182,10 @@ private:
*/ */
Stmt VisitStmt_(const ForNode *node) final { Stmt VisitStmt_(const ForNode *node) final {
inner_for_ = node; inner_for_ = node;
iter_var_ = Var(node->loop_var->name_hint + "_outer");
auto ret = StmtExprMutator::VisitStmt_(node); auto ret = StmtExprMutator::VisitStmt_(node);
if (inner_for_ == node) { // rewrite the innermost loop if (inner_for_ == node) { // rewrite the innermost loop
For fnode = ret.as<For>().value(); For fnode = ret.as<For>().value();
auto old_var = fnode->loop_var;
auto extent_ptr = as_const_int(fnode->extent); auto extent_ptr = as_const_int(fnode->extent);
ICHECK(extent_ptr) << fnode->extent; ICHECK(extent_ptr) << fnode->extent;
int extent = *extent_ptr; int extent = *extent_ptr;
...@@ -185,23 +193,10 @@ private: ...@@ -185,23 +193,10 @@ private:
<< "extent: " << extent << " vector_size_: " << vector_size_; << "extent: " << extent << " vector_size_: " << vector_size_;
ICHECK(is_zero(fnode->min)); ICHECK(is_zero(fnode->min));
if (!dynamic_) { if (!dynamic_) {
Var tx_var;
PostOrderVisit(fnode->body, [&tx_var](const ObjectRef &node) {
if (const VarNode *var = node.as<VarNode>()) {
if (var->name_hint == "tx") {
tx_var = GetRef<Var>(var);
}
}
});
ICHECK(tx_var.defined()) << "Failed to find tx var";
Var outer_var = Var(old_var->name_hint + "_outer");
Map<Var, PrimExpr> vmap; Map<Var, PrimExpr> vmap;
// Scale thread index (tx) and loop variable by vector_size to map each vmap.Set(fnode->loop_var, iter_var_);
// new iteration to a vectorized chunk
vmap.Set(tx_var, tx_var * vector_size_);
vmap.Set(fnode->loop_var, outer_var * vector_size_);
Stmt body = Substitute(fnode->body, vmap); Stmt body = Substitute(fnode->body, vmap);
return For(outer_var, 0, extent / vector_size_, fnode->kind, body, return For(iter_var_, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span); fnode->thread_binding, fnode->annotations, fnode->span);
} }
} }
...@@ -209,24 +204,80 @@ private: ...@@ -209,24 +204,80 @@ private:
} }
PrimExpr VisitExpr_(const CallNode *node) final { PrimExpr VisitExpr_(const CallNode *node) final {
if (dynamic_) {
return StmtExprMutator::VisitExpr_(node);
}
if (vector_size_ == 2 || vector_size_ == 4) { if (vector_size_ == 2 || vector_size_ == 4) {
if (node->op == builtin::call_extern() && node->args.size() >= 2) { if (node->op == builtin::call_extern() && node->args.size() >= 2) {
if (const auto *func_name = node->args[0].as<StringImmNode>()) { if (const auto *func_name = node->args[0].as<StringImmNode>()) {
if (func_name->value == "AtomicAdd") { if (func_name->value == "AtomicAdd") {
PrimExpr value_node = node->args[2]; // Matrix[by * stride_y + i / (stride_x / (tx_txtent *
// vector_size_)) + tx_var_ / (stride_x / vector_size_),
Call address_of_value = tvm::tir::Call( // bx * stride_x + (i % (stride_x / (tx_extent *
DataType::Handle(), builtin::address_of(), {value_node}); // vector_size_)) * (tx_extent * vector_size_) + (tx_var_ %
// (stride / vector_size_)) * vector_size_]
const CallNode *addr_call = node->args[1].as<CallNode>();
if (!addr_call || addr_call->op != builtin::address_of() ||
addr_call->args.size() != 1) {
return StmtExprMutator::VisitExpr_(node);
}
const BufferLoadNode *old_dst_node =
addr_call->args[0].as<BufferLoadNode>();
const BufferLoadNode *old_value_node =
node->args[2].as<BufferLoadNode>();
if (!old_dst_node || !old_value_node) {
return StmtExprMutator::VisitExpr_(node);
}
Array<PrimExpr> dst_indices, value_indices;
if ((extent_tx_ * vector_size_) > stride_x_) {
dst_indices.push_back(
by_var_ * stride_y_ +
iter_var_ * (extent_tx_ * vector_size_ / stride_x_) +
truncdiv(tx_var_, stride_x_ / vector_size_));
dst_indices.push_back(
bx_var_ * stride_x_ +
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
value_indices.push_back(
iter_var_ * (extent_tx_ * vector_size_ / stride_x_) +
truncdiv(tx_var_ * vector_size_, stride_x_));
value_indices.push_back(
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
} else {
dst_indices.push_back(
by_var_ * stride_y_ +
truncdiv(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) +
truncdiv(tx_var_, stride_x_ / vector_size_));
dst_indices.push_back(
bx_var_ * stride_x_ +
truncmod(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) *
(extent_tx_ * vector_size_) +
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
value_indices.push_back(
truncdiv(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) +
truncdiv(tx_var_, stride_x_ / vector_size_));
value_indices.push_back(
truncmod(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) *
(extent_tx_ * vector_size_) +
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
}
BufferLoad dst_node =
BufferLoad(old_dst_node->buffer, dst_indices,
old_dst_node->predicate, old_dst_node->span);
BufferLoad value_node =
BufferLoad(old_value_node->buffer, value_indices,
old_value_node->predicate, old_value_node->span);
Call address_of_dst =
Call(DataType::Handle(), builtin::address_of(), {dst_node});
Call address_of_value =
Call(DataType::Handle(), builtin::address_of(), {value_node});
Array<PrimExpr> new_args; Array<PrimExpr> new_args;
if (vector_size_ == 2) { if (vector_size_ == 2) {
new_args.push_back(StringImm("AtomicAddx2")); new_args.push_back(StringImm("AtomicAddx2"));
} else { } else {
new_args.push_back(StringImm("AtomicAddx4")); new_args.push_back(StringImm("AtomicAddx4"));
} }
new_args.push_back(address_of_dst);
new_args.push_back(node->args[1]);
new_args.push_back(address_of_value); new_args.push_back(address_of_value);
Call new_call = Call new_call =
...@@ -244,6 +295,11 @@ private: ...@@ -244,6 +295,11 @@ private:
const int vector_size_; const int vector_size_;
const PrimExpr condition_; const PrimExpr condition_;
const bool dynamic_; const bool dynamic_;
const PrimExpr by_var_, bx_var_;
int stride_y_, stride_x_;
const Var tx_var_;
Var iter_var_;
int extent_tx_;
}; };
static int GetVectorizeSizeMax(int compute_capability, DataType dtype) { static int GetVectorizeSizeMax(int compute_capability, DataType dtype) {
...@@ -272,6 +328,8 @@ For VectorizeAtomicAdd(const For &for_node, Var thread_var, Range thread_bounds, ...@@ -272,6 +328,8 @@ For VectorizeAtomicAdd(const For &for_node, Var thread_var, Range thread_bounds,
int compute_capability) { int compute_capability) {
int vectorize_size_max = 1; int vectorize_size_max = 1;
int stride_x = -1, stride_y = -1;
PrimExpr bx_var, by_var;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto *call = obj.as<CallNode>()) { if (const auto *call = obj.as<CallNode>()) {
...@@ -284,8 +342,27 @@ For VectorizeAtomicAdd(const For &for_node, Var thread_var, Range thread_bounds, ...@@ -284,8 +342,27 @@ For VectorizeAtomicAdd(const For &for_node, Var thread_var, Range thread_bounds,
} }
} }
} }
if (const MulNode *mul = obj.as<MulNode>()) {
const VarNode *var = nullptr;
const IntImmNode *imm = nullptr;
PrimExpr var_expr;
if ((var = mul->a.as<VarNode>()) && (imm = mul->b.as<IntImmNode>())) {
var_expr = mul->a;
} else if ((var = mul->b.as<VarNode>()) &&
(imm = mul->a.as<IntImmNode>())) {
var_expr = mul->b;
}
if (var && imm) {
if (var->name_hint == "bx") {
stride_x = imm->value;
bx_var = var_expr;
} else if (var->name_hint == "by") {
stride_y = imm->value;
by_var = var_expr;
}
}
}
}); });
if (vectorize_size_max != 1) { if (vectorize_size_max != 1) {
int vectorize_hint = vectorize_size_max; int vectorize_hint = vectorize_size_max;
AtomicAddVectorizePlanResult res = {1, false, 0}; AtomicAddVectorizePlanResult res = {1, false, 0};
...@@ -293,9 +370,11 @@ For VectorizeAtomicAdd(const For &for_node, Var thread_var, Range thread_bounds, ...@@ -293,9 +370,11 @@ For VectorizeAtomicAdd(const For &for_node, Var thread_var, Range thread_bounds,
res = planner.Plan(for_node, thread_var, thread_bounds, vectorize_hint); res = planner.Plan(for_node, thread_var, thread_bounds, vectorize_hint);
vectorize_hint = res.vector_size; vectorize_hint = res.vector_size;
if (vectorize_hint == 1) if (vectorize_hint == 1 || stride_x == -1 || stride_y == -1 ||
!bx_var.defined() || !by_var.defined())
return for_node; return for_node;
auto rewriter = AtomicAddVectorizeRewriter(res); auto rewriter = AtomicAddVectorizeRewriter(
res, thread_var, by_var, bx_var, thread_bounds, stride_y, stride_x);
return Downcast<For>(rewriter(for_node)); return Downcast<For>(rewriter(for_node));
} else { } else {
return for_node; return for_node;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment