"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "cbc6273ae8de00a2c3e900987c7a42978d33ed11"
Unverified Commit 340bfc50 authored by Yuqi Dong's avatar Yuqi Dong Committed by GitHub
Browse files

[Bugfix] Fix atomicadd auto vectorize identify var error (#883)

* update

* update

* update

* update
parent 4a229ddb
......@@ -13,6 +13,7 @@
#include "../target/utils.h"
#include "../transform/atomicadd_vectorize.h"
#include "../transform/common/loop_fusion_utils.h"
#include "../transform/common/loop_parallel_transform_utils.h"
#include "../transform/loop_partition.h"
#include "builtin.h"
......@@ -21,31 +22,6 @@ namespace tl {
using namespace tir;
/**
* @brief Extracts a numeric architecture identifier from a Target's "arch"
* attribute.
*
* Reads the Target's "arch" string (must be defined) and, if it has the form
* "sm_<N>", parses and returns N as an integer. For any other arch string,
* returns 0.
*
* @param target Target whose "arch" attribute will be inspected (ICHECKs that
* the attribute is defined).
* @return int Parsed integer suffix when the arch is "sm_<N>", otherwise 0.
*/
static int GetArchInt(Target target) {
int arch_int = 0;
auto s = target->GetAttr<String>("arch");
ICHECK(s.defined());
std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0) {
arch_int = std::stoi(arch.substr(3));
} else {
arch_int = 0;
}
return arch_int;
}
/**
* @brief Construct an AtomicAdd operator from call arguments and a buffer map.
*
......@@ -328,6 +304,47 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
return Downcast<For>(body);
}
/**
* @brief Infer and return the layout map for the atomic add operator.
*
* Constructs a cached ParallelOp (by building the SIMT loop) if not already
* present, validates that local.fragment layouts for src and dst match when
* both are provided, and then delegates layout inference to the underlying
* ParallelOp.
*
* @param T Layout inference inputs, including an optional mapping of buffers to
* layouts.
* @param level Inference strictness level.
* @return LayoutMap The inferred layout mapping for buffers used by this
* operator.
*
* @note This method mutates the AtomicAddNode by creating and storing a
* ParallelOp on first invocation.
* @throws If both src and dst have layouts in `local.fragment` and their
* fragment layouts differ, an ICHECK failure is raised with diagnostic output.
*/
LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (!par_op_.defined()) {
arith::Analyzer analyzer;
par_op_ = ParallelOp(MakeSIMTLoop(&analyzer));
}
if (T.layout_map.count(src) && T.layout_map.count(dst)) {
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") {
const FragmentNode *src_layout = T.layout_map[src].as<FragmentNode>();
const FragmentNode *dst_layout = T.layout_map[dst].as<FragmentNode>();
if (src_layout && dst_layout) {
ICHECK(src_layout->IsEqual(dst_layout, true))
<< "Get different layout for " << src << " and " << dst
<< "\nLHS = " << src_layout->DebugOutput()
<< "\nRHS = " << dst_layout->DebugOutput()
<< "\nYou may need to use a shared memory to transform the layout";
}
}
}
return par_op_->InferLayout(T, level);
}
/**
* @brief Lower the atomic-add top-level operator into a parallel, vectorized
* TIR loop.
......@@ -389,70 +406,142 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
}
auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
auto par_op = ParallelOp(fused_loop);
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree};
for (auto level : levels) {
(par_op)->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
false, T.buffer_remap},
level);
}
auto loop_layout = par_op->GetLoopLayout();
Var thread_var = T.thread_var;
Range thread_bounds = T.thread_bounds;
auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
auto vectorized_thread_loop = VectorizeAtomicAdd(
thread_loop, thread_var, thread_bounds, GetArchInt(target));
auto transformed_loop =
Downcast<For>(ParallelLoopTransformer::Substitute(fused_loop));
auto GetArchInt = [&](const Target &tgt) -> int {
int arch_int = 0;
if (auto s = tgt->GetAttr<String>("arch")) {
std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0)
arch_int = std::stoi(arch.substr(3));
}
return arch_int;
};
if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop);
}
struct AtomicLoopNestCollector : tir::StmtExprVisitor {
Array<IterVar> loop_vars;
Map<Buffer, Array<PrimExpr>> indice_map;
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> writes;
arith::Analyzer analyzer;
return vectorized_thread_loop;
}
void Run(const Stmt &s) { StmtExprVisitor::VisitStmt(s); }
/**
* @brief Infer and return the layout map for the atomic add operator.
*
* Constructs a cached ParallelOp (by building the SIMT loop) if not already
* present, validates that local.fragment layouts for src and dst match when
* both are provided, and then delegates layout inference to the underlying
* ParallelOp.
*
* @param T Layout inference inputs, including an optional mapping of buffers to
* layouts.
* @param level Inference strictness level.
* @return LayoutMap The inferred layout mapping for buffers used by this
* operator.
*
* @note This method mutates the AtomicAddNode by creating and storing a
* ParallelOp on first invocation.
* @throws If both src and dst have layouts in `local.fragment` and their
* fragment layouts differ, an ICHECK failure is raised with diagnostic output.
*/
LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (!par_op_.defined()) {
arith::Analyzer analyzer;
par_op_ = ParallelOp(MakeSIMTLoop(&analyzer));
}
if (T.layout_map.count(src) && T.layout_map.count(dst)) {
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") {
const FragmentNode *src_layout = T.layout_map[src].as<FragmentNode>();
const FragmentNode *dst_layout = T.layout_map[dst].as<FragmentNode>();
if (src_layout && dst_layout) {
ICHECK(src_layout->IsEqual(dst_layout, true))
<< "Get different layout for " << src << " and " << dst
<< "\nLHS = " << src_layout->DebugOutput()
<< "\nRHS = " << dst_layout->DebugOutput()
<< "\nYou may need to use a shared memory to transform the layout";
void VisitStmt_(const ForNode *op) final {
if (op->kind == ForKind::kParallel) {
loop_vars.push_back(IterVar(Range(op->min, op->extent), op->loop_var,
IterVarType::kDataPar));
}
analyzer.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
StmtExprVisitor::VisitStmt_(op);
}
}
return par_op_->InferLayout(T, level);
void VisitStmt_(const BufferStoreNode *op) final {
if (op->buffer.scope() == "local.fragment") {
indice_map.Set(op->buffer, op->indices);
writes.insert(op->buffer);
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitExpr_(const BufferLoadNode *op) final {
if (op->buffer.scope() == "local.fragment") {
indice_map.Set(op->buffer, op->indices);
}
StmtExprVisitor::VisitExpr_(op);
}
};
auto ComputeLoopLayoutFromBuffer =
[&](const Buffer &buf, const Array<PrimExpr> &indices,
const LayoutMap &layout_map, const Range &thread_bounds,
const Array<IterVar> &loop_vars) -> Fragment {
Fragment src = layout_map[buf].as<Fragment>().value();
Var rep;
auto rep_iter =
IterVar(Range(0, src->ReplicateExtent()), rep, IterVarType::kDataPar);
PrimExpr fth = src->ForwardThread(indices, rep);
fth = analyzer->Simplify(fth);
Fragment out = Fragment(loop_vars, /*forward_index=*/{}, fth, rep_iter)
->BindThreadRange(thread_bounds);
return out;
};
struct AtomicInferResult {
Fragment loop_layout;
Optional<PrimExpr> predicate;
};
auto AtomicAddInferLayout =
[&](const For &loop, const LayoutInferArgs &args) -> AtomicInferResult {
AtomicLoopNestCollector C;
C.Run(loop);
Optional<Buffer> read_src;
int best_rank = -1;
for (auto kv : C.indice_map) {
const Buffer &buf = kv.first;
if (buf.scope() != "local.fragment")
continue;
if (!args.layout_map.count(buf))
continue;
int rank = static_cast<int>(kv.second.size());
if (rank > best_rank) {
best_rank = rank;
read_src = buf;
}
}
AtomicAddVectorizePlanner planner;
int sm = GetArchInt(target);
auto plan = planner.Plan(loop, sm);
int vec = std::max(plan.vector_size, 1);
if (auto cw = loop->annotations.Get("coalesced_width")) {
if (const auto *imm = cw->as<IntImmNode>()) {
int expected = imm->value;
ICHECK_GT(expected, 0);
ICHECK(vec % expected == 0)
<< "vector_size " << vec << " not divisible by coalesced_width "
<< expected;
vec = expected;
} else {
LOG(FATAL) << "coalesced_width should be IntImmNode.";
}
}
PrimExpr total = 1;
for (Stmt s = loop; s.as<For>().has_value(); s = s.as<For>().value()->body)
total = total * s.as<For>().value()->extent;
PrimExpr denom = args.thread_bounds->extent * vec;
while (!analyzer->CanProve(floormod(total, denom) == 0) && vec > 1) {
vec >>= 1;
denom = args.thread_bounds->extent * vec;
}
if (vec < 1)
vec = 1;
Fragment loop_layout;
if (read_src) {
loop_layout = ComputeLoopLayoutFromBuffer(
read_src.value(), C.indice_map[read_src.value()], args.layout_map,
args.thread_bounds, C.loop_vars);
} else {
const For &remapped = loop;
loop_layout = PlanLoopPartition(remapped, vec, args.thread_bounds);
}
Optional<PrimExpr> pred;
if (plan.dynamic && plan.condition.defined()) {
pred = plan.condition;
}
DLOG(INFO) << "[AtomicAddInferLayout] vec=" << vec
<< " loop_layout=" << loop_layout->DebugOutput();
return {loop_layout, pred};
};
auto ret = AtomicAddInferLayout(transformed_loop,
{T.target, T.thread_bounds, T.layout_map,
analyzer, false, T.buffer_remap});
Fragment loop_layout = ret.loop_layout;
auto thread_loop =
PartitionLoop(transformed_loop, T.thread_var, analyzer, loop_layout);
auto vectorized_thread_loop =
VectorizeAtomicAdd(thread_loop, GetArchInt(target));
return vectorized_thread_loop;
}
TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
......
......@@ -3,18 +3,7 @@
* \brief A tool to automatically vectorize atomic add
*/
#include "../layout/layout.h"
#include "../layout/utils.h"
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_vectorization_utils.h"
#include <numeric>
#include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <utility>
#include "atomicadd_vectorize.h"
namespace tvm {
namespace tl {
......@@ -23,132 +12,151 @@ using namespace tir;
using arith::IRMutatorWithAnalyzer;
using arith::IRVisitorWithAnalyzer;
struct AtomicAddVectorizePlanResult {
int vector_size;
bool dynamic;
PrimExpr condition;
};
AtomicAddVectorizePlanner::AtomicAddVectorizePlanner() = default;
class AtomicAddVectorizePlanner : public arith::IRVisitorWithAnalyzer {
public:
AtomicAddVectorizePlanner() = default;
int max_vector_size = 1;
AtomicAddVectorizePlanResult Plan(const For &node, Var thread_var,
Range thread_bounds, int vectorize_hint) {
this->max_vector_size = vectorize_hint;
this->thread_var = std::move(thread_var);
this->thread_bounds = std::move(thread_bounds);
this->operator()(node);
return {vector_size_, dynamic_, condition_};
}
AtomicAddVectorizePlanResult
AtomicAddVectorizePlanner::Plan(const For &node, int compute_capability) {
int vectorize_size_max = 1;
this->vector_size_ = 4;
this->dynamic_ = false;
this->condition_ = PrimExpr();
private:
void VisitStmt_(const ForNode *node) final {
inner_for_ = node;
iter_map_.Set(node->loop_var, Range(node->min, node->extent));
PostOrderVisit(node, [&](const ObjectRef &obj) {
if (const auto *call = obj.as<CallNode>()) {
if (call->op == builtin::call_extern() && call->args.size() >= 2) {
const auto *func_name = call->args[0].as<StringImmNode>();
if (!func_name)
return;
if (func_name->value == "AtomicAdd") {
DataType dtype;
if (const auto *load = call->args[1].as<BufferLoadNode>()) {
dtype = load->dtype;
vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
} else if (const auto *ite = call->args[1].as<IfThenElseNode>()) {
if (const auto *then_load = ite->then_case.as<BufferLoadNode>()) {
dtype = then_load->dtype;
vectorize_size_max =
GetVectorizeSizeMax(compute_capability, dtype);
} else if (const auto *else_load =
ite->else_case.as<BufferLoadNode>()) {
dtype = else_load->dtype;
vectorize_size_max =
GetVectorizeSizeMax(compute_capability, dtype);
} else {
// fallback
vectorize_size_max = 1;
DLOG(WARNING) << "[AtomicAddVectorizePlanner] IfThenElse case "
"has no BufferLoad; Fallback to no vectorize";
}
} else {
// fallback
vectorize_size_max = 1;
DLOG(WARNING) << "[AtomicAddVectorizePlanner] Unexpected arg1 type "
<< call->args[1]->GetTypeKey()
<< "; Fallback to no vectorize";
}
}
}
}
});
arith::IRVisitorWithAnalyzer::VisitStmt_(node);
if (vectorize_size_max <= 1) {
return {1, dynamic_, condition_};
}
void VisitExpr_(const CallNode *node) final {
if (node->op == builtin::call_extern() && node->args.size() >= 2) {
if (const auto *func_name = node->args[0].as<StringImmNode>()) {
if (func_name->value == "AtomicAdd") {
const BufferLoadNode *buffer_load_dst =
node->args[1].as<BufferLoadNode>();
const BufferLoadNode *buffer_load_src =
node->args[2].as<BufferLoadNode>();
if (buffer_load_src && buffer_load_src->buffer.defined() &&
buffer_load_dst && buffer_load_dst->buffer.defined()) {
this->max_vector_size = vectorize_size_max;
this->operator()(node);
return {vector_size_, dynamic_, condition_};
}
Buffer dst_buffer = buffer_load_dst->buffer;
Array<PrimExpr> indices_dst = buffer_load_dst->indices;
UpdateVectorSize(indices_dst, dst_buffer);
Buffer src_buffer = buffer_load_src->buffer;
Array<PrimExpr> indices_src = buffer_load_src->indices;
UpdateVectorSize(indices_src, src_buffer);
}
void AtomicAddVectorizePlanner::VisitStmt_(const ForNode *node) {
inner_for_ = node;
arith::IRVisitorWithAnalyzer::VisitStmt_(node);
}
void AtomicAddVectorizePlanner::VisitExpr_(const CallNode *node) {
if (node->op == builtin::call_extern() && node->args.size() >= 2) {
if (const auto *func_name = node->args[0].as<StringImmNode>()) {
if (func_name->value == "AtomicAdd") {
const BufferLoadNode *buffer_load_dst =
node->args[1].as<BufferLoadNode>();
const BufferLoadNode *buffer_load_src =
node->args[2].as<BufferLoadNode>();
if (buffer_load_src && buffer_load_src->buffer.defined() &&
buffer_load_dst && buffer_load_dst->buffer.defined()) {
Buffer dst_buffer = buffer_load_dst->buffer;
UpdateVectorSize(buffer_load_dst->indices, dst_buffer);
Buffer src_buffer = buffer_load_src->buffer;
UpdateVectorSize(buffer_load_src->indices, src_buffer);
}
}
}
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
}
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
}
void UpdateVectorSize(const Array<PrimExpr> &indices, const Buffer &buffer) {
if (!inner_for_)
return;
auto extent_ptr = inner_for_->extent.as<IntImmNode>();
if (!extent_ptr)
return;
int AtomicAddVectorizePlanner::GetVectorizeSizeMax(int compute_capability,
DataType dtype) {
if (dtype == DataType::Float(16)) {
return 2;
}
if (dtype == DataType::BFloat(16)) {
return compute_capability > 75 ? 2 : 1;
}
if (dtype == DataType::Float(32)) {
return compute_capability >= 90 ? 4 : 1;
}
return 1;
}
const DataType &access_type = buffer->dtype;
// i // 2, i % 8 can also be vectorized as factor 16
// so we should disable this GCD optimization
void AtomicAddVectorizePlanner::UpdateVectorSize(const Array<PrimExpr> &indices,
const Buffer &buffer) {
if (!inner_for_)
return;
auto extent_ptr = inner_for_->extent.as<IntImmNode>();
if (!extent_ptr)
return;
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
const DataType &access_type = buffer->dtype;
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
auto last_dim = buffer->shape.back();
auto mod_set = analyzer_.modular_set(last_dim);
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
// conditionally tail vectorize
if (buffer->shape.back().as<IntImmNode>()) {
auto last_dim = buffer->shape.back();
auto mod_set = analyzer_.modular_set(last_dim);
max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);
if (buffer->shape.back().as<IntImmNode>()) {
max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);
auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base);
auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base);
// If gcd_base is equal to the last dimension,
// we should analyze the second-to-last dimension
// in relation to the last dimension.
if (gcd_base < Downcast<IntImm>(last_dim)->value) {
max_vector_size = gcd_base;
}
if (gcd_base < Downcast<IntImm>(last_dim)->value) {
max_vector_size = gcd_base;
}
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);
PrimExpr elem_offset = 0;
PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) {
elem_offset = elem_offset + indices[i] * stride;
stride = stride * buffer->shape[i];
}
PrimExpr thread_extent = thread_bounds->extent;
while (!IndiceCanVectorize(elem_offset, thread_var, thread_extent,
vector_size_, &analyzer_)) {
vector_size_ /= 2;
}
} else if (vector_size_ <= 4) {
// dynamic shape load: get the vectorization condition
dynamic_ = true;
PrimExpr offset = buffer.OffsetOf(indices).back();
condition_ = (truncmod(offset, vector_size_) == 0);
PrimExpr elem_offset = 0;
PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) {
elem_offset = elem_offset + indices[i] * stride;
stride = stride * buffer->shape[i];
}
}
const ForNode *inner_for_;
Map<Var, Range> iter_map_;
bool has_nonlocal_memory_access_ = false;
int vector_size_ = 4;
Var thread_var;
Range thread_bounds;
bool dynamic_ = false;
PrimExpr condition_;
};
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
inner_for_->extent, vector_size_, &analyzer_)) {
vector_size_ /= 2;
}
} else if (vector_size_ <= 4) {
dynamic_ = true;
PrimExpr offset = buffer.OffsetOf(indices).back();
condition_ = (truncmod(offset, vector_size_) == 0);
}
}
class AtomicAddVectorizeRewriter : public StmtExprMutator {
public:
AtomicAddVectorizeRewriter(const AtomicAddVectorizePlanResult &plan,
Var thread_var, PrimExpr by_var, PrimExpr bx_var,
const Range &thread_bounds, int stride_y,
int stride_x)
: vector_size_(plan.vector_size), condition_(plan.condition),
dynamic_(plan.dynamic), tx_var_(std::move(thread_var)),
by_var_(std::move(by_var)), bx_var_(std::move(bx_var)),
stride_y_(stride_y), stride_x_(stride_x) {
const int64_t *tx_ext = as_const_int(thread_bounds->extent);
ICHECK(tx_ext)
<< "thread_bounds->extent must be a constant for vectorization.";
extent_tx_ = static_cast<int>(*tx_ext);
}
AtomicAddVectorizeRewriter(const AtomicAddVectorizePlanResult &plan)
: vector_size_(plan.vector_size), dynamic_(plan.dynamic),
condition_(plan.condition) {}
private:
/**
......@@ -179,10 +187,11 @@ private:
*/
Stmt VisitStmt_(const ForNode *node) final {
inner_for_ = node;
iter_var_ = Var(node->loop_var->name_hint + "_outer");
auto ret = StmtExprMutator::VisitStmt_(node);
if (inner_for_ == node) { // rewrite the innermost loop
if (inner_for_ == node) {
For fnode = ret.as<For>().value();
auto old_var = fnode->loop_var;
auto new_var = Var(old_var->name_hint);
auto extent_ptr = as_const_int(fnode->extent);
ICHECK(extent_ptr) << fnode->extent;
int extent = *extent_ptr;
......@@ -191,9 +200,9 @@ private:
ICHECK(is_zero(fnode->min));
if (!dynamic_) {
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, iter_var_);
vmap.Set(old_var, new_var * vector_size_);
Stmt body = Substitute(fnode->body, vmap);
return For(iter_var_, 0, extent / vector_size_, fnode->kind, body,
return For(new_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
}
}
......@@ -208,57 +217,18 @@ private:
if (node->op == builtin::call_extern() && node->args.size() >= 2) {
if (const auto *func_name = node->args[0].as<StringImmNode>()) {
if (func_name->value == "AtomicAdd") {
// Matrix[by * stride_y + i / (stride_x / (tx_txtent *
// vector_size_)) + tx_var_ / (stride_x / vector_size_),
// bx * stride_x + (i % (stride_x / (tx_extent *
// vector_size_)) * (tx_extent * vector_size_) + (tx_var_ %
// (stride / vector_size_)) * vector_size_]
const BufferLoadNode *old_dst_node =
const BufferLoadNode *temp_dst_node =
node->args[1].as<BufferLoadNode>();
const BufferLoadNode *old_value_node =
const BufferLoadNode *temp_value_node =
node->args[2].as<BufferLoadNode>();
if (!old_dst_node || !old_value_node) {
if (!temp_dst_node || !temp_value_node) {
return StmtExprMutator::VisitExpr_(node);
}
Array<PrimExpr> dst_indices, value_indices;
if ((extent_tx_ * vector_size_) > stride_x_) {
dst_indices.push_back(
by_var_ * stride_y_ +
iter_var_ * (extent_tx_ * vector_size_ / stride_x_) +
truncdiv(tx_var_, stride_x_ / vector_size_));
dst_indices.push_back(
bx_var_ * stride_x_ +
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
value_indices.push_back(
iter_var_ * (extent_tx_ * vector_size_ / stride_x_) +
truncdiv(tx_var_ * vector_size_, stride_x_));
value_indices.push_back(
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
} else {
dst_indices.push_back(
by_var_ * stride_y_ +
truncdiv(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) +
truncdiv(tx_var_, stride_x_ / vector_size_));
dst_indices.push_back(
bx_var_ * stride_x_ +
truncmod(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) *
(extent_tx_ * vector_size_) +
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
value_indices.push_back(
truncdiv(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) +
truncdiv(tx_var_, stride_x_ / vector_size_));
value_indices.push_back(
truncmod(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) *
(extent_tx_ * vector_size_) +
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
}
const BufferLoad dst_node =
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>());
const BufferLoad value_node =
Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>());
BufferLoad dst_node =
BufferLoad(old_dst_node->buffer, dst_indices,
old_dst_node->predicate, old_dst_node->span);
BufferLoad value_node =
BufferLoad(old_value_node->buffer, value_indices,
old_value_node->predicate, old_value_node->span);
Call address_of_dst =
Call(DataType::Handle(), builtin::address_of(), {dst_node});
Call address_of_value =
......@@ -287,89 +257,17 @@ private:
const int vector_size_;
const PrimExpr condition_;
const bool dynamic_;
const PrimExpr by_var_, bx_var_;
int stride_y_, stride_x_;
const Var tx_var_;
Var iter_var_;
int extent_tx_;
};
static int GetVectorizeSizeMax(int compute_capability, DataType dtype) {
if (dtype == DataType::Float(16)) {
return 2;
}
if (dtype == DataType::BFloat(16)) {
if (compute_capability > 75) {
return 2;
} else {
return 1;
}
}
if (dtype == DataType::Float(32)) {
if (compute_capability >= 90) {
return 4;
} else {
return 1;
}
}
return 1;
}
For VectorizeAtomicAdd(const For &for_node, const Var &thread_var,
const Range &thread_bounds, int compute_capability) {
int vectorize_size_max = 1;
int stride_x = -1, stride_y = -1;
PrimExpr bx_var, by_var;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto *call = obj.as<CallNode>()) {
if (call->op == builtin::call_extern() && call->args.size() >= 2) {
const auto *func_name = call->args[0].as<StringImmNode>();
if (func_name->value == "AtomicAdd") {
DataType dtype = call->args[1].as<BufferLoadNode>()->dtype;
vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
}
}
}
if (const MulNode *mul = obj.as<MulNode>()) {
const VarNode *var = nullptr;
const IntImmNode *imm = nullptr;
PrimExpr var_expr;
if ((var = mul->a.as<VarNode>()) && (imm = mul->b.as<IntImmNode>())) {
var_expr = mul->a;
} else if ((var = mul->b.as<VarNode>()) &&
(imm = mul->a.as<IntImmNode>())) {
var_expr = mul->b;
}
if (var && imm) {
if (var->name_hint == "bx") {
stride_x = imm->value;
bx_var = var_expr;
} else if (var->name_hint == "by") {
stride_y = imm->value;
by_var = var_expr;
}
}
}
});
if (vectorize_size_max != 1) {
int vectorize_hint = vectorize_size_max;
AtomicAddVectorizePlanResult res = {1, false, 0};
AtomicAddVectorizePlanner planner;
res = planner.Plan(for_node, thread_var, thread_bounds, vectorize_hint);
vectorize_hint = res.vector_size;
if (vectorize_hint == 1 || stride_x == -1 || stride_y == -1 ||
!bx_var.defined() || !by_var.defined())
return for_node;
auto rewriter = AtomicAddVectorizeRewriter(
res, thread_var, by_var, bx_var, thread_bounds, stride_y, stride_x);
return Downcast<For>(rewriter(for_node));
} else {
For VectorizeAtomicAdd(const For &for_node, int compute_capability) {
AtomicAddVectorizePlanResult res = {1, false, 0};
AtomicAddVectorizePlanner planner;
res = planner.Plan(for_node, compute_capability);
int vectorize_hint = res.vector_size;
if (vectorize_hint == 1)
return for_node;
}
auto rewriter = AtomicAddVectorizeRewriter(res);
return Downcast<For>(rewriter(for_node));
}
} // namespace tl
......
......@@ -6,16 +6,53 @@
#ifndef TVM_TL_ATOMICADD_VECTORIZE_H_
#define TVM_TL_ATOMICADD_VECTORIZE_H_
#include "../layout/layout.h"
#include "../layout/utils.h"
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "atomicadd_vectorize.h"
#include "common/loop_vectorization_utils.h"
#include <numeric>
#include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <utility>
namespace tvm {
namespace tl {
using namespace tir;
For VectorizeAtomicAdd(const For &for_node, const Var &thread_var,
const Range &thread_bounds, int compute_capability);
For VectorizeAtomicAdd(const For &for_node, int compute_capability);
struct AtomicAddVectorizePlanResult {
int vector_size;
bool dynamic;
PrimExpr condition;
};
class AtomicAddVectorizePlanner : public arith::IRVisitorWithAnalyzer {
public:
AtomicAddVectorizePlanner();
AtomicAddVectorizePlanResult Plan(const For &node, int compute_capability);
private:
void VisitStmt_(const ForNode *node) final;
void VisitExpr_(const CallNode *node) final;
int GetVectorizeSizeMax(int compute_capability, DataType dtype);
void UpdateVectorSize(const Array<PrimExpr> &indices, const Buffer &buffer);
const ForNode *inner_for_ = nullptr;
bool has_nonlocal_memory_access_ = false;
int vector_size_ = 4;
int max_vector_size = 1;
bool dynamic_ = false;
PrimExpr condition_;
};
} // namespace tl
} // namespace tvm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment