Commit eba7dd5a authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Feature] Add TMA Store Synchronization Support (#195)

- Introduce TMAStoreArrive and TMAStoreWait operations for CUDA TMA store synchronization
- Add new builtin operations in op/builtin.cc and op/builtin.h
- Implement TMAStoreSyncInjector to automatically inject TMA store synchronization calls
- Update CUDA codegen to support new TMA store synchronization intrinsics
- Add Python language bindings for new TMA store synchronization operations
parent 94c758ad
......@@ -89,6 +89,15 @@ TIR_DEFINE_TL_BUILTIN(FenceProxyAsyncOp)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(TMAStoreArrive)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(TMAStoreWait)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(SetMaxNReg)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
......
......@@ -136,6 +136,22 @@ const Op &SyncThreadsPartialOp();
*/
const Op &FenceProxyAsyncOp();
/*!
* \brief Indicate arrival of warp issuing TMA_STORE
*
* TMAStoreArrive()
*
*/
const Op &TMAStoreArrive();
/*!
* \brief Wait for TMA_STORE to finish
*
* TMAStoreWait()
*
*/
const Op &TMAStoreWait();
/*!
* \brief Set reg hint for warp-specialized branched
*
......
......@@ -843,6 +843,10 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::FenceProxyAsyncOp())) {
print_extern_call_stmt("tl::fence_proxy_async");
} else if (op->op.same_as(tl::TMAStoreArrive())) {
print_extern_call_stmt("tl::tma_store_arrive");
} else if (op->op.same_as(tl::TMAStoreWait())) {
print_extern_call_stmt("tl::tma_store_wait<0>");
} else if (op->op.same_as(tl::SetMaxNReg())) {
this->PrintIndent();
int nreg = Downcast<IntImm>(op->args[0])->value;
......
......@@ -227,6 +227,15 @@ TL_DEVICE void fence_proxy_async() {
asm volatile("fence.proxy.async.shared::cta;" : :);
}
// Indicate arrival of warp issuing TMA_STORE
TL_DEVICE void tma_store_arrive() {
asm volatile("cp.async.bulk.commit_group;");
}
template <int Count> TL_DEVICE void tma_store_wait() {
asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(Count) : "memory");
}
TL_DEVICE void syncthreads_partial(uint64_t &smem_barrier) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
uint64_t state = 0;
......
......@@ -112,6 +112,31 @@ private:
std::unordered_map<const StmtNode *, Proxy> map_;
};
class TMAStoreSyncInjector : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc f) {
auto T = TMAStoreSyncInjector();
f.CopyOnWrite()->body = T(f->body);
return f;
}
private:
Stmt VisitStmt_(const EvaluateNode *op) final {
if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(TMAStoreOp())) {
Array<Stmt> new_body;
new_body.push_back(GetRef<Evaluate>(op));
new_body.push_back(
Evaluate(Call(DataType::Handle(), TMAStoreArrive(), {})));
new_body.push_back(
Evaluate(Call(DataType::Handle(), TMAStoreWait(), {})));
return SeqStmt(std::move(new_body));
}
}
return StmtExprMutator::VisitStmt_(op);
}
};
class InjectFenceProxy : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc f) {
......@@ -161,6 +186,7 @@ using namespace tir::transform;
tvm::transform::Pass InjectFenceProxy() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
f = TMAStoreSyncInjector::Substitute(f);
return InjectFenceProxy::Substitute(f);
};
return CreatePrimFuncPass(pass_func, 0, "tl.InjectFenceProxy", {});
......
......@@ -57,5 +57,4 @@ def test_lower_hopper_intrin_barrier():
if __name__ == "__main__":
# tilelang.testing.main()
test_lower_hopper_intrin_barrier()
tilelang.testing.main()
\ No newline at end of file
......@@ -25,6 +25,14 @@ def FenceProxyAsyncOp(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.FenceProxyAsyncOp"), *args)
def TMAStoreArrive(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.TMAStoreArrive"), *args)
def TMAStoreWait(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.TMAStoreWait"), *args)
def SetMaxNReg(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.SetMaxNReg"), *args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment