Unverified Commit ddfaac36 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Refactor Pass `InjectFenceProxy` and expose some warp group...

[Refactor] Refactor Pass `InjectFenceProxy` and expose some warp group primitives in frontend (#977)

* • InjectFenceProxy docs and tests

  - annotate proxy fence injector with context comments for async/generic detection
  - add compiler internals doc covering the pass mechanics and link it in docs index
  - repair fence proxy test by fixing descriptor init usage and fence counter logic

* do not consider call_extern as async.

* doc update.

* reduce test size for sparse mla
parent 77e31e52
# InjectFenceProxy Pass
`tl.InjectFenceProxy` is a TIR-level transform that keeps the GPU proxy state consistent on NVIDIA Hopper (SM90+) by inserting `fence.proxy.async` instructions when control flow switches from generic memory operations to asynchronous proxy operations.
## Why Fences Are Needed
Hopper separates memory instructions into generic and asynchronous proxy paths. When an asynchronous instruction (for example, `cp.async` or `tma.load`) issues after generic traffic (like `ldmatrix` or plain buffer stores), the hardware requires a `fence.proxy.async` to guarantee ordering. Missing fences can lead to race conditions or undefined behaviour.
## What the Pass Does
- Walks every statement in the `PrimFunc`, tracking whether it behaves as a **generic**, **async**, or **neutral** proxy (neutral statements reset the state, such as an explicit fence).
- Automatically lowers `tma_store` intrinsics into the required `arrive`/`wait` handshake so that TMA stores participate correctly in synchronization.
- Injects an explicit `fence.proxy.async` whenever a generic statement is followed by an async statement without an intervening neutral barrier.
The pass is conservative: unknown extern calls are treated as async so that the fence is inserted rather than accidentally omitted.
### Timeline View
```
generic initialize_descriptor → generic shared-store → async wgmma
│ │ │
└─ generic proxy ┴─ generic proxy ┴─ async proxy
│ fence inserted here ↑
└──────────────────────────────┘
```
The proxy tracker scans the sequence from left to right. The moment it detects a transition from generic to async (between the store and `cp.async` above), it synthesizes a `fence.proxy.async` to reset the hardware proxy state before the async path runs.
## Coverage of Intrinsics
The tracker understands the TileLang intrinsics for TMA load/store, shared-memory MMA (`wgmma`), and TVM/PTX async copy intrinsics (`cp.async` variants). Generic operations currently include `ldmatrix`, `stmatrix`, and descriptor initialization. Other IR nodes (loops, blocks, attributes) receive a proxy kind derived from their bodies so that the analysis survives structured control flow.
## Usage
The pass is part of the default TileLang lowering pipeline. To apply it manually:
```python
from tilelang import tl
from tvm import IRModule
mod = IRModule({"main": prim_func})
with tvm.transform.PassContext():
mod = tl.transform.InjectFenceProxy()(mod)
```
## End-to-End Example
Before the pass:
```python
@T.prim_func
def kernel():
with T.Kernel(1):
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
smem = T.decl_buffer((128,), "float16", scope="shared")
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32)
smem[0] = T.float16(0)
T.ptx_wgmma_ss(
"float16",
"m64n64k16",
T.bool(True),
T.bool(True),
"fp16",
"fp16",
"fp16",
desc.data,
T.int32(0),
desc.data,
T.int32(0),
smem.data,
T.int32(0),
T.bool(True),
1,
1,
)
```
After `tl.transform.InjectFenceProxy`:
```python
@T.prim_func
def kernel():
with T.Kernel(1):
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
smem = T.decl_buffer((128,), "float16", scope="shared")
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32)
smem[0] = T.float16(0)
T.fence_proxy_async()
T.ptx_wgmma_ss(
"float16",
"m64n64k16",
T.bool(True),
T.bool(True),
"fp16",
"fp16",
"fp16",
desc.data,
T.int32(0),
desc.data,
T.int32(0),
smem.data,
T.int32(0),
T.bool(True),
1,
1,
)
```
The only change is the `fence_proxy_async` between the generic descriptor setup / shared-memory write and the async `wgmma`. In larger kernels the pass performs the same operation across nested blocks, loops, and conditional branches.
## Extending the Pass
If you introduce a new intrinsic that behaves like an async proxy, add it to `IsAsyncIntrinsic` in `src/transform/inject_fence_proxy.cc`. Likewise, extend `IsKnownGeneric` for additional generic operations. When adding new neutral barriers, make sure they set the proxy kind to `kNeutral` so the state resets correctly.
...@@ -40,6 +40,7 @@ deeplearning_operators/deepseek_mla ...@@ -40,6 +40,7 @@ deeplearning_operators/deepseek_mla
:caption: COMPILER INTERNALS :caption: COMPILER INTERNALS
compiler_internals/letstmt_inline compiler_internals/letstmt_inline
compiler_internals/inject_fence_proxy
::: :::
:::{toctree} :::{toctree}
...@@ -54,4 +55,4 @@ autoapi/tilelang/index ...@@ -54,4 +55,4 @@ autoapi/tilelang/index
:caption: Privacy :caption: Privacy
privacy privacy
::: :::
\ No newline at end of file
...@@ -21,7 +21,7 @@ def test_example_fp8_lighting_indexer(): ...@@ -21,7 +21,7 @@ def test_example_fp8_lighting_indexer():
def test_example_sparse_mla_fwd(): def test_example_sparse_mla_fwd():
# small shapes for testing # small shapes for testing
test_sparse_mla_fwd( test_sparse_mla_fwd(
S=1024, SKV=2048, H=128, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
...@@ -29,14 +29,14 @@ def test_example_sparse_mla_fwd(): ...@@ -29,14 +29,14 @@ def test_example_sparse_mla_fwd():
def test_example_sparse_mla_fwd_pipelined(): def test_example_sparse_mla_fwd_pipelined():
# small shapes for testing # small shapes for testing
test_sparse_mla_fwd_pipelined( test_sparse_mla_fwd_pipelined(
S=1024, SKV=2048, H=128, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0) @tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_bwd(): def test_example_sparse_mla_bwd():
test_sparse_mla_bwd( test_sparse_mla_bwd(
S=1024, SKV=2048, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False) S=256, SKV=1024, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -203,6 +203,21 @@ TIR_DEFINE_TL_BUILTIN(no_set_max_nreg) ...@@ -203,6 +203,21 @@ TIR_DEFINE_TL_BUILTIN(no_set_max_nreg)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warpgroup_arrive)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warpgroup_commit_batch)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warpgroup_wait)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(wait_wgmma) TIR_DEFINE_TL_BUILTIN(wait_wgmma)
.set_num_inputs(1) .set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
......
...@@ -334,6 +334,30 @@ TVM_DLL const Op &set_max_nreg(); ...@@ -334,6 +334,30 @@ TVM_DLL const Op &set_max_nreg();
*/ */
TVM_DLL const Op &no_set_max_nreg(); TVM_DLL const Op &no_set_max_nreg();
/*!
* \brief Arrive at a warpgroup fence for WGMMA sequences
*
* warpgroup_arrive()
*
*/
TVM_DLL const Op &warpgroup_arrive();
/*!
* \brief Commit the current warpgroup batch for WGMMA sequences
*
* warpgroup_commit_batch()
*
*/
TVM_DLL const Op &warpgroup_commit_batch();
/*!
* \brief Wait for the warpgroup batch identified by num_mma
*
* warpgroup_wait(num_mma)
*
*/
TVM_DLL const Op &warpgroup_wait();
/*! /*!
* \brief Wait the previous wgmma to finish * \brief Wait the previous wgmma to finish
* *
......
...@@ -1374,6 +1374,15 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1374,6 +1374,15 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
print_extern_call_stmt("tl::tma_store_arrive"); print_extern_call_stmt("tl::tma_store_arrive");
} else if (op->op.same_as(tl::tma_store_wait())) { } else if (op->op.same_as(tl::tma_store_wait())) {
print_extern_call_stmt("tl::tma_store_wait<0>"); print_extern_call_stmt("tl::tma_store_wait<0>");
} else if (op->op.same_as(tl::warpgroup_arrive())) {
print_extern_call_stmt("tl::warpgroup_arrive");
} else if (op->op.same_as(tl::warpgroup_commit_batch())) {
print_extern_call_stmt("tl::warpgroup_commit_batch");
} else if (op->op.same_as(tl::warpgroup_wait())) {
this->PrintIndent();
int num_mma = Downcast<IntImm>(op->args[0])->value;
this->stream << "tl::warpgroup_wait<" << std::to_string(num_mma)
<< ">();\n";
} else if (op->op.same_as(tl::set_max_nreg())) { } else if (op->op.same_as(tl::set_max_nreg())) {
this->PrintIndent(); this->PrintIndent();
int nreg = Downcast<IntImm>(op->args[0])->value; int nreg = Downcast<IntImm>(op->args[0])->value;
......
...@@ -2,9 +2,18 @@ ...@@ -2,9 +2,18 @@
#if __CUDA_ARCH_LIST__ >= 900 #if __CUDA_ARCH_LIST__ >= 900
#include "cute/arch/cluster_sm90.hpp" #include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/mma_sm90_gmma.hpp"
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
namespace tl { namespace tl {
TL_DEVICE void warpgroup_arrive() { cute::warpgroup_arrive(); }
TL_DEVICE void warpgroup_commit_batch() { cute::warpgroup_commit_batch(); }
template <int NumMma> TL_DEVICE void warpgroup_wait() {
cute::warpgroup_wait<NumMma>();
}
// Template parameter: // Template parameter:
// thread_extent: the logical size (in number of threads) of each "group" // thread_extent: the logical size (in number of threads) of each "group"
// within which we want to elect exactly ONE representative // within which we want to elect exactly ONE representative
...@@ -53,4 +62,4 @@ template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_dealloc() { ...@@ -53,4 +62,4 @@ template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_dealloc() {
asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount)); asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount));
} }
} // namespace tl } // namespace tl
#endif #endif
\ No newline at end of file
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*! /*!
* \file inject_fence_proxy.cc * \file inject_fence_proxy.cc
* \brief Inject fence between generic and async proxies (sm90+) * \brief Inject proxy fences between generic and async proxies (sm90+)
*/ */
#include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/transform.h>
#include <tvm/runtime/logging.h>
#include <tvm/tir/analysis.h> #include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <unordered_map>
#include <utility>
#include "../op/builtin.h" #include "../op/builtin.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
using tvm::transform::PassContext;
enum class Proxy : uint8_t { kGeneric, kAsync, kBoth }; // Tracks what kind of proxy activity a statement performs so we can decide when
// to inject fences while traversing the IR.
enum class ProxyKind : uint8_t {
kUnknown,
kGeneric,
kAsync,
kMixed,
kNeutral, // Acts as a barrier and resets proxy state (e.g., fence
// instructions)
};
class ProxyMarker : public StmtVisitor { namespace {
public:
ProxyMarker() = default;
Proxy GetProxy(const StmtNode *stmt) const {
auto it = map_.find(stmt);
// ICHECK(it != map_.end());
// TODO: This is a hack implementation to avoid the ICHECK failure.
if (it == map_.end()) {
return Proxy::kGeneric;
}
return it->second;
}
Proxy GetProxy(const Stmt &stmt) const { return GetProxy(stmt.get()); } inline bool IsAsync(ProxyKind kind) { return kind == ProxyKind::kAsync; }
inline bool IsGeneric(ProxyKind kind) { return kind == ProxyKind::kGeneric; }
void VisitStmt_(const EvaluateNode *op) final { // Merge two proxy kinds to represent the aggregate behaviour of a compound
Proxy proxy = Proxy::kAsync; // node.
if (auto call = op->value.as<CallNode>()) { inline ProxyKind CombineProxy(ProxyKind a, ProxyKind b) {
if (call->op.same_as(ptx_ldmatrix()) || if (a == ProxyKind::kUnknown)
call->op.same_as(ptx_stmatrix())) { return b;
proxy = Proxy::kGeneric; if (b == ProxyKind::kUnknown)
} return a;
} if (a == ProxyKind::kNeutral)
SetProxy(op, proxy); return b;
} if (b == ProxyKind::kNeutral)
return a;
if (a == b)
return a;
return ProxyKind::kMixed;
}
void VisitStmt_(const BufferStoreNode *op) final { // We only need a fence when transitioning from generic operations to async
Proxy proxy = Proxy::kGeneric; // ones.
SetProxy(op, proxy); inline bool NeedsFence(ProxyKind prev, ProxyKind curr) {
} if (prev == ProxyKind::kUnknown || curr == ProxyKind::kUnknown)
return false;
if (prev == ProxyKind::kNeutral || curr == ProxyKind::kNeutral)
return false;
if (prev == ProxyKind::kMixed || curr == ProxyKind::kMixed)
return false;
return IsGeneric(prev) && IsAsync(curr);
}
void VisitStmt_(const SeqStmtNode *op) final { inline bool IsFenceCall(const CallNode *call) {
StmtVisitor::VisitStmt_(op); return call && call->op.same_as(fence_proxy_async());
auto role = GetProxy(op->seq[0]); }
for (auto stmt : op->seq) {
if (role != GetProxy(stmt)) {
role = Proxy::kBoth;
break;
}
}
SetProxy(op, role);
}
void VisitStmt_(const IfThenElseNode *op) final { // Identify async intrinsics emitted by TileLang or TVM that require a fence
StmtVisitor::VisitStmt_(op); // when they follow generic proxies.
auto role = GetProxy(op->then_case); bool IsAsyncIntrinsic(const CallNode *call) {
if (op->else_case.defined()) { if (call == nullptr) {
auto role_else = GetProxy(op->else_case.value()); return false;
if (role != role_else)
role = Proxy::kBoth;
}
SetProxy(op, role);
} }
void VisitStmt_(const BlockRealizeNode *op) final { // TileLang async intrinsics
StmtVisitor::VisitStmt_(op); if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col()) ||
SetProxy(op, GetProxy(op->block)); call->op.same_as(tma_store()) || call->op.same_as(tma_store_arrive()) ||
call->op.same_as(tma_store_wait()) ||
call->op.same_as(ptx_cp_async_barrier_noinc()) ||
call->op.same_as(ptx_wgmma_ss()) || call->op.same_as(ptx_wgmma_rs())) {
return true;
} }
template <class NodeType> void HandleBodyStmt(const NodeType *op) { // PTX async copy intrinsics
StmtVisitor::VisitStmt_(op); if (call->op.same_as(builtin::ptx_cp_async()) ||
SetProxy(op, GetProxy(op->body)); call->op.same_as(builtin::ptx_cp_async_barrier()) ||
call->op.same_as(builtin::ptx_cp_async_bulk())) {
return true;
} }
void VisitStmt_(const ForNode *op) final { HandleBodyStmt(op); } return false;
void VisitStmt_(const LetStmtNode *op) final { HandleBodyStmt(op); } }
void VisitStmt_(const AttrStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const AssertStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const BlockNode *op) final { HandleBodyStmt(op); }
private: // Known ops that must be treated as generic proxies (e.g. ldmatrix/stmatrix).
void SetProxy(const StmtNode *stmt, Proxy proxy) { map_[stmt] = proxy; } bool IsKnownGeneric(const CallNode *call) {
std::unordered_map<const StmtNode *, Proxy> map_; if (call == nullptr) {
}; return false;
}
return call->op.same_as(ptx_ldmatrix()) || call->op.same_as(ptx_stmatrix()) ||
call->op.same_as(initialize_descriptor());
}
ProxyKind ProxyFromAttrValue(const ObjectRef &value) {
if (const auto *str = value.as<StringImmNode>()) {
if (str->value == "async") {
return ProxyKind::kAsync;
}
if (str->value == "generic") {
return ProxyKind::kGeneric;
}
if (str->value == "neutral") {
return ProxyKind::kNeutral;
}
}
return ProxyKind::kUnknown;
}
// TMA stores must be followed by the arrive/wait pair. We rewrite them as part
// of the pass to guarantee the proper synchronization semantics.
class TMAStoreSyncInjector : public StmtExprMutator { class TMAStoreSyncInjector : public StmtExprMutator {
public: public:
static PrimFunc Substitute(PrimFunc f) { static PrimFunc Apply(PrimFunc f) {
auto T = TMAStoreSyncInjector(); if (!f->body.defined()) {
f.CopyOnWrite()->body = T(f->body); return f;
}
auto injector = TMAStoreSyncInjector();
f.CopyOnWrite()->body = injector(f->body);
return f; return f;
} }
private: private:
Stmt operator()(const Stmt &stmt) { return StmtExprMutator::VisitStmt(stmt); }
Stmt VisitStmt_(const EvaluateNode *op) final { Stmt VisitStmt_(const EvaluateNode *op) final {
if (auto call = op->value.as<CallNode>()) { Stmt mutated = StmtExprMutator::VisitStmt_(op);
const auto *node = mutated.as<EvaluateNode>();
if (const auto *call = node->value.as<CallNode>()) {
if (call->op.same_as(tma_store())) { if (call->op.same_as(tma_store())) {
Array<Stmt> new_body; Array<Stmt> seq;
new_body.push_back(GetRef<Evaluate>(op)); seq.push_back(mutated);
new_body.push_back( seq.push_back(
Evaluate(Call(DataType::Handle(), tma_store_arrive(), {}))); Evaluate(Call(DataType::Handle(), tma_store_arrive(), {})));
new_body.push_back( seq.push_back(Evaluate(Call(DataType::Handle(), tma_store_wait(), {})));
Evaluate(Call(DataType::Handle(), tma_store_wait(), {}))); return SeqStmt(std::move(seq));
return SeqStmt(std::move(new_body));
} }
} }
return StmtExprMutator::VisitStmt_(op); return mutated;
} }
}; };
class InjectFenceProxy : public StmtExprMutator { // Main pass: track the proxy state while walking the IR and inject fences when
// switching from generic to async proxies.
class ProxyFenceInjector : public StmtMutator {
public: public:
static PrimFunc Substitute(PrimFunc f) { static PrimFunc Apply(PrimFunc f) {
auto T = InjectFenceProxy(); if (!f->body.defined()) {
f.CopyOnWrite()->body = T(f->body); return f;
}
ProxyFenceInjector injector;
f.CopyOnWrite()->body = injector.VisitStmt(f->body);
return f; return f;
} }
private: private:
Proxy get_generic_proxy(const Stmt &stmt) { Stmt VisitStmt_(const SeqStmtNode *op) final {
auto marker = ProxyMarker(); Array<Stmt> seq;
marker(stmt); seq.reserve(op->seq.size());
return marker.GetProxy(stmt);
ProxyKind sequence_kind = ProxyKind::kUnknown;
ProxyKind prev_kind = ProxyKind::kUnknown;
for (const Stmt &stmt : op->seq) {
Stmt new_stmt = VisitStmt(stmt);
ProxyKind current_kind = GetProxyKind(new_stmt);
if (!seq.empty() && NeedsFence(prev_kind, current_kind)) {
Stmt fence = MakeFenceStmt();
seq.push_back(fence);
prev_kind = GetProxyKind(fence);
}
seq.push_back(new_stmt);
sequence_kind = CombineProxy(sequence_kind, current_kind);
prev_kind = current_kind;
}
Stmt result = seq.size() == 1 ? seq[0] : SeqStmt(std::move(seq));
SetProxyKind(result, sequence_kind);
return result;
} }
Stmt VisitStmt_(const SeqStmtNode *op) final { Stmt VisitStmt_(const EvaluateNode *op) final {
ICHECK(!op->seq.empty()); Stmt stmt = StmtMutator::VisitStmt_(op);
Array<Stmt> new_body; const auto *evaluate = stmt.as<EvaluateNode>();
Proxy cur_proxy, prev_proxy; ProxyKind kind = ProxyKind::kGeneric;
auto fence_stmt =
Evaluate(Call(DataType::Handle(), fence_proxy_async(), {})); if (const auto *call = evaluate->value.as<CallNode>()) {
prev_proxy = get_generic_proxy(op->seq[0]); if (IsFenceCall(call)) {
new_body.push_back(VisitStmt(op->seq[0])); kind = ProxyKind::kNeutral;
if (op->seq.size() > 1) { } else if (IsAsyncIntrinsic(call)) {
for (int i = 1; i < static_cast<int>(op->seq.size()); i++) { kind = ProxyKind::kAsync;
cur_proxy = get_generic_proxy(op->seq[i]); } else if (IsKnownGeneric(call)) {
if (cur_proxy == Proxy::kAsync && prev_proxy == Proxy::kGeneric) { kind = ProxyKind::kGeneric;
new_body.push_back(fence_stmt); } else {
} // Treat unknown externs as async to avoid missing required fences.
new_body.push_back(VisitStmt(op->seq[i])); kind = ProxyKind::kAsync;
prev_proxy = cur_proxy;
} }
} }
ICHECK(!new_body.empty());
return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)); SetProxyKind(stmt, kind);
return stmt;
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
SetProxyKind(stmt, ProxyKind::kGeneric);
return stmt;
} }
// Stmt VisitStmt_(const ForNode* op) final { Stmt VisitStmt_(const IfThenElseNode *op) final {
// std::cout << "ForNode:" << op->body->GetTypeKey() << std::endl; Stmt stmt = StmtMutator::VisitStmt_(op);
// return StmtExprMutator::VisitStmt_(op); const auto *node = stmt.as<IfThenElseNode>();
// } ProxyKind kind = GetProxyKind(node->then_case);
if (node->else_case.defined()) {
kind = CombineProxy(kind, GetProxyKind(node->else_case.value()));
}
SetProxyKind(stmt, kind);
return stmt;
}
Stmt VisitStmt_(const AttrStmtNode *op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
const auto *node = stmt.as<AttrStmtNode>();
ProxyKind body_kind = GetProxyKind(node->body);
SetProxyKind(stmt, body_kind);
return stmt;
}
Stmt VisitStmt_(const BlockRealizeNode *op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
const auto *node = stmt.as<BlockRealizeNode>();
SetProxyKind(stmt, GetProxyKind(node->block));
return stmt;
}
InjectFenceProxy() = default; Stmt VisitStmt_(const BlockNode *op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
const auto *node = stmt.as<BlockNode>();
ProxyKind kind = ProxyKind::kUnknown;
if (node->init.defined()) {
kind = CombineProxy(kind, GetProxyKind(node->init.value()));
}
kind = CombineProxy(kind, GetProxyKind(node->body));
SetProxyKind(stmt, kind);
return stmt;
}
Stmt VisitStmt_(const ForNode *op) final { return VisitSingleBody(op); }
Stmt VisitStmt_(const LetStmtNode *op) final { return VisitSingleBody(op); }
Stmt VisitStmt_(const AssertStmtNode *op) final {
return VisitSingleBody(op);
}
Stmt VisitStmt_(const WhileNode *op) final { return VisitSingleBody(op); }
template <typename NodeType> Stmt VisitSingleBody(const NodeType *op) {
Stmt stmt = StmtMutator::VisitStmt_(op);
const auto *node = stmt.as<NodeType>();
ProxyKind body_kind = GetProxyKind(node->body);
SetProxyKind(stmt, body_kind);
return stmt;
}
void SetProxyKind(const Stmt &stmt, ProxyKind kind) {
proxy_map_[stmt.get()] = kind;
}
ProxyKind GetProxyKind(const Stmt &stmt) const {
if (!stmt.defined()) {
return ProxyKind::kUnknown;
}
auto it = proxy_map_.find(stmt.get());
if (it == proxy_map_.end()) {
return ProxyKind::kUnknown;
}
return it->second;
}
Stmt MakeFenceStmt() {
Stmt fence = Evaluate(Call(DataType::Handle(), fence_proxy_async(), {}));
SetProxyKind(fence, ProxyKind::kNeutral);
return fence;
}
std::unordered_map<const StmtNode *, ProxyKind> proxy_map_;
}; };
using namespace tir::transform; } // namespace
tvm::transform::Pass InjectFenceProxy() { tvm::transform::Pass InjectFenceProxy() {
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { auto pass_func = [](PrimFunc f, const IRModule &, const PassContext &) {
f = TMAStoreSyncInjector::Substitute(f); f = TMAStoreSyncInjector::Apply(f);
return InjectFenceProxy::Substitute(f); f = ProxyFenceInjector::Apply(f);
return f;
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.InjectFenceProxy", {}); return tir::transform::CreatePrimFuncPass(pass_func, 0, "tl.InjectFenceProxy",
{});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK({
......
...@@ -927,8 +927,8 @@ private: ...@@ -927,8 +927,8 @@ private:
original_order.push_back(MakeBlock(child, buffer_data_to_buffer_)); original_order.push_back(MakeBlock(child, buffer_data_to_buffer_));
}; };
for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) { for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) {
const auto *nested_block_realize = const Stmt &child = pipeline_body_seq->seq[i];
pipeline_body_seq->seq[i].as<BlockRealizeNode>(); const auto *nested_block_realize = child.as<BlockRealizeNode>();
if (nested_block_realize && is_one(nested_block_realize->predicate) && if (nested_block_realize && is_one(nested_block_realize->predicate) &&
nested_block_realize->block->body->IsInstance<SeqStmtNode>()) { nested_block_realize->block->body->IsInstance<SeqStmtNode>()) {
const Block &nested_pipeline_block = nested_block_realize->block; const Block &nested_pipeline_block = nested_block_realize->block;
...@@ -938,13 +938,8 @@ private: ...@@ -938,13 +938,8 @@ private:
pipeline_allocs.push_back(buffer); pipeline_allocs.push_back(buffer);
buffer_data_to_buffer_.Set(buffer->data, buffer); buffer_data_to_buffer_.Set(buffer->data, buffer);
} }
const auto *nested_seq = nested_pipeline_block->body.as<SeqStmtNode>();
for (size_t j = 0; j < nested_seq->seq.size(); j++) {
f_add_child(nested_seq->seq[j]);
}
} else {
f_add_child(pipeline_body_seq->seq[i]);
} }
f_add_child(child);
} }
auto pipeline_stages = Downcast<Array<Integer>>( auto pipeline_stages = Downcast<Array<Integer>>(
......
...@@ -53,5 +53,175 @@ def test_lower_fence_proxy(): ...@@ -53,5 +53,175 @@ def test_lower_fence_proxy():
_check(before, after) _check(before, after)
def test_async_to_generic_no_double_fence():
@T.prim_func
def before():
with T.Kernel(8):
A_shared = T.decl_buffer((1024,), "uint8", scope="shared.dyn")
B_shared = T.decl_buffer((1024,), "uint8", scope="shared.dyn")
T.ptx_cp_async("uint8", A_shared.data, 0, B_shared.data, 0, 16)
T.fence_proxy_async()
T.call_extern("handle", "generic_op")
mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
mod = tl.transform.InjectFenceProxy()(mod)
def _count_fences(stmt):
count = 0
def visit(node):
nonlocal count
if isinstance(node, tir.Evaluate):
call = node.value
if isinstance(call, tir.Call):
op = call.op
name = getattr(op, "name", None)
if name == "tl.fence_proxy_async":
count += 1
tir.stmt_functor.post_order_visit(stmt, visit)
return count
assert _count_fences(mod["main"].body) == 1
def test_proxy_hint_override():
@T.prim_func
def before():
with T.Kernel(8):
T.evaluate(T.call_extern("handle", "custom_async"))
with T.attr("proxy_scope", "tl.proxy_hint", "neutral"):
T.evaluate(T.call_extern("handle", "custom_generic"))
T.evaluate(T.call_extern("handle", "custom_async_tail"))
mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
mod = tl.transform.InjectFenceProxy()(mod)
def _has_fence(stmt):
result = False
def visit(node):
nonlocal result
if isinstance(node, tir.Evaluate):
call = node.value
if isinstance(call, tir.Call):
op = call.op
name = getattr(op, "name", None)
if name == "tl.fence_proxy_async":
result = True
tir.stmt_functor.post_order_visit(stmt, visit)
return result
assert not _has_fence(mod["main"].body)
def test_tma_store_sync_injection():
@T.prim_func
def before():
with T.Kernel(8):
A_global = T.decl_buffer((128,), "float16", scope="global")
T.evaluate(T.call_intrin("handle", tir.op.Op.get("tl.tma_store"), A_global.data))
mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
mod = tl.transform.InjectFenceProxy()(mod)
arrives = 0
waits = 0
def visit(node):
nonlocal arrives, waits
if isinstance(node, tir.Evaluate):
call = node.value
if isinstance(call, tir.Call):
name = getattr(call.op, "name", None)
if name == "tl.tma_store_arrive":
arrives += 1
elif name in ("tl.tma_store_wait", "tl.tma_store_wait<0>"):
waits += 1
tir.stmt_functor.post_order_visit(mod["main"].body, visit)
assert arrives == 1
assert waits == 1
def test_wgmma_marked_async():
@T.prim_func
def before():
with T.Kernel(1):
A_shared = T.decl_buffer((1,), "float16", scope="shared")
desc_a = T.decl_buffer((1,), "uint64", scope="local.descriptor")
desc_b = T.decl_buffer((1,), "uint64", scope="local.descriptor")
C_local = T.decl_buffer((32,), "float16", scope="local")
A_shared[0] = T.float16(0)
T.warpgroup_arrive()
T.ptx_wgmma_ss("float16", "m64n64k16", T.bool(True), T.bool(True), "fp16", "fp16",
"fp16", desc_a.data, T.int32(0), desc_b.data, T.int32(0), C_local.data,
T.int32(0), T.bool(True), 1, 1)
mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
mod = tl.transform.InjectFenceProxy()(mod)
order = []
def visit(node):
if isinstance(node, tir.Evaluate):
call = node.value
if isinstance(call, tir.Call):
order.append(getattr(call.op, "name", ""))
tir.stmt_functor.post_order_visit(mod["main"].body, visit)
assert "tl.ptx_wgmma_ss" in order
assert "tl.fence_proxy_async" in order
assert order.index("tl.fence_proxy_async") < order.index("tl.ptx_wgmma_ss")
def test_wgmma_after_descriptor():
@T.prim_func
def before():
with T.Kernel(1):
desc_a = T.decl_buffer((1,), "uint64", scope="local.descriptor")
desc_b = T.decl_buffer((1,), "uint64", scope="local.descriptor")
C_local = T.decl_buffer((32,), "float16", scope="local")
T.initialize_descriptor(desc_a, T.uint64(0), 2, 1, 32)
T.initialize_descriptor(desc_b, T.uint64(0), 2, 1, 32)
T.warpgroup_arrive()
T.ptx_wgmma_ss("float16", "m64n64k16", T.bool(True), T.bool(True), "fp16", "fp16",
"fp16", desc_a.data, T.int32(0), desc_b.data, T.int32(0), C_local.data,
T.int32(0), T.bool(True), 1, 1)
mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
mod = tl.transform.InjectFenceProxy()(mod)
fence_count = 0
order = []
def visit(node):
nonlocal fence_count
if isinstance(node, tir.Evaluate):
call = node.value
if isinstance(call, tir.Call):
name = getattr(call.op, "name", "")
order.append(name)
if name == "tl.fence_proxy_async":
fence_count += 1
tir.stmt_functor.post_order_visit(mod["main"].body, visit)
assert fence_count >= 1
assert "tl.warpgroup_arrive" in order
assert order.index("tl.fence_proxy_async") < order.index("tl.warpgroup_arrive")
if __name__ == "__main__": if __name__ == "__main__":
test_lower_fence_proxy() tilelang.testing.main()
...@@ -156,7 +156,12 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -156,7 +156,12 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
if allow_fence_proxy(target=target): if allow_fence_proxy(target=target):
# in hopper device, wgmma is an async proxy # in hopper device, wgmma is an async proxy
# so we need to inject a fence proxy before it # so we need to inject a fence proxy before it
print("Before injectFenceProxy")
print(mod)
mod = tilelang.transform.InjectFenceProxy()(mod) mod = tilelang.transform.InjectFenceProxy()(mod)
print("After InjectFenceProxy")
print(mod)
mod = tilelang.transform.LowerOpaqueBlock()(mod) mod = tilelang.transform.LowerOpaqueBlock()(mod)
mod = tir.transform.NarrowDataType(32)(mod) mod = tir.transform.NarrowDataType(32)(mod)
mod = tilelang.transform.FlattenBuffer()(mod) mod = tilelang.transform.FlattenBuffer()(mod)
......
...@@ -242,12 +242,14 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -242,12 +242,14 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
@T.macro @T.macro
def _warp_mma(A_buf, B_buf, C_local_buf): def _warp_mma(A_buf, B_buf, C_local_buf):
# TODO(lei): inject warpgroup_fence_operand for C_local_buf
desc_a = T.alloc_descriptor() desc_a = T.alloc_descriptor()
desc_b = T.alloc_descriptor() desc_b = T.alloc_descriptor()
T.initialize_descriptor(desc_a, A_buf.access_ptr("w"), a_swizzle_mode, T.initialize_descriptor(desc_a, A_buf.access_ptr("w"), a_swizzle_mode,
int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4)) int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4))
T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode, T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode,
int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
T.warpgroup_arrive()
for ki in T.serial(0, (k_dim // micro_size_k)): for ki in T.serial(0, (k_dim // micro_size_k)):
for i in T.serial(m_dim // 64): for i in T.serial(m_dim // 64):
A_offset = (ki % ak_atom_size) * micro_size_k + i * 64 * a_swizzle_atom_elems + ( A_offset = (ki % ak_atom_size) * micro_size_k + i * 64 * a_swizzle_atom_elems + (
...@@ -262,6 +264,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -262,6 +264,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
(A_offset * elems_in_bytes) >> 4, desc_b.data, (A_offset * elems_in_bytes) >> 4, desc_b.data,
(B_offset * elems_in_bytes) >> 4, C_local_buf.data, C_offset, (B_offset * elems_in_bytes) >> 4, C_local_buf.data, C_offset,
scale_out, scale_in_a, scale_in_b) scale_out, scale_in_a, scale_in_b)
T.warpgroup_commit_batch()
T.warpgroup_wait(0)
return _warp_mma(A_buf, B_buf, C_local_buf) return _warp_mma(A_buf, B_buf, C_local_buf)
......
...@@ -249,6 +249,37 @@ def mbarrier_expect_tx(*args): ...@@ -249,6 +249,37 @@ def mbarrier_expect_tx(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_expect_tx"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_expect_tx"), *args)
def warpgroup_arrive():
"""Signal warpgroup readiness for subsequent WGMMA operations.
Returns:
tir.Call: A handle to the warpgroup arrive operation.
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_arrive"))
def warpgroup_commit_batch():
"""Commit the current warpgroup batch for WGMMA operations.
Returns:
tir.Call: A handle to the warpgroup commit batch operation.
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_commit_batch"))
def warpgroup_wait(num_mma: int):
"""Wait for completion of the specified warpgroup batch.
Args:
num_mma: int
Identifier of the warpgroup MMA batch to wait on.
Returns:
tir.Call: A handle to the warpgroup wait operation.
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_wait"), num_mma)
def wait_wgmma(id: int): def wait_wgmma(id: int):
"""Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete. """Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment