"git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "bf8a6fc14245e193c262293179122a645764a486"
Commit eba7dd5a authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Feature] Add TMA Store Synchronization Support (#195)

- Introduce TMAStoreArrive and TMAStoreWait operations for CUDA TMA store synchronization
- Add new builtin operations in op/builtin.cc and op/builtin.h
- Implement TMAStoreSyncInjector to automatically inject TMA store synchronization calls
- Update CUDA codegen to support new TMA store synchronization intrinsics
- Add Python language bindings for new TMA store synchronization operations
parent 94c758ad
...@@ -89,6 +89,15 @@ TIR_DEFINE_TL_BUILTIN(FenceProxyAsyncOp) ...@@ -89,6 +89,15 @@ TIR_DEFINE_TL_BUILTIN(FenceProxyAsyncOp)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(TMAStoreArrive)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(TMAStoreWait)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(SetMaxNReg) TIR_DEFINE_TL_BUILTIN(SetMaxNReg)
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
......
...@@ -136,6 +136,22 @@ const Op &SyncThreadsPartialOp(); ...@@ -136,6 +136,22 @@ const Op &SyncThreadsPartialOp();
*/ */
const Op &FenceProxyAsyncOp(); const Op &FenceProxyAsyncOp();
/*!
* \brief Indicate arrival of warp issuing TMA_STORE
*
* TMAStoreArrive()
*
*/
const Op &TMAStoreArrive();
/*!
* \brief Wait for TMA_STORE to finish
*
* TMAStoreWait()
*
*/
const Op &TMAStoreWait();
/*! /*!
* \brief Set reg hint for warp-specialized branched * \brief Set reg hint for warp-specialized branched
* *
......
...@@ -843,6 +843,10 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -843,6 +843,10 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
print_extern_call_stmt(func_name, 2); print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::FenceProxyAsyncOp())) { } else if (op->op.same_as(tl::FenceProxyAsyncOp())) {
print_extern_call_stmt("tl::fence_proxy_async"); print_extern_call_stmt("tl::fence_proxy_async");
} else if (op->op.same_as(tl::TMAStoreArrive())) {
print_extern_call_stmt("tl::tma_store_arrive");
} else if (op->op.same_as(tl::TMAStoreWait())) {
print_extern_call_stmt("tl::tma_store_wait<0>");
} else if (op->op.same_as(tl::SetMaxNReg())) { } else if (op->op.same_as(tl::SetMaxNReg())) {
this->PrintIndent(); this->PrintIndent();
int nreg = Downcast<IntImm>(op->args[0])->value; int nreg = Downcast<IntImm>(op->args[0])->value;
......
...@@ -227,6 +227,15 @@ TL_DEVICE void fence_proxy_async() { ...@@ -227,6 +227,15 @@ TL_DEVICE void fence_proxy_async() {
asm volatile("fence.proxy.async.shared::cta;" : :); asm volatile("fence.proxy.async.shared::cta;" : :);
} }
// Indicate arrival of warp issuing TMA_STORE
TL_DEVICE void tma_store_arrive() {
asm volatile("cp.async.bulk.commit_group;");
}
template <int Count> TL_DEVICE void tma_store_wait() {
asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(Count) : "memory");
}
TL_DEVICE void syncthreads_partial(uint64_t &smem_barrier) { TL_DEVICE void syncthreads_partial(uint64_t &smem_barrier) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier); uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
uint64_t state = 0; uint64_t state = 0;
......
...@@ -112,6 +112,31 @@ private: ...@@ -112,6 +112,31 @@ private:
std::unordered_map<const StmtNode *, Proxy> map_; std::unordered_map<const StmtNode *, Proxy> map_;
}; };
class TMAStoreSyncInjector : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc f) {
auto T = TMAStoreSyncInjector();
f.CopyOnWrite()->body = T(f->body);
return f;
}
private:
Stmt VisitStmt_(const EvaluateNode *op) final {
if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(TMAStoreOp())) {
Array<Stmt> new_body;
new_body.push_back(GetRef<Evaluate>(op));
new_body.push_back(
Evaluate(Call(DataType::Handle(), TMAStoreArrive(), {})));
new_body.push_back(
Evaluate(Call(DataType::Handle(), TMAStoreWait(), {})));
return SeqStmt(std::move(new_body));
}
}
return StmtExprMutator::VisitStmt_(op);
}
};
class InjectFenceProxy : public StmtExprMutator { class InjectFenceProxy : public StmtExprMutator {
public: public:
static PrimFunc Substitute(PrimFunc f) { static PrimFunc Substitute(PrimFunc f) {
...@@ -161,6 +186,7 @@ using namespace tir::transform; ...@@ -161,6 +186,7 @@ using namespace tir::transform;
tvm::transform::Pass InjectFenceProxy() { tvm::transform::Pass InjectFenceProxy() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
f = TMAStoreSyncInjector::Substitute(f);
return InjectFenceProxy::Substitute(f); return InjectFenceProxy::Substitute(f);
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.InjectFenceProxy", {}); return CreatePrimFuncPass(pass_func, 0, "tl.InjectFenceProxy", {});
......
...@@ -57,5 +57,4 @@ def test_lower_hopper_intrin_barrier(): ...@@ -57,5 +57,4 @@ def test_lower_hopper_intrin_barrier():
if __name__ == "__main__": if __name__ == "__main__":
# tilelang.testing.main() tilelang.testing.main()
test_lower_hopper_intrin_barrier() \ No newline at end of file
...@@ -25,6 +25,14 @@ def FenceProxyAsyncOp(*args): ...@@ -25,6 +25,14 @@ def FenceProxyAsyncOp(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.FenceProxyAsyncOp"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.FenceProxyAsyncOp"), *args)
def TMAStoreArrive(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.TMAStoreArrive"), *args)
def TMAStoreWait(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.TMAStoreWait"), *args)
def SetMaxNReg(*args): def SetMaxNReg(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.SetMaxNReg"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.SetMaxNReg"), *args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment