"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3fab6624fdd2753233a10984b62025076a7e9889"
Unverified Commit 17a63976 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Enhancement] Add missing `fence_barrier_init` primitive after mbarrier init (#1121)

* [Enhancement] Add missing  primitive after mbarrier init

* lint
parent 0dc50a54
...@@ -503,6 +503,7 @@ TVM_DLL const Op &initialize_descriptor(); ...@@ -503,6 +503,7 @@ TVM_DLL const Op &initialize_descriptor();
* This op is used to represent a descriptor start address setting operation in * This op is used to represent a descriptor start address setting operation in
* tilelang. * tilelang.
*/ */
TVM_DLL const Op &increase_descriptor_offset(); TVM_DLL const Op &increase_descriptor_offset();
/*! /*!
* \brief tilelang intrinsic for element-wise atomic addition. * \brief tilelang intrinsic for element-wise atomic addition.
......
...@@ -133,6 +133,10 @@ TL_DEVICE void fence_proxy_async() { ...@@ -133,6 +133,10 @@ TL_DEVICE void fence_proxy_async() {
asm volatile("fence.proxy.async.shared::cta;" : :); asm volatile("fence.proxy.async.shared::cta;" : :);
} }
TL_DEVICE void fence_barrier_init() {
asm volatile("fence.mbarrier_init.release.cluster;" : :);
}
// Indicate arrival of warp issuing TMA_STORE // Indicate arrival of warp issuing TMA_STORE
TL_DEVICE void tma_store_arrive() { TL_DEVICE void tma_store_arrive() {
asm volatile("cp.async.bulk.commit_group;"); asm volatile("cp.async.bulk.commit_group;");
......
...@@ -83,6 +83,16 @@ public: ...@@ -83,6 +83,16 @@ public:
stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]); stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]);
stmt_seq.push_back(stmt_); stmt_seq.push_back(stmt_);
if (!init_mbarrier_calls_.empty()) { if (!init_mbarrier_calls_.empty()) {
// Note from FlashAttention:
// Helps with visibility of barrier init operations across warps /
// cta / cluster Available as a separate function so as to batch
// inits across barriers and fence once Note : It must be composed
// with an appropriate sync instruction with the right scope to
// ensure visibility eg. __syncthreads() or a cluster_arrive() +
// cluster_wait()
Stmt mem_fence = Evaluate(Call(
DataType::Handle(), tvm::tl::ptx_fence_barrier_init(), {}));
stmt_seq.push_back(mem_fence);
Stmt mem_sync = Stmt mem_sync =
Evaluate(Call(DataType::Handle(), builtin::tvm_storage_sync(), Evaluate(Call(DataType::Handle(), builtin::tvm_storage_sync(),
{StringImm("shared")})); {StringImm("shared")}));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment