Unverified Commit 0ff4f427 authored by Yuqi Dong's avatar Yuqi Dong Committed by GitHub
Browse files

[Feature]: Add test for atomicadd auto vectorize and remove useless code (#1019)

* update

* format

* rabbit
parent bd1c7b39
...@@ -272,7 +272,6 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { ...@@ -272,7 +272,6 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1); PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);
Array<PrimExpr> new_args; Array<PrimExpr> new_args;
new_args.push_back(StringImm("AtomicAdd"));
PrimExpr src_value = BufferLoad(src, src_indices); PrimExpr src_value = BufferLoad(src, src_indices);
if (src->dtype != dst->dtype) if (src->dtype != dst->dtype)
...@@ -288,7 +287,7 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { ...@@ -288,7 +287,7 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
new_args.push_back(src_value); new_args.push_back(src_value);
Call atomicadd_call = Call atomicadd_call =
tvm::tir::Call(dst->dtype, builtin::call_extern(), new_args); tvm::tir::Call(dst->dtype, atomicadd_elem_op(), new_args);
Stmt body = tvm::tir::Evaluate(atomicadd_call); Stmt body = tvm::tir::Evaluate(atomicadd_call);
...@@ -325,10 +324,6 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { ...@@ -325,10 +324,6 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
*/ */
LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T, LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const { InferLevel level) const {
if (!par_op_.defined()) {
arith::Analyzer analyzer;
par_op_ = ParallelOp(MakeSIMTLoop(&analyzer));
}
if (T.layout_map.count(src) && T.layout_map.count(dst)) { if (T.layout_map.count(src) && T.layout_map.count(dst)) {
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") { if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") {
const FragmentNode *src_layout = T.layout_map[src].as<FragmentNode>(); const FragmentNode *src_layout = T.layout_map[src].as<FragmentNode>();
...@@ -342,7 +337,7 @@ LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T, ...@@ -342,7 +337,7 @@ LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
} }
} }
} }
return par_op_->InferLayout(T, level); return {};
} }
/** /**
......
...@@ -295,5 +295,10 @@ TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset) ...@@ -295,5 +295,10 @@ TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(atomicadd_elem_op)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -501,6 +501,13 @@ TVM_DLL const Op &initialize_descriptor(); ...@@ -501,6 +501,13 @@ TVM_DLL const Op &initialize_descriptor();
* tilelang. * tilelang.
*/ */
TVM_DLL const Op &increase_descriptor_offset(); TVM_DLL const Op &increase_descriptor_offset();
/*!
* \brief tilelang intrinsic for element-wise atomic addition.
*
* This op is used to represent an element-wise atomic add operation in
* tilelang.
*/
TVM_DLL const Op &atomicadd_elem_op();
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
......
...@@ -23,25 +23,27 @@ AtomicAddVectorizePlanner::Plan(const For &node, int compute_capability) { ...@@ -23,25 +23,27 @@ AtomicAddVectorizePlanner::Plan(const For &node, int compute_capability) {
PostOrderVisit(node, [&](const ObjectRef &obj) { PostOrderVisit(node, [&](const ObjectRef &obj) {
if (const auto *call = obj.as<CallNode>()) { if (const auto *call = obj.as<CallNode>()) {
if (call->op == builtin::call_extern() && call->args.size() >= 2) { if (call->op == atomicadd_elem_op()) {
const auto *func_name = call->args[0].as<StringImmNode>(); if (call->args.size() < 2) {
if (!func_name) // Fallback: unexpected arity
vectorize_size_max = 1;
DLOG(WARNING) << "[AtomicAddVectorizePlanner] atomicadd_elem_op "
"expects 2 args, got "
<< call->args.size() << "; Fallback to no vectorize";
return; return;
if (func_name->value == "AtomicAdd") { }
DataType dtype; DataType dtype;
if (const auto *load = call->args[1].as<BufferLoadNode>()) { if (const auto *load = call->args[0].as<BufferLoadNode>()) {
dtype = load->dtype; dtype = load->dtype;
vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype); vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
} else if (const auto *ite = call->args[1].as<IfThenElseNode>()) { } else if (const auto *ite = call->args[0].as<IfThenElseNode>()) {
if (const auto *then_load = ite->then_case.as<BufferLoadNode>()) { if (const auto *then_load = ite->then_case.as<BufferLoadNode>()) {
dtype = then_load->dtype; dtype = then_load->dtype;
vectorize_size_max = vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
GetVectorizeSizeMax(compute_capability, dtype);
} else if (const auto *else_load = } else if (const auto *else_load =
ite->else_case.as<BufferLoadNode>()) { ite->else_case.as<BufferLoadNode>()) {
dtype = else_load->dtype; dtype = else_load->dtype;
vectorize_size_max = vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
GetVectorizeSizeMax(compute_capability, dtype);
} else { } else {
// fallback // fallback
vectorize_size_max = 1; vectorize_size_max = 1;
...@@ -57,7 +59,6 @@ AtomicAddVectorizePlanner::Plan(const For &node, int compute_capability) { ...@@ -57,7 +59,6 @@ AtomicAddVectorizePlanner::Plan(const For &node, int compute_capability) {
} }
} }
} }
}
}); });
if (vectorize_size_max <= 1) { if (vectorize_size_max <= 1) {
...@@ -75,13 +76,12 @@ void AtomicAddVectorizePlanner::VisitStmt_(const ForNode *node) { ...@@ -75,13 +76,12 @@ void AtomicAddVectorizePlanner::VisitStmt_(const ForNode *node) {
} }
void AtomicAddVectorizePlanner::VisitExpr_(const CallNode *node) { void AtomicAddVectorizePlanner::VisitExpr_(const CallNode *node) {
if (node->op == builtin::call_extern() && node->args.size() >= 2) { if (node->op == atomicadd_elem_op() && !node->args.empty()) {
if (const auto *func_name = node->args[0].as<StringImmNode>()) { if (node->args.size() < 2) {
if (func_name->value == "AtomicAdd") { return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
const BufferLoadNode *buffer_load_dst = }
node->args[1].as<BufferLoadNode>(); const BufferLoadNode *buffer_load_dst = node->args[0].as<BufferLoadNode>();
const BufferLoadNode *buffer_load_src = const BufferLoadNode *buffer_load_src = node->args[1].as<BufferLoadNode>();
node->args[2].as<BufferLoadNode>();
if (buffer_load_src && buffer_load_src->buffer.defined() && if (buffer_load_src && buffer_load_src->buffer.defined() &&
buffer_load_dst && buffer_load_dst->buffer.defined()) { buffer_load_dst && buffer_load_dst->buffer.defined()) {
Buffer dst_buffer = buffer_load_dst->buffer; Buffer dst_buffer = buffer_load_dst->buffer;
...@@ -91,8 +91,6 @@ void AtomicAddVectorizePlanner::VisitExpr_(const CallNode *node) { ...@@ -91,8 +91,6 @@ void AtomicAddVectorizePlanner::VisitExpr_(const CallNode *node) {
UpdateVectorSize(buffer_load_src->indices, src_buffer); UpdateVectorSize(buffer_load_src->indices, src_buffer);
} }
} }
}
}
return arith::IRVisitorWithAnalyzer::VisitExpr_(node); return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
} }
...@@ -188,6 +186,8 @@ private: ...@@ -188,6 +186,8 @@ private:
Stmt VisitStmt_(const ForNode *node) final { Stmt VisitStmt_(const ForNode *node) final {
inner_for_ = node; inner_for_ = node;
auto ret = StmtExprMutator::VisitStmt_(node); auto ret = StmtExprMutator::VisitStmt_(node);
if (vector_size_ == 1)
return ret;
if (inner_for_ == node) { if (inner_for_ == node) {
For fnode = ret.as<For>().value(); For fnode = ret.as<For>().value();
auto old_var = fnode->loop_var; auto old_var = fnode->loop_var;
...@@ -210,48 +210,55 @@ private: ...@@ -210,48 +210,55 @@ private:
} }
PrimExpr VisitExpr_(const CallNode *node) final { PrimExpr VisitExpr_(const CallNode *node) final {
if (dynamic_) { bool legal_vectorize = true;
return StmtExprMutator::VisitExpr_(node); if (dynamic_)
} legal_vectorize = false;
if (vector_size_ == 2 || vector_size_ == 4) { if (!(node->op == atomicadd_elem_op()))
if (node->op == builtin::call_extern() && node->args.size() >= 2) { legal_vectorize = false;
if (const auto *func_name = node->args[0].as<StringImmNode>()) { if (node->args.size() < 2)
if (func_name->value == "AtomicAdd") { legal_vectorize = false;
const BufferLoadNode *temp_dst_node = if (legal_vectorize) {
node->args[1].as<BufferLoadNode>(); const BufferLoadNode *temp_dst_node = node->args[0].as<BufferLoadNode>();
const BufferLoadNode *temp_value_node = const BufferLoadNode *temp_value_node =
node->args[2].as<BufferLoadNode>(); node->args[1].as<BufferLoadNode>();
if (!temp_dst_node || !temp_value_node) { if (!temp_dst_node || !temp_value_node)
return StmtExprMutator::VisitExpr_(node); legal_vectorize = false;
} }
const BufferLoad dst_node = if (legal_vectorize) {
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()); const BufferLoad dst_node = Downcast<BufferLoad>(node->args[0]);
const BufferLoad value_node = const BufferLoad value_node = Downcast<BufferLoad>(node->args[1]);
Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>());
Call address_of_dst = Call address_of_dst =
Call(DataType::Handle(), builtin::address_of(), {dst_node}); Call(DataType::Handle(), builtin::address_of(), {dst_node});
Call address_of_value = Call address_of_value =
Call(DataType::Handle(), builtin::address_of(), {value_node}); Call(DataType::Handle(), builtin::address_of(), {value_node});
Array<PrimExpr> new_args; Array<PrimExpr> new_args;
if (vector_size_ == 2) { if (vector_size_ == 4) {
new_args.push_back(StringImm("AtomicAddx4"));
} else if (vector_size_ == 2) {
new_args.push_back(StringImm("AtomicAddx2")); new_args.push_back(StringImm("AtomicAddx2"));
} else { } else {
new_args.push_back(StringImm("AtomicAddx4")); new_args.push_back(StringImm("AtomicAdd"));
} }
new_args.push_back(address_of_dst); new_args.push_back(address_of_dst);
new_args.push_back(address_of_value); new_args.push_back(address_of_value);
Call new_call =
tvm::tir::Call(node->dtype, builtin::call_extern(), new_args);
return new_call;
} else {
Array<PrimExpr> new_args;
new_args.push_back(StringImm("AtomicAdd"));
for (auto x : node->args)
new_args.push_back(x);
Call new_call = Call new_call =
tvm::tir::Call(node->dtype, builtin::call_extern(), new_args); tvm::tir::Call(node->dtype, builtin::call_extern(), new_args);
return new_call; return new_call;
} }
} }
}
}
return StmtExprMutator::VisitExpr_(node);
}
const ForNode *inner_for_; const ForNode *inner_for_;
const int vector_size_; const int vector_size_;
...@@ -263,9 +270,6 @@ For VectorizeAtomicAdd(const For &for_node, int compute_capability) { ...@@ -263,9 +270,6 @@ For VectorizeAtomicAdd(const For &for_node, int compute_capability) {
AtomicAddVectorizePlanResult res = {1, false, 0}; AtomicAddVectorizePlanResult res = {1, false, 0};
AtomicAddVectorizePlanner planner; AtomicAddVectorizePlanner planner;
res = planner.Plan(for_node, compute_capability); res = planner.Plan(for_node, compute_capability);
int vectorize_hint = res.vector_size;
if (vectorize_hint == 1)
return for_node;
auto rewriter = AtomicAddVectorizeRewriter(res); auto rewriter = AtomicAddVectorizeRewriter(res);
return Downcast<For>(rewriter(for_node)); return Downcast<For>(rewriter(for_node));
} }
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "../layout/layout.h" #include "../layout/layout.h"
#include "../layout/utils.h" #include "../layout/utils.h"
#include "../op/builtin.h"
#include "arith/int_operator.h" #include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h" #include "arith/ir_visitor_with_analyzer.h"
#include "atomicadd_vectorize.h" #include "atomicadd_vectorize.h"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment