"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "3d757d5025fb1cd419f5ac8cacf92f05bdac945e"
Unverified Commit 05507037 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Feature][Example] Support TMA reduce operation and update GQA bwd example (#969)



* [Feature][Example] Support TMA reduce operation and update GQA bwd example

* move GQA bwd with TMA reduce to new example

* [Lint]: [pre-commit.ci] auto fixes [...]

---------
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 77b9d08e
This diff is collapsed.
......@@ -80,7 +80,10 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
if (args.size() >= 3) {
node->coalesced_width = Downcast<IntImm>(args[2]);
node->use_tma = Downcast<IntImm>(args[2]);
}
if (args.size() >= 4) {
node->coalesced_width = Downcast<IntImm>(args[3]);
}
data_ = std::move(node);
}
......@@ -169,6 +172,18 @@ Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
return indices;
}
std::pair<Array<PrimExpr>, PrimExpr>
AtomicAddNode::ReturnIndicesAndSize(int src_dst) const {
Array<PrimExpr> indices;
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
PrimExpr size = 1;
for (size_t i = 0; i < ranges.size(); i++) {
indices.push_back(ranges[i]->min);
size *= ranges[i]->extent;
}
return {indices, size};
}
/**
* @brief Build a combined bound-check predicate for indexed access.
*
......@@ -350,6 +365,28 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
*/
Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Target target = T.target;
if (use_tma->value != 0) {
Array<PrimExpr> src_indices, dst_indices;
PrimExpr src_size, dst_size;
std::tie(src_indices, src_size) = ReturnIndicesAndSize(0);
std::tie(dst_indices, dst_size) = ReturnIndicesAndSize(1);
ICHECK(analyzer->CanProveEqual(src_size, dst_size))
<< "src_size = " << src_size << ", dst_size = " << dst_size;
BufferLoad src_node = BufferLoad(src, src_indices);
BufferLoad dst_node = BufferLoad(dst, dst_indices);
Call address_of_src =
Call(DataType::Handle(), builtin::address_of(), {src_node});
Call address_of_dst =
Call(DataType::Handle(), builtin::address_of(), {dst_node});
int need_reduce = 1;
int eviction_policy = 0;
auto body = Evaluate(Call(DataType::Handle(), tma_store(),
{address_of_src, address_of_dst,
ceildiv(src_size * src->dtype.bits(), 8),
need_reduce, eviction_policy}));
return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), body);
}
auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
auto par_op = ParallelOp(fused_loop);
......
......@@ -20,6 +20,7 @@ public:
Buffer src, dst; ///< Source and destination buffers
Array<Range> src_range,
dst_range; ///< Access ranges for source and destination
IntImm use_tma; ///< Whether to use TMA for memory operations
IntImm coalesced_width; ///< Width for memory coalescing optimization
mutable ParallelOp par_op_; ///< Associated parallel operation
......@@ -39,6 +40,7 @@ public:
.def_ro("dst", &AtomicAddNode::dst)
.def_ro("src_range", &AtomicAddNode::src_range)
.def_ro("dst_range", &AtomicAddNode::dst_range)
.def_ro("use_tma", &AtomicAddNode::use_tma)
.def_ro("coalesced_width", &AtomicAddNode::coalesced_width);
}
......@@ -46,6 +48,7 @@ public:
return equal(src, other->src) && equal(dst, other->dst) &&
equal(src_range, other->src_range) &&
equal(dst_range, other->dst_range) &&
equal(use_tma, other->use_tma) &&
equal(coalesced_width, other->coalesced_width);
}
......@@ -54,6 +57,7 @@ public:
hash_reduce(dst);
hash_reduce(src_range);
hash_reduce(dst_range);
hash_reduce(use_tma);
hash_reduce(coalesced_width);
}
......@@ -67,6 +71,8 @@ protected:
Array<IterVar> MakeIterVars() const;
/// Generate buffer indices from iteration variables
Array<PrimExpr> MakeIndices(const Array<IterVar> &ivs, int src_dst) const;
/// Return buffer indices and size
std::pair<Array<PrimExpr>, PrimExpr> ReturnIndicesAndSize(int src_dst) const;
/// Create boundary predicate for memory safety
PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const;
......
......@@ -1571,6 +1571,9 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
global_coords.Set(0, global_coords[0] + instruction_dim * loop_var);
for (auto coord : global_coords)
args.push_back(coord);
int need_reduce = 0;
if (!is_load)
args.push_back(need_reduce);
args.push_back(this->eviction_policy);
tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled,
Evaluate(Call(DataType::Handle(), op, args)));
......@@ -1580,6 +1583,9 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
args.push_back(shared_addr);
for (auto coord : global_coords)
args.push_back(coord);
int need_reduce = 0;
if (!is_load)
args.push_back(need_reduce);
args.push_back(this->eviction_policy);
tma_copy = Evaluate(Call(DataType::Handle(), op, args));
}
......@@ -1654,10 +1660,11 @@ Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer,
{shared_addr, global_addr, 0,
elements * shared_tensor->dtype.bytes(), this->eviction_policy}));
} else {
int need_reduce = 0;
tma_copy = Evaluate(
Call(DataType::Handle(), tma_store(),
{global_addr, shared_addr, elements * shared_tensor->dtype.bytes(),
this->eviction_policy}));
need_reduce, this->eviction_policy}));
}
tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
return tma_copy;
......
......@@ -1345,6 +1345,11 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
print_extern_call_stmt(ss.str(), 0, 1);
} else if (op->op.same_as(tl::tma_store())) {
std::stringstream ss;
auto need_reduce = op->args[op->args.size() - 2].as<IntImmNode>()->value;
if (need_reduce) {
print_extern_call_stmt("tl::tma_store_add", 0, 2);
return;
}
auto eviction_policy =
this->eviction_policy_names_
[op->args[op->args.size() - 1].as<IntImmNode>()->value];
......@@ -1353,7 +1358,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
} else {
ss << "tl::tma_store";
}
print_extern_call_stmt(ss.str(), 0, 1);
print_extern_call_stmt(ss.str(), 0, 2);
} else if (op->op.same_as(tl::ptx_ldmatrix())) {
int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value;
......
......@@ -252,6 +252,16 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor,
: "memory");
}
TL_DEVICE void tma_store_add(float *const smem_ptr, float *gmem_ptr,
int32_t const &store_bytes) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 "
"[%0], [%1], %2;\n"
:
: "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes)
: "memory");
}
TL_DEVICE void prefetch_tma_descriptor(const CUtensorMap &descriptor) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
asm volatile("prefetch.tensormap [%0];" : : "l"(gmem_int_desc) : "memory");
......
......@@ -116,7 +116,8 @@ def atomic_min(dst: Buffer,
def atomic_add(dst: Buffer,
value: PrimExpr,
memory_order: Optional[str] = None,
return_prev: bool = False) -> PrimExpr:
return_prev: bool = False,
use_tma: bool = False) -> PrimExpr:
"""
Atomically add `value` into `dst`, returning a handle to the operation.
......@@ -225,7 +226,7 @@ def atomic_add(dst: Buffer,
raise NotImplementedError(
"return_prev is not supported for tile-region-based atomic operations")
return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst)
return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst, use_tma)
def atomic_addx2(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> PrimExpr:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment