Unverified Commit b45e9c45 authored by yyttt6's avatar yyttt6 Committed by GitHub
Browse files

[Feature]:Add auto vectorize for atomic add (#686)

* [Feature]:Add auto vectorize for atomic add

* fix

* fix2

* format
parent c5df7938
......@@ -37,14 +37,6 @@ def matmul(M,
T.copy(C_local, C_shared)
# TODO: Automatically add vectorized atomic with enhancement
# https://github.com/tile-ai/tilelang/issues/523
# if DataType(dtype).bits == 16:
# for i, j in T.Parallel(block_M, block_N // 2):
# m, n = by * block_M + i, bx * block_N + j * 2
# # vectorized atomic
# T.atomic_addx2(C[m, n], C_shared[i, j * 2])
for i, j in T.Parallel(block_M, block_N):
T.atomic_add(C[by * block_M + i, bx * block_N + j], C_shared[i, j])
......
import tilelang
import tilelang.language as T
@tilelang.jit
def matmul(M,
N,
K,
block_M,
block_N,
block_K,
split_k,
dtype="float16",
accum_dtype="float",
out_dtype="float32"):
splitK = K // split_k
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(splitK, block_K), num_stages=0):
T.copy(A[by * block_M, bz * splitK + ko * block_K], A_shared)
T.copy(B[bz * splitK + ko * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C_shared)
T.atomic_add(C[by * block_M, bx * block_N], C_shared)
return main
def main():
M = 1024
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32
split_k = 4
kernel = matmul(M, N, K, block_M, block_N, block_K, split_k)
import torch
torch.random.manual_seed(42)
a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half()
c = torch.zeros(M, N).cuda().float()
kernel(a, b, c)
ref_c = a @ b
torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2)
if __name__ == "__main__":
main()
import tilelang.testing
from example_tilelang_gemm_splitk import main
import example_tilelang_gemm_splitk
import example_tilelang_gemm_splitk_vectorize_atomicadd
def test_example_tilelang_gemm_splitk():
main()
example_tilelang_gemm_splitk.main()
def test_example_tilelang_gemm_splitk_vectorize_atomicadd():
example_tilelang_gemm_splitk_vectorize_atomicadd.main()
if __name__ == "__main__":
......
/*!
* \file tl/op/atomic_add.cc
*
* Define elment-wise operators.
*/
#include "atomic_add.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include "../target/utils.h"
#include "../transform/atomicadd_vectorize.h"
#include "../transform/common/loop_fusion_utils.h"
#include "../transform/loop_partition.h"
#include "builtin.h"
namespace tvm {
namespace tl {
using namespace tir;
static int GetArchInt(Target target) {
int arch_int = 0;
auto s = target->GetAttr<String>("arch");
ICHECK(s.defined());
const char *arch_str = s.value().c_str();
if (arch_str[0] == 's' && arch_str[1] == 'm' && arch_str[2] == '_') {
arch_int = atoi(&arch_str[3]);
} else {
arch_int = 0;
}
return arch_int;
}
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
Array<Range> rgs[2];
Buffer bf[2];
for (int i = 0; i < 2; i++) {
auto expr = args[i];
auto call = expr.as<CallNode>();
ICHECK(call);
auto region = RegionOp(call->args, vmap);
rgs[i] = region.GetRanges();
bf[i] = region.GetBuffer();
}
std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]);
std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]);
if (args.size() >= 3) {
coalesced_width = Downcast<IntImm>(args[2]);
}
}
Array<IterVar> AtomicAdd::MakeIterVars() const {
Array<IterVar> loop_vars;
size_t idx = 0;
for (size_t i = 0; i < src_range.size(); i++) {
if (is_one(src_range[i]->extent))
continue;
Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype);
idx++;
loop_vars.push_back(
{Range(0, src_range[i]->extent), var, IterVarType::kDataPar});
}
return loop_vars;
}
// ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices
Array<PrimExpr> AtomicAdd::MakeIndices(const Array<IterVar> &ivs,
int src_dst) const {
Array<PrimExpr> indices;
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
size_t idx = 0;
for (size_t i = 0; i < ranges.size(); i++) {
if (is_one(ranges[i]->extent))
indices.push_back(ranges[i]->min);
else {
indices.push_back(ranges[i]->min + ivs[idx]->var);
idx++;
}
}
ICHECK(idx == ivs.size())
<< "idx = " << idx << ", ivs.size() = " << ivs.size()
<< "src name = " << src->name << ", dst name = " << dst->name;
return indices;
}
PrimExpr AtomicAdd::MakePredicate(arith::Analyzer *analyzer,
const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const {
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
Array<PrimExpr> cond_list;
ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
size_t idx = 0;
for (size_t i = 0; i < ranges.size(); i++) {
if (is_one(ranges[i]->extent))
continue;
PrimExpr cond = ranges[i]->min + ivs[idx]->var < extents[i];
if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
cond_list.push_back(cond);
}
cond = ranges[i]->min + ivs[idx]->var >= 0;
if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
cond_list.push_back(cond);
}
idx++;
}
if (cond_list.empty())
return {};
else {
PrimExpr cond = cond_list[0];
for (size_t i = 1; i < cond_list.size(); i++)
cond = And(cond, cond_list[i]);
return cond;
}
}
For AtomicAdd::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<IterVar> loop_vars = MakeIterVars();
bool is_scalar = loop_vars.size() == 0;
if (is_scalar) {
return For(Var("i"), 0, 1, ForKind::kSerial,
BufferStore(dst, BufferLoad(src, {0}), {0}));
}
for (const auto &iv : loop_vars)
analyzer->Bind(iv->var, iv->dom);
ICHECK(loop_vars.size() <= src_range.size())
<< "loop_vars.size() = " << loop_vars.size()
<< ", src_range.size() = " << src_range.size() << ", src = " << src->name
<< ", dst = " << dst->name;
ICHECK(loop_vars.size() <= dst_range.size())
<< "loop_vars.size() = " << loop_vars.size()
<< ", dst_range.size() = " << dst_range.size() << ", src = " << src->name
<< ", dst = " << dst->name;
Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0);
Array<PrimExpr> dst_indices = MakeIndices(loop_vars, 1);
PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);
Array<PrimExpr> new_args;
new_args.push_back(StringImm("AtomicAdd"));
PrimExpr src_value = BufferLoad(src, src_indices);
if (src->dtype != dst->dtype)
src_value = Cast(dst->dtype, src_value);
if (src_predicate.defined())
src_value = if_then_else(src_predicate, src_value, make_zero(dst->dtype));
PrimExpr dst_value = BufferLoad(dst, dst_indices);
if (dst_predicate.defined())
dst_value = if_then_else(dst_predicate, dst_value, make_zero(dst->dtype));
Call address_of_value =
tvm::tir::Call(DataType::Handle(), builtin::address_of(), {dst_value});
new_args.push_back(address_of_value);
new_args.push_back(src_value);
Call atomicadd_call =
tvm::tir::Call(dst->dtype, builtin::call_extern(), new_args);
Stmt body = tvm::tir::Evaluate(atomicadd_call);
for (int i = loop_vars.size() - 1; i >= 0; i--) {
Map<String, ObjectRef> annotations = {};
if (coalesced_width.defined()) {
annotations.Set("coalesced_width", coalesced_width);
}
body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
ForKind::kParallel, body, std::nullopt, annotations);
}
return Downcast<For>(body);
}
Stmt AtomicAdd::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Target target = T.target;
bool is_cpu_target = target->GetTargetDeviceType() == kDLCPU;
auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
For vectorized_thread_loop;
auto par_op = std::make_unique<ParallelOp>(fused_loop);
if (!is_cpu_target) {
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree};
for (auto level : levels) {
par_op->InferLayout(
{T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level);
}
auto loop_layout = par_op->GetLoopLayout();
Var thread_var = T.thread_var;
Range thread_bounds = T.thread_bounds;
auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
vectorized_thread_loop = VectorizeAtomicAdd(
thread_loop, thread_var, thread_bounds, GetArchInt(target));
}
if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop);
}
return vectorized_thread_loop;
}
LayoutMap AtomicAdd::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (par_op_ == nullptr) {
arith::Analyzer analyzer;
par_op_ = std::make_unique<ParallelOp>(MakeSIMTLoop(&analyzer));
}
if (T.layout_map.count(src) && T.layout_map.count(dst)) {
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") {
const FragmentNode *src_layout = T.layout_map[src].as<FragmentNode>();
const FragmentNode *dst_layout = T.layout_map[dst].as<FragmentNode>();
if (src_layout && dst_layout) {
ICHECK(src_layout->IsEqual(dst_layout, true))
<< "Get different layout for " << src << " and " << dst
<< "\nLHS = " << src_layout->DebugOutput()
<< "\nRHS = " << dst_layout->DebugOutput()
<< "\nYou may need to use a shared memory to transform the layout";
}
}
}
return par_op_->InferLayout(T, level);
}
TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
// TVM_REGISTER_OP("tl.atomicadd")
// .set_num_inputs(2)
// .add_argument("ref", "Buffer", "The destination buffer")
// .add_argument("val", "Expr", "The value to be added atomically");
} // namespace tl
} // namespace tvm
\ No newline at end of file
/*!
* \file tl/op/atomic_add.h
* \brief Define atomic add operator.
*
*/
#ifndef TVM_TL_OP_ATOMIC_ADD_H_
#define TVM_TL_OP_ATOMIC_ADD_H_
#include "op.h"
#include "parallel.h"
namespace tvm {
namespace tl {
using namespace tir;
class AtomicAdd : public Operator {
public:
AtomicAdd(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op &Get();
protected:
For MakeSIMTLoop(arith::Analyzer *analyzer) const;
Array<IterVar> MakeIterVars() const;
// ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices
Array<PrimExpr> MakeIndices(const Array<IterVar> &ivs, int src_dst) const;
PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const;
Array<PrimExpr> args_;
Buffer src, dst;
Array<Range> src_range, dst_range;
IntImm coalesced_width;
std::unique_ptr<ParallelOp> par_op_;
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_ATOMIC_ADD_H_
\ No newline at end of file
/*!
* \file atomicadd_vectorize.cc
* \brief A tool to atomatically vectorize atomic add
*/
#include "../layout/layout.h"
#include "../layout/utils.h"
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_vectorization_utils.h"
#include <numeric>
#include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
namespace tvm {
namespace tl {
using namespace tir;
using arith::IRMutatorWithAnalyzer;
using arith::IRVisitorWithAnalyzer;
struct AtomicAddVectorizePlanResult {
int vector_size;
bool dynamic;
PrimExpr condition;
};
class AtomicAddVectorizePlanner : public arith::IRVisitorWithAnalyzer {
public:
AtomicAddVectorizePlanner() = default;
int max_vector_size = 1;
AtomicAddVectorizePlanResult Plan(const For &node, Var thread_var,
Range thread_bounds, int vectorize_hint) {
this->max_vector_size = vectorize_hint;
this->thread_var = thread_var;
this->thread_bounds = thread_bounds;
this->operator()(node);
return {vector_size_, dynamic_, condition_};
}
private:
void VisitStmt_(const ForNode *node) final {
inner_for_ = node;
iter_map_.Set(node->loop_var, Range(node->min, node->extent));
arith::IRVisitorWithAnalyzer::VisitStmt_(node);
}
void VisitExpr_(const CallNode *node) final {
if (node->op == builtin::call_extern() && node->args.size() >= 2) {
if (const auto *func_name = node->args[0].as<StringImmNode>()) {
if (func_name->value == "AtomicAdd") {
const CallNode *addr_call = node->args[1].as<CallNode>();
if (addr_call && addr_call->op == builtin::address_of() &&
addr_call->args.size() == 1) {
const BufferLoadNode *buffer_load_dst =
addr_call->args[0].as<BufferLoadNode>();
const BufferLoadNode *buffer_load_src =
node->args[2].as<BufferLoadNode>();
if (buffer_load_src && buffer_load_src->buffer.defined() &&
buffer_load_dst && buffer_load_dst->buffer.defined()) {
Buffer dst_buffer = buffer_load_dst->buffer;
Array<PrimExpr> indices_dst = buffer_load_dst->indices;
UpdateVectorSize(indices_dst, dst_buffer);
Buffer src_buffer = buffer_load_src->buffer;
Array<PrimExpr> indices_src = buffer_load_src->indices;
UpdateVectorSize(indices_src, src_buffer);
}
}
}
}
}
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
}
void UpdateVectorSize(const Array<PrimExpr> indices, const Buffer &buffer) {
if (!inner_for_)
return;
auto extent_ptr = inner_for_->extent.as<IntImmNode>();
if (!extent_ptr)
return;
const DataType &access_type = buffer->dtype;
// i // 2, i % 8 can also be vectorized as factor 16
// so we should disable this GCD optimization
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
auto last_dim = buffer->shape.back();
auto mod_set = analyzer_.modular_set(last_dim);
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
// conditionally tail vectorize
if (buffer->shape.back().as<IntImmNode>()) {
max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);
auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base);
// If gcd_base is equal to the last dimension,
// we should analyze the second-to-last dimension
// in relation to the last dimension.
if (gcd_base < Downcast<IntImm>(last_dim)->value) {
max_vector_size = gcd_base;
}
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);
PrimExpr elem_offset = 0;
PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) {
elem_offset = elem_offset + indices[i] * stride;
stride = stride * buffer->shape[i];
}
PrimExpr thread_extent = thread_bounds->extent;
while (!IndiceCanVectorize(elem_offset, thread_var, thread_extent,
vector_size_, &analyzer_)) {
vector_size_ /= 2;
}
} else if (vector_size_ <= 4) {
// dynamic shape load: get the vectorization condition
dynamic_ = true;
PrimExpr offset = buffer.OffsetOf(indices).back();
condition_ = (FloorMod(offset, vector_size_) == 0);
}
}
const ForNode *inner_for_;
Map<Var, Range> iter_map_;
bool has_nonlocal_memory_access_ = false;
int vector_size_ = 4;
Var thread_var;
Range thread_bounds;
bool dynamic_ = false;
PrimExpr condition_;
};
class AtomicAddVectorizeRewriter : public StmtExprMutator {
public:
AtomicAddVectorizeRewriter(AtomicAddVectorizePlanResult plan)
: vector_size_(plan.vector_size), condition_(plan.condition),
dynamic_(plan.dynamic) {}
private:
Stmt VisitStmt_(const ForNode *node) final {
inner_for_ = node;
auto ret = StmtExprMutator::VisitStmt_(node);
if (inner_for_ == node) { // rewrite the innermost loop
For fnode = ret.as<For>().value();
auto old_var = fnode->loop_var;
auto extent_ptr = as_const_int(fnode->extent);
ICHECK(extent_ptr) << fnode->extent;
int extent = *extent_ptr;
ICHECK(extent % vector_size_ == 0)
<< "extent: " << extent << " vector_size_: " << vector_size_;
ICHECK(is_zero(fnode->min));
if (!dynamic_) {
Var tx_var;
PostOrderVisit(fnode->body, [&tx_var](const ObjectRef &node) {
if (const VarNode *var = node.as<VarNode>()) {
if (var->name_hint == "tx") {
tx_var = GetRef<Var>(var);
}
}
});
ICHECK(tx_var.defined()) << "Failed to find tx var";
Var outer_var = Var(old_var->name_hint + "_outer");
Map<Var, PrimExpr> vmap;
vmap.Set(tx_var,
truncmod(tx_var, extent / vector_size_) * vector_size_);
vmap.Set(fnode->loop_var, outer_var * vector_size_ +
truncdiv(tx_var, extent / vector_size_));
Stmt body = Substitute(fnode->body, vmap);
return For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
} else {
return fnode;
}
} else {
return ret;
}
}
PrimExpr VisitExpr_(const CallNode *node) final {
if (vector_size_ == 2 || vector_size_ == 4) {
if (node->op == builtin::call_extern() && node->args.size() >= 2) {
if (const auto *func_name = node->args[0].as<StringImmNode>()) {
if (func_name->value == "AtomicAdd") {
PrimExpr value_node = node->args[2];
Call address_of_value = tvm::tir::Call(
DataType::Handle(), builtin::address_of(), {value_node});
Array<PrimExpr> new_args;
if (vector_size_ == 2) {
new_args.push_back(StringImm("AtomicAddx2"));
} else {
new_args.push_back(StringImm("AtomicAddx4"));
}
new_args.push_back(node->args[1]);
new_args.push_back(address_of_value);
Call new_call =
tvm::tir::Call(node->dtype, builtin::call_extern(), new_args);
return new_call;
}
}
}
}
return StmtExprMutator::VisitExpr_(node);
}
const ForNode *inner_for_;
const int vector_size_;
const PrimExpr condition_;
const bool dynamic_;
};
static int GetVectorizeSizeMax(int compute_capability, DataType dtype) {
if (dtype == DataType::Float(16)) {
return 2;
}
if (dtype == DataType::BFloat(16)) {
if (compute_capability > 75) {
return 2;
} else {
return 1;
}
}
if (dtype == DataType::Float(32)) {
if (compute_capability >= 90) {
return 4;
} else {
return 1;
}
}
return 1;
}
For VectorizeAtomicAdd(const For &for_node, Var thread_var, Range thread_bounds,
int compute_capability) {
int vectorize_size_max = 1;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto *call = obj.as<CallNode>()) {
if (call->op == builtin::call_extern() && call->args.size() >= 2) {
const auto *func_name = call->args[0].as<StringImmNode>();
if (func_name->value == "AtomicAdd") {
DataType dtype =
call->args[1].as<CallNode>()->args[0].as<BufferLoadNode>()->dtype;
vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
}
}
}
});
if (vectorize_size_max != 1) {
int vectorize_hint = vectorize_size_max;
AtomicAddVectorizePlanResult res = {1, false, 0};
AtomicAddVectorizePlanner planner;
res = planner.Plan(for_node, thread_var, thread_bounds, vectorize_hint);
vectorize_hint = res.vector_size;
if (vectorize_hint == 1)
return for_node;
auto rewriter = AtomicAddVectorizeRewriter(res);
return Downcast<For>(rewriter(for_node));
} else {
return for_node;
}
}
} // namespace tl
} // namespace tvm
/*!
* \file atomicadd_vectorize.h
* \brief A tool to automatically vectorize a for atomicadd
*/
#ifndef TVM_TL_ATOMICADD_VECTORIZE_H_
#define TVM_TL_ATOMICADD_VECTORIZE_H_
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
namespace tvm {
namespace tl {
using namespace tir;
For VectorizeAtomicAdd(const For &for_node, Var thread_var, Range thread_bounds,
int compute_capability);
} // namespace tl
} // namespace tvm
#endif // TVM_TL_ATOMICADD_VECTORIZE_H_
\ No newline at end of file
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
"""The language interface for tl programs."""
import tilelang.language as T
from tvm.tir import PrimExpr, Buffer
from tvm import ir
from tvm.tir import PrimExpr, Buffer, BufferLoad, BufferRegion, Var, op
from typing import List, Union
def region(buffer: BufferLoad, access_type: str, *args: PrimExpr):
"""Create a memory region descriptor for tile operations.
Args:
buffer (tir.BufferLoad): The buffer to create a region for
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
*args (tir.PrimExpr): Extent expressions defining the region size
Returns:
tir.Call: A region descriptor for tile operations
"""
access_type = {"r": 1, "w": 2, "rw": 3}[access_type]
return T.call_intrin("handle", op.Op.get("tl.region"), buffer, access_type, *args)
def buffer_to_tile_region(buffer: Buffer, access_type: str):
"""Convert a TVM buffer to a tile region descriptor.
Args:
buffer (tir.Buffer): The buffer to convert
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
Returns:
tir.Call: A region descriptor covering the entire buffer
"""
mins = [0 for _ in buffer.shape]
extents = [x for x in buffer.shape]
return region(T.BufferLoad(buffer, mins), access_type, *extents)
def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: List[PrimExpr]):
"""Convert a buffer load operation to a tile region descriptor.
Args:
load (tir.BufferLoad): The buffer load operation
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
extents (List[tir.PrimExpr]): List of expressions defining the region size
Returns:
tir.Call: A region descriptor for the loaded area
"""
indices = load.indices
if len(indices) > len(extents):
# (f"mismatch between indices and extents for buffer load {load}: indices = {indices}, extents = {extents}, "
# f"region will be expanded in the last 2 dimensions")
new_extents = []
for _ in range(len(indices) - len(extents)):
new_extents.append(1)
for extent in extents:
new_extents.append(extent)
extents = new_extents
assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}"
return region(load, access_type, *extents)
def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str,
extents: List[PrimExpr]):
"""Convert a buffer region to a tile region descriptor.
Args:
buffer_region (tir.BufferRegion): The buffer region to convert
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
Returns:
tir.Call: A region descriptor for the specified buffer region
"""
mins = [x.min for x in buffer_region.region]
region_extents = [x.extent for x in buffer_region.region]
assert len(region_extents) >= len(
extents
), f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}"
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)
def atomic_add(dst: Buffer, value: PrimExpr) -> PrimExpr:
"""Perform an atomic addition operation.
......@@ -15,7 +93,41 @@ def atomic_add(dst: Buffer, value: PrimExpr) -> PrimExpr:
Returns:
PrimExpr: Handle to the atomic addition operation
"""
return T.call_extern("handle", "AtomicAdd", T.address_of(dst), value)
if isinstance(dst, BufferLoad) and isinstance(value, BufferLoad):
return T.call_extern("handle", "AtomicAdd", T.address_of(dst), value)
if isinstance(dst, Buffer) and isinstance(value, Buffer):
ir.assert_structural_equal(dst.shape, value.shape)
def get_extent(data):
if isinstance(data, Var) and T.has_let_value(data):
data = T.get_let_value(data)
if isinstance(data, Buffer):
return data.shape
elif isinstance(data, BufferRegion):
return [x.extent for x in data.region]
else:
return None
src_extent = get_extent(value)
dst_extent = get_extent(dst)
assert src_extent or dst_extent, "Can't deduce atomicadd extents from args"
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
extent = max(src_extent, dst_extent)
def _to_region(data, access_type):
if isinstance(data, Var) and T.has_let_value(data):
data = T.get_let_value(data)
if isinstance(data, Buffer):
return buffer_to_tile_region(data, access_type)
elif isinstance(data, BufferRegion):
return buffer_region_to_tile_region(data, access_type, extent)
else:
return buffer_load_to_tile_region(data, access_type, extent)
value = _to_region(value, "r")
dst = _to_region(dst, "w")
return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst)
def atomic_addx2(dst: Buffer, value: PrimExpr) -> PrimExpr:
......@@ -32,14 +144,14 @@ def atomic_addx2(dst: Buffer, value: PrimExpr) -> PrimExpr:
def atomic_addx4(dst: Buffer, value: PrimExpr) -> PrimExpr:
"""Perform an atomic addition operation with double-width operands.
"""Perform an atomic addition operation with quad-width operands.
Args:
dst (Buffer): Destination buffer where the atomic addition will be performed
value (PrimExpr): Value to be atomically added (double-width)
value (PrimExpr): Value to be atomically added (quad-width)
Returns:
PrimExpr: Handle to the double-width atomic addition operation
PrimExpr: Handle to the quad-width atomic addition operation
"""
return T.call_extern("handle", "AtomicAddx4", T.address_of(dst), T.address_of(value))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment