Unverified Commit a7a29c09 authored by yyttt6's avatar yyttt6 Committed by GitHub
Browse files

[Bugfix]:Fix atomic add auto vectorize negative optimization (#765)

* [Bugfix]:Fix atomic add auto vectorize negative optimization

* fixbug

* format

* fix bug
parent 2af3f22e
/*!
* \file tl/op/atomic_add.cc
*
* Define elment-wise operators.
* Define element-wise operators.
*/
#include "./atomic_add.h"
......@@ -368,10 +368,8 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Range thread_bounds = T.thread_bounds;
auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
// TODO(@dyq): buggy implementation, need to fix
// vectorized_thread_loop = VectorizeAtomicAdd(
// thread_loop, thread_var, thread_bounds, GetArchInt(target));
auto vectorized_thread_loop = VectorizeLoop(thread_loop);
auto vectorized_thread_loop = VectorizeAtomicAdd(
thread_loop, thread_var, thread_bounds, GetArchInt(target));
if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
......
......@@ -125,7 +125,7 @@ private:
// dynamic shape load: get the vectorization condition
dynamic_ = true;
PrimExpr offset = buffer.OffsetOf(indices).back();
condition_ = (FloorMod(offset, vector_size_) == 0);
condition_ = (truncmod(offset, vector_size_) == 0);
}
}
......@@ -141,9 +141,17 @@ private:
class AtomicAddVectorizeRewriter : public StmtExprMutator {
public:
AtomicAddVectorizeRewriter(AtomicAddVectorizePlanResult plan)
AtomicAddVectorizeRewriter(AtomicAddVectorizePlanResult plan, Var thread_var,
PrimExpr by_var, PrimExpr bx_var,
Range thread_bounds, int stride_y, int stride_x)
: vector_size_(plan.vector_size), condition_(plan.condition),
dynamic_(plan.dynamic) {}
dynamic_(plan.dynamic), tx_var_(thread_var), by_var_(by_var),
bx_var_(bx_var), stride_y_(stride_y), stride_x_(stride_x) {
const int64_t *tx_ext = as_const_int(thread_bounds->extent);
ICHECK(tx_ext)
<< "thread_bounds->extent must be a constant for vectorization.";
extent_tx_ = static_cast<int>(*tx_ext);
}
private:
/**
......@@ -174,10 +182,10 @@ private:
*/
Stmt VisitStmt_(const ForNode *node) final {
inner_for_ = node;
iter_var_ = Var(node->loop_var->name_hint + "_outer");
auto ret = StmtExprMutator::VisitStmt_(node);
if (inner_for_ == node) { // rewrite the innermost loop
For fnode = ret.as<For>().value();
auto old_var = fnode->loop_var;
auto extent_ptr = as_const_int(fnode->extent);
ICHECK(extent_ptr) << fnode->extent;
int extent = *extent_ptr;
......@@ -185,23 +193,10 @@ private:
<< "extent: " << extent << " vector_size_: " << vector_size_;
ICHECK(is_zero(fnode->min));
if (!dynamic_) {
Var tx_var;
PostOrderVisit(fnode->body, [&tx_var](const ObjectRef &node) {
if (const VarNode *var = node.as<VarNode>()) {
if (var->name_hint == "tx") {
tx_var = GetRef<Var>(var);
}
}
});
ICHECK(tx_var.defined()) << "Failed to find tx var";
Var outer_var = Var(old_var->name_hint + "_outer");
Map<Var, PrimExpr> vmap;
// Scale thread index (tx) and loop variable by vector_size to map each
// new iteration to a vectorized chunk
vmap.Set(tx_var, tx_var * vector_size_);
vmap.Set(fnode->loop_var, outer_var * vector_size_);
vmap.Set(fnode->loop_var, iter_var_);
Stmt body = Substitute(fnode->body, vmap);
return For(outer_var, 0, extent / vector_size_, fnode->kind, body,
return For(iter_var_, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
}
}
......@@ -209,24 +204,80 @@ private:
}
PrimExpr VisitExpr_(const CallNode *node) final {
if (dynamic_) {
return StmtExprMutator::VisitExpr_(node);
}
if (vector_size_ == 2 || vector_size_ == 4) {
if (node->op == builtin::call_extern() && node->args.size() >= 2) {
if (const auto *func_name = node->args[0].as<StringImmNode>()) {
if (func_name->value == "AtomicAdd") {
PrimExpr value_node = node->args[2];
Call address_of_value = tvm::tir::Call(
DataType::Handle(), builtin::address_of(), {value_node});
// Matrix[by * stride_y + i / (stride_x / (tx_txtent *
// vector_size_)) + tx_var_ / (stride_x / vector_size_),
// bx * stride_x + (i % (stride_x / (tx_extent *
// vector_size_)) * (tx_extent * vector_size_) + (tx_var_ %
// (stride / vector_size_)) * vector_size_]
const CallNode *addr_call = node->args[1].as<CallNode>();
if (!addr_call || addr_call->op != builtin::address_of() ||
addr_call->args.size() != 1) {
return StmtExprMutator::VisitExpr_(node);
}
const BufferLoadNode *old_dst_node =
addr_call->args[0].as<BufferLoadNode>();
const BufferLoadNode *old_value_node =
node->args[2].as<BufferLoadNode>();
if (!old_dst_node || !old_value_node) {
return StmtExprMutator::VisitExpr_(node);
}
Array<PrimExpr> dst_indices, value_indices;
if ((extent_tx_ * vector_size_) > stride_x_) {
dst_indices.push_back(
by_var_ * stride_y_ +
iter_var_ * (extent_tx_ * vector_size_ / stride_x_) +
truncdiv(tx_var_, stride_x_ / vector_size_));
dst_indices.push_back(
bx_var_ * stride_x_ +
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
value_indices.push_back(
iter_var_ * (extent_tx_ * vector_size_ / stride_x_) +
truncdiv(tx_var_ * vector_size_, stride_x_));
value_indices.push_back(
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
} else {
dst_indices.push_back(
by_var_ * stride_y_ +
truncdiv(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) +
truncdiv(tx_var_, stride_x_ / vector_size_));
dst_indices.push_back(
bx_var_ * stride_x_ +
truncmod(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) *
(extent_tx_ * vector_size_) +
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
value_indices.push_back(
truncdiv(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) +
truncdiv(tx_var_, stride_x_ / vector_size_));
value_indices.push_back(
truncmod(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) *
(extent_tx_ * vector_size_) +
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
}
BufferLoad dst_node =
BufferLoad(old_dst_node->buffer, dst_indices,
old_dst_node->predicate, old_dst_node->span);
BufferLoad value_node =
BufferLoad(old_value_node->buffer, value_indices,
old_value_node->predicate, old_value_node->span);
Call address_of_dst =
Call(DataType::Handle(), builtin::address_of(), {dst_node});
Call address_of_value =
Call(DataType::Handle(), builtin::address_of(), {value_node});
Array<PrimExpr> new_args;
if (vector_size_ == 2) {
new_args.push_back(StringImm("AtomicAddx2"));
} else {
new_args.push_back(StringImm("AtomicAddx4"));
}
new_args.push_back(node->args[1]);
new_args.push_back(address_of_dst);
new_args.push_back(address_of_value);
Call new_call =
......@@ -244,6 +295,11 @@ private:
const int vector_size_;
const PrimExpr condition_;
const bool dynamic_;
const PrimExpr by_var_, bx_var_;
int stride_y_, stride_x_;
const Var tx_var_;
Var iter_var_;
int extent_tx_;
};
static int GetVectorizeSizeMax(int compute_capability, DataType dtype) {
......@@ -272,6 +328,8 @@ For VectorizeAtomicAdd(const For &for_node, Var thread_var, Range thread_bounds,
int compute_capability) {
int vectorize_size_max = 1;
int stride_x = -1, stride_y = -1;
PrimExpr bx_var, by_var;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto *call = obj.as<CallNode>()) {
......@@ -284,8 +342,27 @@ For VectorizeAtomicAdd(const For &for_node, Var thread_var, Range thread_bounds,
}
}
}
if (const MulNode *mul = obj.as<MulNode>()) {
const VarNode *var = nullptr;
const IntImmNode *imm = nullptr;
PrimExpr var_expr;
if ((var = mul->a.as<VarNode>()) && (imm = mul->b.as<IntImmNode>())) {
var_expr = mul->a;
} else if ((var = mul->b.as<VarNode>()) &&
(imm = mul->a.as<IntImmNode>())) {
var_expr = mul->b;
}
if (var && imm) {
if (var->name_hint == "bx") {
stride_x = imm->value;
bx_var = var_expr;
} else if (var->name_hint == "by") {
stride_y = imm->value;
by_var = var_expr;
}
}
}
});
if (vectorize_size_max != 1) {
int vectorize_hint = vectorize_size_max;
AtomicAddVectorizePlanResult res = {1, false, 0};
......@@ -293,9 +370,11 @@ For VectorizeAtomicAdd(const For &for_node, Var thread_var, Range thread_bounds,
res = planner.Plan(for_node, thread_var, thread_bounds, vectorize_hint);
vectorize_hint = res.vector_size;
if (vectorize_hint == 1)
if (vectorize_hint == 1 || stride_x == -1 || stride_y == -1 ||
!bx_var.defined() || !by_var.defined())
return for_node;
auto rewriter = AtomicAddVectorizeRewriter(res);
auto rewriter = AtomicAddVectorizeRewriter(
res, thread_var, by_var, bx_var, thread_bounds, stride_y, stride_x);
return Downcast<For>(rewriter(for_node));
} else {
return for_node;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment