Unverified Commit 340bfc50 authored by Yuqi Dong's avatar Yuqi Dong Committed by GitHub
Browse files

[Bugfix] Fix atomicadd auto vectorize identify var error (#883)

* update

* update

* update

* update
parent 4a229ddb
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "../target/utils.h" #include "../target/utils.h"
#include "../transform/atomicadd_vectorize.h" #include "../transform/atomicadd_vectorize.h"
#include "../transform/common/loop_fusion_utils.h" #include "../transform/common/loop_fusion_utils.h"
#include "../transform/common/loop_parallel_transform_utils.h"
#include "../transform/loop_partition.h" #include "../transform/loop_partition.h"
#include "builtin.h" #include "builtin.h"
...@@ -21,31 +22,6 @@ namespace tl { ...@@ -21,31 +22,6 @@ namespace tl {
using namespace tir; using namespace tir;
/**
* @brief Extracts a numeric architecture identifier from a Target's "arch"
* attribute.
*
* Reads the Target's "arch" string (must be defined) and, if it has the form
* "sm_<N>", parses and returns N as an integer. For any other arch string,
* returns 0.
*
* @param target Target whose "arch" attribute will be inspected (ICHECKs that
* the attribute is defined).
* @return int Parsed integer suffix when the arch is "sm_<N>", otherwise 0.
*/
static int GetArchInt(Target target) {
int arch_int = 0;
auto s = target->GetAttr<String>("arch");
ICHECK(s.defined());
std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0) {
arch_int = std::stoi(arch.substr(3));
} else {
arch_int = 0;
}
return arch_int;
}
/** /**
* @brief Construct an AtomicAdd operator from call arguments and a buffer map. * @brief Construct an AtomicAdd operator from call arguments and a buffer map.
* *
...@@ -328,6 +304,47 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { ...@@ -328,6 +304,47 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
return Downcast<For>(body); return Downcast<For>(body);
} }
/**
* @brief Infer and return the layout map for the atomic add operator.
*
* Constructs a cached ParallelOp (by building the SIMT loop) if not already
* present, validates that local.fragment layouts for src and dst match when
* both are provided, and then delegates layout inference to the underlying
* ParallelOp.
*
* @param T Layout inference inputs, including an optional mapping of buffers to
* layouts.
* @param level Inference strictness level.
* @return LayoutMap The inferred layout mapping for buffers used by this
* operator.
*
* @note This method mutates the AtomicAddNode by creating and storing a
* ParallelOp on first invocation.
* @throws If both src and dst have layouts in `local.fragment` and their
* fragment layouts differ, an ICHECK failure is raised with diagnostic output.
*/
LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (!par_op_.defined()) {
arith::Analyzer analyzer;
par_op_ = ParallelOp(MakeSIMTLoop(&analyzer));
}
if (T.layout_map.count(src) && T.layout_map.count(dst)) {
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") {
const FragmentNode *src_layout = T.layout_map[src].as<FragmentNode>();
const FragmentNode *dst_layout = T.layout_map[dst].as<FragmentNode>();
if (src_layout && dst_layout) {
ICHECK(src_layout->IsEqual(dst_layout, true))
<< "Get different layout for " << src << " and " << dst
<< "\nLHS = " << src_layout->DebugOutput()
<< "\nRHS = " << dst_layout->DebugOutput()
<< "\nYou may need to use a shared memory to transform the layout";
}
}
}
return par_op_->InferLayout(T, level);
}
/** /**
* @brief Lower the atomic-add top-level operator into a parallel, vectorized * @brief Lower the atomic-add top-level operator into a parallel, vectorized
* TIR loop. * TIR loop.
...@@ -389,70 +406,142 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -389,70 +406,142 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
} }
auto simt_loop = MakeSIMTLoop(analyzer); auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop)); auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
auto par_op = ParallelOp(fused_loop); auto transformed_loop =
Downcast<For>(ParallelLoopTransformer::Substitute(fused_loop));
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree}; auto GetArchInt = [&](const Target &tgt) -> int {
for (auto level : levels) { int arch_int = 0;
(par_op)->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer, if (auto s = tgt->GetAttr<String>("arch")) {
false, T.buffer_remap}, std::string arch = s.value();
level); if (arch.rfind("sm_", 0) == 0)
} arch_int = std::stoi(arch.substr(3));
auto loop_layout = par_op->GetLoopLayout(); }
Var thread_var = T.thread_var; return arch_int;
Range thread_bounds = T.thread_bounds; };
auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
auto vectorized_thread_loop = VectorizeAtomicAdd(
thread_loop, thread_var, thread_bounds, GetArchInt(target));
if (par_op->GetPredicate(T.thread_var).defined()) { struct AtomicLoopNestCollector : tir::StmtExprVisitor {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(), Array<IterVar> loop_vars;
vectorized_thread_loop); Map<Buffer, Array<PrimExpr>> indice_map;
} std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> writes;
arith::Analyzer analyzer;
return vectorized_thread_loop; void Run(const Stmt &s) { StmtExprVisitor::VisitStmt(s); }
}
/** void VisitStmt_(const ForNode *op) final {
* @brief Infer and return the layout map for the atomic add operator. if (op->kind == ForKind::kParallel) {
* loop_vars.push_back(IterVar(Range(op->min, op->extent), op->loop_var,
* Constructs a cached ParallelOp (by building the SIMT loop) if not already IterVarType::kDataPar));
* present, validates that local.fragment layouts for src and dst match when
* both are provided, and then delegates layout inference to the underlying
* ParallelOp.
*
* @param T Layout inference inputs, including an optional mapping of buffers to
* layouts.
* @param level Inference strictness level.
* @return LayoutMap The inferred layout mapping for buffers used by this
* operator.
*
* @note This method mutates the AtomicAddNode by creating and storing a
* ParallelOp on first invocation.
* @throws If both src and dst have layouts in `local.fragment` and their
* fragment layouts differ, an ICHECK failure is raised with diagnostic output.
*/
LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (!par_op_.defined()) {
arith::Analyzer analyzer;
par_op_ = ParallelOp(MakeSIMTLoop(&analyzer));
}
if (T.layout_map.count(src) && T.layout_map.count(dst)) {
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") {
const FragmentNode *src_layout = T.layout_map[src].as<FragmentNode>();
const FragmentNode *dst_layout = T.layout_map[dst].as<FragmentNode>();
if (src_layout && dst_layout) {
ICHECK(src_layout->IsEqual(dst_layout, true))
<< "Get different layout for " << src << " and " << dst
<< "\nLHS = " << src_layout->DebugOutput()
<< "\nRHS = " << dst_layout->DebugOutput()
<< "\nYou may need to use a shared memory to transform the layout";
} }
analyzer.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
StmtExprVisitor::VisitStmt_(op);
} }
} void VisitStmt_(const BufferStoreNode *op) final {
return par_op_->InferLayout(T, level); if (op->buffer.scope() == "local.fragment") {
indice_map.Set(op->buffer, op->indices);
writes.insert(op->buffer);
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitExpr_(const BufferLoadNode *op) final {
if (op->buffer.scope() == "local.fragment") {
indice_map.Set(op->buffer, op->indices);
}
StmtExprVisitor::VisitExpr_(op);
}
};
auto ComputeLoopLayoutFromBuffer =
[&](const Buffer &buf, const Array<PrimExpr> &indices,
const LayoutMap &layout_map, const Range &thread_bounds,
const Array<IterVar> &loop_vars) -> Fragment {
Fragment src = layout_map[buf].as<Fragment>().value();
Var rep;
auto rep_iter =
IterVar(Range(0, src->ReplicateExtent()), rep, IterVarType::kDataPar);
PrimExpr fth = src->ForwardThread(indices, rep);
fth = analyzer->Simplify(fth);
Fragment out = Fragment(loop_vars, /*forward_index=*/{}, fth, rep_iter)
->BindThreadRange(thread_bounds);
return out;
};
struct AtomicInferResult {
Fragment loop_layout;
Optional<PrimExpr> predicate;
};
auto AtomicAddInferLayout =
[&](const For &loop, const LayoutInferArgs &args) -> AtomicInferResult {
AtomicLoopNestCollector C;
C.Run(loop);
Optional<Buffer> read_src;
int best_rank = -1;
for (auto kv : C.indice_map) {
const Buffer &buf = kv.first;
if (buf.scope() != "local.fragment")
continue;
if (!args.layout_map.count(buf))
continue;
int rank = static_cast<int>(kv.second.size());
if (rank > best_rank) {
best_rank = rank;
read_src = buf;
}
}
AtomicAddVectorizePlanner planner;
int sm = GetArchInt(target);
auto plan = planner.Plan(loop, sm);
int vec = std::max(plan.vector_size, 1);
if (auto cw = loop->annotations.Get("coalesced_width")) {
if (const auto *imm = cw->as<IntImmNode>()) {
int expected = imm->value;
ICHECK_GT(expected, 0);
ICHECK(vec % expected == 0)
<< "vector_size " << vec << " not divisible by coalesced_width "
<< expected;
vec = expected;
} else {
LOG(FATAL) << "coalesced_width should be IntImmNode.";
}
}
PrimExpr total = 1;
for (Stmt s = loop; s.as<For>().has_value(); s = s.as<For>().value()->body)
total = total * s.as<For>().value()->extent;
PrimExpr denom = args.thread_bounds->extent * vec;
while (!analyzer->CanProve(floormod(total, denom) == 0) && vec > 1) {
vec >>= 1;
denom = args.thread_bounds->extent * vec;
}
if (vec < 1)
vec = 1;
Fragment loop_layout;
if (read_src) {
loop_layout = ComputeLoopLayoutFromBuffer(
read_src.value(), C.indice_map[read_src.value()], args.layout_map,
args.thread_bounds, C.loop_vars);
} else {
const For &remapped = loop;
loop_layout = PlanLoopPartition(remapped, vec, args.thread_bounds);
}
Optional<PrimExpr> pred;
if (plan.dynamic && plan.condition.defined()) {
pred = plan.condition;
}
DLOG(INFO) << "[AtomicAddInferLayout] vec=" << vec
<< " loop_layout=" << loop_layout->DebugOutput();
return {loop_layout, pred};
};
auto ret = AtomicAddInferLayout(transformed_loop,
{T.target, T.thread_bounds, T.layout_map,
analyzer, false, T.buffer_remap});
Fragment loop_layout = ret.loop_layout;
auto thread_loop =
PartitionLoop(transformed_loop, T.thread_var, analyzer, loop_layout);
auto vectorized_thread_loop =
VectorizeAtomicAdd(thread_loop, GetArchInt(target));
return vectorized_thread_loop;
} }
TIR_REGISTER_TL_OP(AtomicAdd, atomicadd) TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
......
...@@ -3,18 +3,7 @@ ...@@ -3,18 +3,7 @@
* \brief A tool to automatically vectorize atomic add * \brief A tool to automatically vectorize atomic add
*/ */
#include "../layout/layout.h" #include "atomicadd_vectorize.h"
#include "../layout/utils.h"
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_vectorization_utils.h"
#include <numeric>
#include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <utility>
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -23,132 +12,151 @@ using namespace tir; ...@@ -23,132 +12,151 @@ using namespace tir;
using arith::IRMutatorWithAnalyzer; using arith::IRMutatorWithAnalyzer;
using arith::IRVisitorWithAnalyzer; using arith::IRVisitorWithAnalyzer;
struct AtomicAddVectorizePlanResult { AtomicAddVectorizePlanner::AtomicAddVectorizePlanner() = default;
int vector_size;
bool dynamic;
PrimExpr condition;
};
class AtomicAddVectorizePlanner : public arith::IRVisitorWithAnalyzer { AtomicAddVectorizePlanResult
public: AtomicAddVectorizePlanner::Plan(const For &node, int compute_capability) {
AtomicAddVectorizePlanner() = default; int vectorize_size_max = 1;
int max_vector_size = 1; this->vector_size_ = 4;
AtomicAddVectorizePlanResult Plan(const For &node, Var thread_var, this->dynamic_ = false;
Range thread_bounds, int vectorize_hint) { this->condition_ = PrimExpr();
this->max_vector_size = vectorize_hint;
this->thread_var = std::move(thread_var);
this->thread_bounds = std::move(thread_bounds);
this->operator()(node);
return {vector_size_, dynamic_, condition_};
}
private: PostOrderVisit(node, [&](const ObjectRef &obj) {
void VisitStmt_(const ForNode *node) final { if (const auto *call = obj.as<CallNode>()) {
inner_for_ = node; if (call->op == builtin::call_extern() && call->args.size() >= 2) {
iter_map_.Set(node->loop_var, Range(node->min, node->extent)); const auto *func_name = call->args[0].as<StringImmNode>();
if (!func_name)
return;
if (func_name->value == "AtomicAdd") {
DataType dtype;
if (const auto *load = call->args[1].as<BufferLoadNode>()) {
dtype = load->dtype;
vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
} else if (const auto *ite = call->args[1].as<IfThenElseNode>()) {
if (const auto *then_load = ite->then_case.as<BufferLoadNode>()) {
dtype = then_load->dtype;
vectorize_size_max =
GetVectorizeSizeMax(compute_capability, dtype);
} else if (const auto *else_load =
ite->else_case.as<BufferLoadNode>()) {
dtype = else_load->dtype;
vectorize_size_max =
GetVectorizeSizeMax(compute_capability, dtype);
} else {
// fallback
vectorize_size_max = 1;
DLOG(WARNING) << "[AtomicAddVectorizePlanner] IfThenElse case "
"has no BufferLoad; Fallback to no vectorize";
}
} else {
// fallback
vectorize_size_max = 1;
DLOG(WARNING) << "[AtomicAddVectorizePlanner] Unexpected arg1 type "
<< call->args[1]->GetTypeKey()
<< "; Fallback to no vectorize";
}
}
}
}
});
arith::IRVisitorWithAnalyzer::VisitStmt_(node); if (vectorize_size_max <= 1) {
return {1, dynamic_, condition_};
} }
void VisitExpr_(const CallNode *node) final { this->max_vector_size = vectorize_size_max;
if (node->op == builtin::call_extern() && node->args.size() >= 2) { this->operator()(node);
if (const auto *func_name = node->args[0].as<StringImmNode>()) { return {vector_size_, dynamic_, condition_};
if (func_name->value == "AtomicAdd") { }
const BufferLoadNode *buffer_load_dst =
node->args[1].as<BufferLoadNode>();
const BufferLoadNode *buffer_load_src =
node->args[2].as<BufferLoadNode>();
if (buffer_load_src && buffer_load_src->buffer.defined() &&
buffer_load_dst && buffer_load_dst->buffer.defined()) {
Buffer dst_buffer = buffer_load_dst->buffer; void AtomicAddVectorizePlanner::VisitStmt_(const ForNode *node) {
Array<PrimExpr> indices_dst = buffer_load_dst->indices; inner_for_ = node;
UpdateVectorSize(indices_dst, dst_buffer); arith::IRVisitorWithAnalyzer::VisitStmt_(node);
Buffer src_buffer = buffer_load_src->buffer; }
Array<PrimExpr> indices_src = buffer_load_src->indices;
UpdateVectorSize(indices_src, src_buffer); void AtomicAddVectorizePlanner::VisitExpr_(const CallNode *node) {
} if (node->op == builtin::call_extern() && node->args.size() >= 2) {
if (const auto *func_name = node->args[0].as<StringImmNode>()) {
if (func_name->value == "AtomicAdd") {
const BufferLoadNode *buffer_load_dst =
node->args[1].as<BufferLoadNode>();
const BufferLoadNode *buffer_load_src =
node->args[2].as<BufferLoadNode>();
if (buffer_load_src && buffer_load_src->buffer.defined() &&
buffer_load_dst && buffer_load_dst->buffer.defined()) {
Buffer dst_buffer = buffer_load_dst->buffer;
UpdateVectorSize(buffer_load_dst->indices, dst_buffer);
Buffer src_buffer = buffer_load_src->buffer;
UpdateVectorSize(buffer_load_src->indices, src_buffer);
} }
} }
} }
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
} }
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
}
void UpdateVectorSize(const Array<PrimExpr> &indices, const Buffer &buffer) { int AtomicAddVectorizePlanner::GetVectorizeSizeMax(int compute_capability,
if (!inner_for_) DataType dtype) {
return; if (dtype == DataType::Float(16)) {
auto extent_ptr = inner_for_->extent.as<IntImmNode>(); return 2;
if (!extent_ptr) }
return; if (dtype == DataType::BFloat(16)) {
return compute_capability > 75 ? 2 : 1;
}
if (dtype == DataType::Float(32)) {
return compute_capability >= 90 ? 4 : 1;
}
return 1;
}
const DataType &access_type = buffer->dtype; void AtomicAddVectorizePlanner::UpdateVectorSize(const Array<PrimExpr> &indices,
// i // 2, i % 8 can also be vectorized as factor 16 const Buffer &buffer) {
// so we should disable this GCD optimization if (!inner_for_)
return;
auto extent_ptr = inner_for_->extent.as<IntImmNode>();
if (!extent_ptr)
return;
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value); const DataType &access_type = buffer->dtype;
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
auto last_dim = buffer->shape.back(); auto last_dim = buffer->shape.back();
auto mod_set = analyzer_.modular_set(last_dim); auto mod_set = analyzer_.modular_set(last_dim);
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
// conditionally tail vectorize
if (buffer->shape.back().as<IntImmNode>()) {
max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff); if (buffer->shape.back().as<IntImmNode>()) {
max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);
auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base);
auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base); if (gcd_base < Downcast<IntImm>(last_dim)->value) {
// If gcd_base is equal to the last dimension, max_vector_size = gcd_base;
// we should analyze the second-to-last dimension }
// in relation to the last dimension.
if (gcd_base < Downcast<IntImm>(last_dim)->value) {
max_vector_size = gcd_base;
}
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_); vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);
PrimExpr elem_offset = 0; PrimExpr elem_offset = 0;
PrimExpr stride = 1; PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) { for (int i = indices.size() - 1; i >= 0; --i) {
elem_offset = elem_offset + indices[i] * stride; elem_offset = elem_offset + indices[i] * stride;
stride = stride * buffer->shape[i]; stride = stride * buffer->shape[i];
}
PrimExpr thread_extent = thread_bounds->extent;
while (!IndiceCanVectorize(elem_offset, thread_var, thread_extent,
vector_size_, &analyzer_)) {
vector_size_ /= 2;
}
} else if (vector_size_ <= 4) {
// dynamic shape load: get the vectorization condition
dynamic_ = true;
PrimExpr offset = buffer.OffsetOf(indices).back();
condition_ = (truncmod(offset, vector_size_) == 0);
} }
}
const ForNode *inner_for_; while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
Map<Var, Range> iter_map_; inner_for_->extent, vector_size_, &analyzer_)) {
bool has_nonlocal_memory_access_ = false; vector_size_ /= 2;
int vector_size_ = 4; }
Var thread_var; } else if (vector_size_ <= 4) {
Range thread_bounds; dynamic_ = true;
bool dynamic_ = false; PrimExpr offset = buffer.OffsetOf(indices).back();
PrimExpr condition_; condition_ = (truncmod(offset, vector_size_) == 0);
}; }
}
class AtomicAddVectorizeRewriter : public StmtExprMutator { class AtomicAddVectorizeRewriter : public StmtExprMutator {
public: public:
AtomicAddVectorizeRewriter(const AtomicAddVectorizePlanResult &plan, AtomicAddVectorizeRewriter(const AtomicAddVectorizePlanResult &plan)
Var thread_var, PrimExpr by_var, PrimExpr bx_var, : vector_size_(plan.vector_size), dynamic_(plan.dynamic),
const Range &thread_bounds, int stride_y, condition_(plan.condition) {}
int stride_x)
: vector_size_(plan.vector_size), condition_(plan.condition),
dynamic_(plan.dynamic), tx_var_(std::move(thread_var)),
by_var_(std::move(by_var)), bx_var_(std::move(bx_var)),
stride_y_(stride_y), stride_x_(stride_x) {
const int64_t *tx_ext = as_const_int(thread_bounds->extent);
ICHECK(tx_ext)
<< "thread_bounds->extent must be a constant for vectorization.";
extent_tx_ = static_cast<int>(*tx_ext);
}
private: private:
/** /**
...@@ -179,10 +187,11 @@ private: ...@@ -179,10 +187,11 @@ private:
*/ */
Stmt VisitStmt_(const ForNode *node) final { Stmt VisitStmt_(const ForNode *node) final {
inner_for_ = node; inner_for_ = node;
iter_var_ = Var(node->loop_var->name_hint + "_outer");
auto ret = StmtExprMutator::VisitStmt_(node); auto ret = StmtExprMutator::VisitStmt_(node);
if (inner_for_ == node) { // rewrite the innermost loop if (inner_for_ == node) {
For fnode = ret.as<For>().value(); For fnode = ret.as<For>().value();
auto old_var = fnode->loop_var;
auto new_var = Var(old_var->name_hint);
auto extent_ptr = as_const_int(fnode->extent); auto extent_ptr = as_const_int(fnode->extent);
ICHECK(extent_ptr) << fnode->extent; ICHECK(extent_ptr) << fnode->extent;
int extent = *extent_ptr; int extent = *extent_ptr;
...@@ -191,9 +200,9 @@ private: ...@@ -191,9 +200,9 @@ private:
ICHECK(is_zero(fnode->min)); ICHECK(is_zero(fnode->min));
if (!dynamic_) { if (!dynamic_) {
Map<Var, PrimExpr> vmap; Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, iter_var_); vmap.Set(old_var, new_var * vector_size_);
Stmt body = Substitute(fnode->body, vmap); Stmt body = Substitute(fnode->body, vmap);
return For(iter_var_, 0, extent / vector_size_, fnode->kind, body, return For(new_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span); fnode->thread_binding, fnode->annotations, fnode->span);
} }
} }
...@@ -208,57 +217,18 @@ private: ...@@ -208,57 +217,18 @@ private:
if (node->op == builtin::call_extern() && node->args.size() >= 2) { if (node->op == builtin::call_extern() && node->args.size() >= 2) {
if (const auto *func_name = node->args[0].as<StringImmNode>()) { if (const auto *func_name = node->args[0].as<StringImmNode>()) {
if (func_name->value == "AtomicAdd") { if (func_name->value == "AtomicAdd") {
// Matrix[by * stride_y + i / (stride_x / (tx_txtent * const BufferLoadNode *temp_dst_node =
// vector_size_)) + tx_var_ / (stride_x / vector_size_),
// bx * stride_x + (i % (stride_x / (tx_extent *
// vector_size_)) * (tx_extent * vector_size_) + (tx_var_ %
// (stride / vector_size_)) * vector_size_]
const BufferLoadNode *old_dst_node =
node->args[1].as<BufferLoadNode>(); node->args[1].as<BufferLoadNode>();
const BufferLoadNode *old_value_node = const BufferLoadNode *temp_value_node =
node->args[2].as<BufferLoadNode>(); node->args[2].as<BufferLoadNode>();
if (!old_dst_node || !old_value_node) { if (!temp_dst_node || !temp_value_node) {
return StmtExprMutator::VisitExpr_(node); return StmtExprMutator::VisitExpr_(node);
} }
Array<PrimExpr> dst_indices, value_indices; const BufferLoad dst_node =
if ((extent_tx_ * vector_size_) > stride_x_) { Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>());
dst_indices.push_back( const BufferLoad value_node =
by_var_ * stride_y_ + Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>());
iter_var_ * (extent_tx_ * vector_size_ / stride_x_) +
truncdiv(tx_var_, stride_x_ / vector_size_));
dst_indices.push_back(
bx_var_ * stride_x_ +
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
value_indices.push_back(
iter_var_ * (extent_tx_ * vector_size_ / stride_x_) +
truncdiv(tx_var_ * vector_size_, stride_x_));
value_indices.push_back(
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
} else {
dst_indices.push_back(
by_var_ * stride_y_ +
truncdiv(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) +
truncdiv(tx_var_, stride_x_ / vector_size_));
dst_indices.push_back(
bx_var_ * stride_x_ +
truncmod(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) *
(extent_tx_ * vector_size_) +
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
value_indices.push_back(
truncdiv(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) +
truncdiv(tx_var_, stride_x_ / vector_size_));
value_indices.push_back(
truncmod(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) *
(extent_tx_ * vector_size_) +
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
}
BufferLoad dst_node =
BufferLoad(old_dst_node->buffer, dst_indices,
old_dst_node->predicate, old_dst_node->span);
BufferLoad value_node =
BufferLoad(old_value_node->buffer, value_indices,
old_value_node->predicate, old_value_node->span);
Call address_of_dst = Call address_of_dst =
Call(DataType::Handle(), builtin::address_of(), {dst_node}); Call(DataType::Handle(), builtin::address_of(), {dst_node});
Call address_of_value = Call address_of_value =
...@@ -287,89 +257,17 @@ private: ...@@ -287,89 +257,17 @@ private:
const int vector_size_; const int vector_size_;
const PrimExpr condition_; const PrimExpr condition_;
const bool dynamic_; const bool dynamic_;
const PrimExpr by_var_, bx_var_;
int stride_y_, stride_x_;
const Var tx_var_;
Var iter_var_;
int extent_tx_;
}; };
static int GetVectorizeSizeMax(int compute_capability, DataType dtype) { For VectorizeAtomicAdd(const For &for_node, int compute_capability) {
AtomicAddVectorizePlanResult res = {1, false, 0};
if (dtype == DataType::Float(16)) { AtomicAddVectorizePlanner planner;
return 2; res = planner.Plan(for_node, compute_capability);
} int vectorize_hint = res.vector_size;
if (dtype == DataType::BFloat(16)) { if (vectorize_hint == 1)
if (compute_capability > 75) {
return 2;
} else {
return 1;
}
}
if (dtype == DataType::Float(32)) {
if (compute_capability >= 90) {
return 4;
} else {
return 1;
}
}
return 1;
}
For VectorizeAtomicAdd(const For &for_node, const Var &thread_var,
const Range &thread_bounds, int compute_capability) {
int vectorize_size_max = 1;
int stride_x = -1, stride_y = -1;
PrimExpr bx_var, by_var;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto *call = obj.as<CallNode>()) {
if (call->op == builtin::call_extern() && call->args.size() >= 2) {
const auto *func_name = call->args[0].as<StringImmNode>();
if (func_name->value == "AtomicAdd") {
DataType dtype = call->args[1].as<BufferLoadNode>()->dtype;
vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
}
}
}
if (const MulNode *mul = obj.as<MulNode>()) {
const VarNode *var = nullptr;
const IntImmNode *imm = nullptr;
PrimExpr var_expr;
if ((var = mul->a.as<VarNode>()) && (imm = mul->b.as<IntImmNode>())) {
var_expr = mul->a;
} else if ((var = mul->b.as<VarNode>()) &&
(imm = mul->a.as<IntImmNode>())) {
var_expr = mul->b;
}
if (var && imm) {
if (var->name_hint == "bx") {
stride_x = imm->value;
bx_var = var_expr;
} else if (var->name_hint == "by") {
stride_y = imm->value;
by_var = var_expr;
}
}
}
});
if (vectorize_size_max != 1) {
int vectorize_hint = vectorize_size_max;
AtomicAddVectorizePlanResult res = {1, false, 0};
AtomicAddVectorizePlanner planner;
res = planner.Plan(for_node, thread_var, thread_bounds, vectorize_hint);
vectorize_hint = res.vector_size;
if (vectorize_hint == 1 || stride_x == -1 || stride_y == -1 ||
!bx_var.defined() || !by_var.defined())
return for_node;
auto rewriter = AtomicAddVectorizeRewriter(
res, thread_var, by_var, bx_var, thread_bounds, stride_y, stride_x);
return Downcast<For>(rewriter(for_node));
} else {
return for_node; return for_node;
} auto rewriter = AtomicAddVectorizeRewriter(res);
return Downcast<For>(rewriter(for_node));
} }
} // namespace tl } // namespace tl
......
...@@ -6,16 +6,53 @@ ...@@ -6,16 +6,53 @@
#ifndef TVM_TL_ATOMICADD_VECTORIZE_H_ #ifndef TVM_TL_ATOMICADD_VECTORIZE_H_
#define TVM_TL_ATOMICADD_VECTORIZE_H_ #define TVM_TL_ATOMICADD_VECTORIZE_H_
#include "../layout/layout.h"
#include "../layout/utils.h"
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "atomicadd_vectorize.h"
#include "common/loop_vectorization_utils.h"
#include <numeric>
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <utility>
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
For VectorizeAtomicAdd(const For &for_node, const Var &thread_var, For VectorizeAtomicAdd(const For &for_node, int compute_capability);
const Range &thread_bounds, int compute_capability);
struct AtomicAddVectorizePlanResult {
int vector_size;
bool dynamic;
PrimExpr condition;
};
class AtomicAddVectorizePlanner : public arith::IRVisitorWithAnalyzer {
public:
AtomicAddVectorizePlanner();
AtomicAddVectorizePlanResult Plan(const For &node, int compute_capability);
private:
void VisitStmt_(const ForNode *node) final;
void VisitExpr_(const CallNode *node) final;
int GetVectorizeSizeMax(int compute_capability, DataType dtype);
void UpdateVectorSize(const Array<PrimExpr> &indices, const Buffer &buffer);
const ForNode *inner_for_ = nullptr;
bool has_nonlocal_memory_access_ = false;
int vector_size_ = 4;
int max_vector_size = 1;
bool dynamic_ = false;
PrimExpr condition_;
};
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment