Unverified Commit ddfaac36 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Refactor Pass `InjectFenceProxy` and expose some warp group...

[Refactor] Refactor Pass `InjectFenceProxy` and expose some warp group primitives in frontend (#977)

* • InjectFenceProxy docs and tests

  - annotate proxy fence injector with context comments for async/generic detection
  - add compiler internals doc covering the pass mechanics and link it in docs index
  - repair fence proxy test by fixing descriptor init usage and fence counter logic

* do not consider call_extern as async.

* doc update.

* reduce test size for sparse mla
parent 77e31e52
# InjectFenceProxy Pass
`tl.InjectFenceProxy` is a TIR-level transform that keeps the GPU proxy state consistent on NVIDIA Hopper (SM90+) by inserting `fence.proxy.async` instructions when control flow switches from generic memory operations to asynchronous proxy operations.
## Why Fences Are Needed
Hopper separates memory instructions into generic and asynchronous proxy paths. When an asynchronous instruction (for example, `cp.async` or `tma.load`) issues after generic traffic (like `ldmatrix` or plain buffer stores), the hardware requires a `fence.proxy.async` to guarantee ordering. Missing fences can lead to race conditions or undefined behaviour.
## What the Pass Does
- Walks every statement in the `PrimFunc`, tracking whether it behaves as a **generic**, **async**, or **neutral** proxy (neutral statements reset the state, such as an explicit fence).
- Automatically lowers `tma_store` intrinsics into the required `arrive`/`wait` handshake so that TMA stores participate correctly in synchronization.
- Injects an explicit `fence.proxy.async` whenever a generic statement is followed by an async statement without an intervening neutral barrier.
The pass is conservative: unknown extern calls are treated as async so that the fence is inserted rather than accidentally omitted.
### Timeline View
```
generic initialize_descriptor → generic shared-store → async wgmma
│ │ │
└─ generic proxy ┴─ generic proxy ┴─ async proxy
│ fence inserted here ↑
└──────────────────────────────┘
```
The proxy tracker scans the sequence from left to right. The moment it detects a transition from generic to async (between the store and `cp.async` above), it synthesizes a `fence.proxy.async` to reset the hardware proxy state before the async path runs.
## Coverage of Intrinsics
The tracker understands the TileLang intrinsics for TMA load/store, shared-memory MMA (`wgmma`), and TVM/PTX async copy intrinsics (`cp.async` variants). Generic operations currently include `ldmatrix`, `stmatrix`, and descriptor initialization. Other IR nodes (loops, blocks, attributes) receive a proxy kind derived from their bodies so that the analysis survives structured control flow.
## Usage
The pass is part of the default TileLang lowering pipeline. To apply it manually:
```python
from tilelang import tl
from tvm import IRModule
mod = IRModule({"main": prim_func})
with tvm.transform.PassContext():
mod = tl.transform.InjectFenceProxy()(mod)
```
## End-to-End Example
Before the pass:
```python
@T.prim_func
def kernel():
with T.Kernel(1):
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
smem = T.decl_buffer((128,), "float16", scope="shared")
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32)
smem[0] = T.float16(0)
T.ptx_wgmma_ss(
"float16",
"m64n64k16",
T.bool(True),
T.bool(True),
"fp16",
"fp16",
"fp16",
desc.data,
T.int32(0),
desc.data,
T.int32(0),
smem.data,
T.int32(0),
T.bool(True),
1,
1,
)
```
After `tl.transform.InjectFenceProxy`:
```python
@T.prim_func
def kernel():
with T.Kernel(1):
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
smem = T.decl_buffer((128,), "float16", scope="shared")
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32)
smem[0] = T.float16(0)
T.fence_proxy_async()
T.ptx_wgmma_ss(
"float16",
"m64n64k16",
T.bool(True),
T.bool(True),
"fp16",
"fp16",
"fp16",
desc.data,
T.int32(0),
desc.data,
T.int32(0),
smem.data,
T.int32(0),
T.bool(True),
1,
1,
)
```
The only change is the `fence_proxy_async` between the generic descriptor setup / shared-memory write and the async `wgmma`. In larger kernels the pass performs the same operation across nested blocks, loops, and conditional branches.
## Extending the Pass
If you introduce a new intrinsic that behaves like an async proxy, add it to `IsAsyncIntrinsic` in `src/transform/inject_fence_proxy.cc`. Likewise, extend `IsKnownGeneric` for additional generic operations. When adding new neutral barriers, make sure they set the proxy kind to `kNeutral` so the state resets correctly.
......@@ -40,6 +40,7 @@ deeplearning_operators/deepseek_mla
:caption: COMPILER INTERNALS
compiler_internals/letstmt_inline
compiler_internals/inject_fence_proxy
:::
:::{toctree}
......
......@@ -21,7 +21,7 @@ def test_example_fp8_lighting_indexer():
def test_example_sparse_mla_fwd():
# small shapes for testing
test_sparse_mla_fwd(
S=1024, SKV=2048, H=128, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
@tilelang.testing.requires_cuda
......@@ -29,14 +29,14 @@ def test_example_sparse_mla_fwd():
def test_example_sparse_mla_fwd_pipelined():
# small shapes for testing
test_sparse_mla_fwd_pipelined(
S=1024, SKV=2048, H=128, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_bwd():
test_sparse_mla_bwd(
S=1024, SKV=2048, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
S=256, SKV=1024, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
if __name__ == "__main__":
......
......@@ -203,6 +203,21 @@ TIR_DEFINE_TL_BUILTIN(no_set_max_nreg)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warpgroup_arrive)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warpgroup_commit_batch)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warpgroup_wait)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(wait_wgmma)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
......
......@@ -334,6 +334,30 @@ TVM_DLL const Op &set_max_nreg();
*/
TVM_DLL const Op &no_set_max_nreg();
/*!
* \brief Arrive at a warpgroup fence for WGMMA sequences
*
* warpgroup_arrive()
*
*/
TVM_DLL const Op &warpgroup_arrive();
/*!
* \brief Commit the current warpgroup batch for WGMMA sequences
*
* warpgroup_commit_batch()
*
*/
TVM_DLL const Op &warpgroup_commit_batch();
/*!
* \brief Wait for the warpgroup batch identified by num_mma
*
* warpgroup_wait(num_mma)
*
*/
TVM_DLL const Op &warpgroup_wait();
/*!
* \brief Wait the previous wgmma to finish
*
......
......@@ -1374,6 +1374,15 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
print_extern_call_stmt("tl::tma_store_arrive");
} else if (op->op.same_as(tl::tma_store_wait())) {
print_extern_call_stmt("tl::tma_store_wait<0>");
} else if (op->op.same_as(tl::warpgroup_arrive())) {
print_extern_call_stmt("tl::warpgroup_arrive");
} else if (op->op.same_as(tl::warpgroup_commit_batch())) {
print_extern_call_stmt("tl::warpgroup_commit_batch");
} else if (op->op.same_as(tl::warpgroup_wait())) {
this->PrintIndent();
int num_mma = Downcast<IntImm>(op->args[0])->value;
this->stream << "tl::warpgroup_wait<" << std::to_string(num_mma)
<< ">();\n";
} else if (op->op.same_as(tl::set_max_nreg())) {
this->PrintIndent();
int nreg = Downcast<IntImm>(op->args[0])->value;
......
......@@ -2,9 +2,18 @@
#if __CUDA_ARCH_LIST__ >= 900
#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/mma_sm90_gmma.hpp"
#include "cutlass/cutlass.h"
namespace tl {
TL_DEVICE void warpgroup_arrive() { cute::warpgroup_arrive(); }
TL_DEVICE void warpgroup_commit_batch() { cute::warpgroup_commit_batch(); }
template <int NumMma> TL_DEVICE void warpgroup_wait() {
cute::warpgroup_wait<NumMma>();
}
// Template parameter:
// thread_extent: the logical size (in number of threads) of each "group"
// within which we want to elect exactly ONE representative
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file inject_fence_proxy.cc
* \brief Inject fence between generic and async proxies (sm90+)
* \brief Inject proxy fences between generic and async proxies (sm90+)
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/transform.h>
#include <tvm/runtime/logging.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_map>
#include <utility>
#include "../op/builtin.h"
namespace tvm {
namespace tl {
using namespace tir;
using tvm::transform::PassContext;
enum class Proxy : uint8_t { kGeneric, kAsync, kBoth };
// Tracks what kind of proxy activity a statement performs so we can decide when
// to inject fences while traversing the IR.
enum class ProxyKind : uint8_t {
kUnknown,
kGeneric,
kAsync,
kMixed,
kNeutral, // Acts as a barrier and resets proxy state (e.g., fence
// instructions)
};
class ProxyMarker : public StmtVisitor {
public:
ProxyMarker() = default;
namespace {
Proxy GetProxy(const StmtNode *stmt) const {
auto it = map_.find(stmt);
// ICHECK(it != map_.end());
// TODO: This is a hack implementation to avoid the ICHECK failure.
if (it == map_.end()) {
return Proxy::kGeneric;
}
return it->second;
}
inline bool IsAsync(ProxyKind kind) { return kind == ProxyKind::kAsync; }
inline bool IsGeneric(ProxyKind kind) { return kind == ProxyKind::kGeneric; }
Proxy GetProxy(const Stmt &stmt) const { return GetProxy(stmt.get()); }
// Merge two proxy kinds to represent the aggregate behaviour of a compound
// node.
inline ProxyKind CombineProxy(ProxyKind a, ProxyKind b) {
if (a == ProxyKind::kUnknown)
return b;
if (b == ProxyKind::kUnknown)
return a;
if (a == ProxyKind::kNeutral)
return b;
if (b == ProxyKind::kNeutral)
return a;
if (a == b)
return a;
return ProxyKind::kMixed;
}
void VisitStmt_(const EvaluateNode *op) final {
Proxy proxy = Proxy::kAsync;
if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(ptx_ldmatrix()) ||
call->op.same_as(ptx_stmatrix())) {
proxy = Proxy::kGeneric;
}
}
SetProxy(op, proxy);
}
// We only need a fence when transitioning from generic operations to async
// ones.
inline bool NeedsFence(ProxyKind prev, ProxyKind curr) {
if (prev == ProxyKind::kUnknown || curr == ProxyKind::kUnknown)
return false;
if (prev == ProxyKind::kNeutral || curr == ProxyKind::kNeutral)
return false;
if (prev == ProxyKind::kMixed || curr == ProxyKind::kMixed)
return false;
return IsGeneric(prev) && IsAsync(curr);
}
void VisitStmt_(const BufferStoreNode *op) final {
Proxy proxy = Proxy::kGeneric;
SetProxy(op, proxy);
inline bool IsFenceCall(const CallNode *call) {
return call && call->op.same_as(fence_proxy_async());
}
// Identify async intrinsics emitted by TileLang or TVM that require a fence
// when they follow generic proxies.
bool IsAsyncIntrinsic(const CallNode *call) {
if (call == nullptr) {
return false;
}
void VisitStmt_(const SeqStmtNode *op) final {
StmtVisitor::VisitStmt_(op);
auto role = GetProxy(op->seq[0]);
for (auto stmt : op->seq) {
if (role != GetProxy(stmt)) {
role = Proxy::kBoth;
break;
// TileLang async intrinsics
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col()) ||
call->op.same_as(tma_store()) || call->op.same_as(tma_store_arrive()) ||
call->op.same_as(tma_store_wait()) ||
call->op.same_as(ptx_cp_async_barrier_noinc()) ||
call->op.same_as(ptx_wgmma_ss()) || call->op.same_as(ptx_wgmma_rs())) {
return true;
}
// PTX async copy intrinsics
if (call->op.same_as(builtin::ptx_cp_async()) ||
call->op.same_as(builtin::ptx_cp_async_barrier()) ||
call->op.same_as(builtin::ptx_cp_async_bulk())) {
return true;
}
SetProxy(op, role);
return false;
}
// Known ops that must be treated as generic proxies (e.g. ldmatrix/stmatrix).
bool IsKnownGeneric(const CallNode *call) {
if (call == nullptr) {
return false;
}
return call->op.same_as(ptx_ldmatrix()) || call->op.same_as(ptx_stmatrix()) ||
call->op.same_as(initialize_descriptor());
}
void VisitStmt_(const IfThenElseNode *op) final {
StmtVisitor::VisitStmt_(op);
auto role = GetProxy(op->then_case);
if (op->else_case.defined()) {
auto role_else = GetProxy(op->else_case.value());
if (role != role_else)
role = Proxy::kBoth;
ProxyKind ProxyFromAttrValue(const ObjectRef &value) {
if (const auto *str = value.as<StringImmNode>()) {
if (str->value == "async") {
return ProxyKind::kAsync;
}
SetProxy(op, role);
if (str->value == "generic") {
return ProxyKind::kGeneric;
}
void VisitStmt_(const BlockRealizeNode *op) final {
StmtVisitor::VisitStmt_(op);
SetProxy(op, GetProxy(op->block));
if (str->value == "neutral") {
return ProxyKind::kNeutral;
}
template <class NodeType> void HandleBodyStmt(const NodeType *op) {
StmtVisitor::VisitStmt_(op);
SetProxy(op, GetProxy(op->body));
}
return ProxyKind::kUnknown;
}
void VisitStmt_(const ForNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const LetStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const AttrStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const AssertStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const BlockNode *op) final { HandleBodyStmt(op); }
private:
void SetProxy(const StmtNode *stmt, Proxy proxy) { map_[stmt] = proxy; }
std::unordered_map<const StmtNode *, Proxy> map_;
};
// TMA stores must be followed by the arrive/wait pair. We rewrite them as part
// of the pass to guarantee the proper synchronization semantics.
class TMAStoreSyncInjector : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc f) {
auto T = TMAStoreSyncInjector();
f.CopyOnWrite()->body = T(f->body);
static PrimFunc Apply(PrimFunc f) {
if (!f->body.defined()) {
return f;
}
auto injector = TMAStoreSyncInjector();
f.CopyOnWrite()->body = injector(f->body);
return f;
}
private:
Stmt operator()(const Stmt &stmt) { return StmtExprMutator::VisitStmt(stmt); }
Stmt VisitStmt_(const EvaluateNode *op) final {
if (auto call = op->value.as<CallNode>()) {
Stmt mutated = StmtExprMutator::VisitStmt_(op);
const auto *node = mutated.as<EvaluateNode>();
if (const auto *call = node->value.as<CallNode>()) {
if (call->op.same_as(tma_store())) {
Array<Stmt> new_body;
new_body.push_back(GetRef<Evaluate>(op));
new_body.push_back(
Array<Stmt> seq;
seq.push_back(mutated);
seq.push_back(
Evaluate(Call(DataType::Handle(), tma_store_arrive(), {})));
new_body.push_back(
Evaluate(Call(DataType::Handle(), tma_store_wait(), {})));
return SeqStmt(std::move(new_body));
seq.push_back(Evaluate(Call(DataType::Handle(), tma_store_wait(), {})));
return SeqStmt(std::move(seq));
}
}
return StmtExprMutator::VisitStmt_(op);
return mutated;
}
};
class InjectFenceProxy : public StmtExprMutator {
// Main pass: track the proxy state while walking the IR and inject fences when
// switching from generic to async proxies.
class ProxyFenceInjector : public StmtMutator {
public:
static PrimFunc Substitute(PrimFunc f) {
auto T = InjectFenceProxy();
f.CopyOnWrite()->body = T(f->body);
static PrimFunc Apply(PrimFunc f) {
if (!f->body.defined()) {
return f;
}
ProxyFenceInjector injector;
f.CopyOnWrite()->body = injector.VisitStmt(f->body);
return f;
}
private:
Proxy get_generic_proxy(const Stmt &stmt) {
auto marker = ProxyMarker();
marker(stmt);
return marker.GetProxy(stmt);
Stmt VisitStmt_(const SeqStmtNode *op) final {
Array<Stmt> seq;
seq.reserve(op->seq.size());
ProxyKind sequence_kind = ProxyKind::kUnknown;
ProxyKind prev_kind = ProxyKind::kUnknown;
for (const Stmt &stmt : op->seq) {
Stmt new_stmt = VisitStmt(stmt);
ProxyKind current_kind = GetProxyKind(new_stmt);
if (!seq.empty() && NeedsFence(prev_kind, current_kind)) {
Stmt fence = MakeFenceStmt();
seq.push_back(fence);
prev_kind = GetProxyKind(fence);
}
Stmt VisitStmt_(const SeqStmtNode *op) final {
ICHECK(!op->seq.empty());
Array<Stmt> new_body;
Proxy cur_proxy, prev_proxy;
auto fence_stmt =
Evaluate(Call(DataType::Handle(), fence_proxy_async(), {}));
prev_proxy = get_generic_proxy(op->seq[0]);
new_body.push_back(VisitStmt(op->seq[0]));
if (op->seq.size() > 1) {
for (int i = 1; i < static_cast<int>(op->seq.size()); i++) {
cur_proxy = get_generic_proxy(op->seq[i]);
if (cur_proxy == Proxy::kAsync && prev_proxy == Proxy::kGeneric) {
new_body.push_back(fence_stmt);
seq.push_back(new_stmt);
sequence_kind = CombineProxy(sequence_kind, current_kind);
prev_kind = current_kind;
}
Stmt result = seq.size() == 1 ? seq[0] : SeqStmt(std::move(seq));
SetProxyKind(result, sequence_kind);
return result;
}
Stmt VisitStmt_(const EvaluateNode *op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
const auto *evaluate = stmt.as<EvaluateNode>();
ProxyKind kind = ProxyKind::kGeneric;
if (const auto *call = evaluate->value.as<CallNode>()) {
if (IsFenceCall(call)) {
kind = ProxyKind::kNeutral;
} else if (IsAsyncIntrinsic(call)) {
kind = ProxyKind::kAsync;
} else if (IsKnownGeneric(call)) {
kind = ProxyKind::kGeneric;
} else {
// Treat unknown externs as async to avoid missing required fences.
kind = ProxyKind::kAsync;
}
}
SetProxyKind(stmt, kind);
return stmt;
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
SetProxyKind(stmt, ProxyKind::kGeneric);
return stmt;
}
Stmt VisitStmt_(const IfThenElseNode *op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
const auto *node = stmt.as<IfThenElseNode>();
ProxyKind kind = GetProxyKind(node->then_case);
if (node->else_case.defined()) {
kind = CombineProxy(kind, GetProxyKind(node->else_case.value()));
}
SetProxyKind(stmt, kind);
return stmt;
}
new_body.push_back(VisitStmt(op->seq[i]));
prev_proxy = cur_proxy;
Stmt VisitStmt_(const AttrStmtNode *op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
const auto *node = stmt.as<AttrStmtNode>();
ProxyKind body_kind = GetProxyKind(node->body);
SetProxyKind(stmt, body_kind);
return stmt;
}
Stmt VisitStmt_(const BlockRealizeNode *op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
const auto *node = stmt.as<BlockRealizeNode>();
SetProxyKind(stmt, GetProxyKind(node->block));
return stmt;
}
ICHECK(!new_body.empty());
return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body));
Stmt VisitStmt_(const BlockNode *op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
const auto *node = stmt.as<BlockNode>();
ProxyKind kind = ProxyKind::kUnknown;
if (node->init.defined()) {
kind = CombineProxy(kind, GetProxyKind(node->init.value()));
}
kind = CombineProxy(kind, GetProxyKind(node->body));
SetProxyKind(stmt, kind);
return stmt;
}
// Stmt VisitStmt_(const ForNode* op) final {
// std::cout << "ForNode:" << op->body->GetTypeKey() << std::endl;
// return StmtExprMutator::VisitStmt_(op);
// }
Stmt VisitStmt_(const ForNode *op) final { return VisitSingleBody(op); }
Stmt VisitStmt_(const LetStmtNode *op) final { return VisitSingleBody(op); }
Stmt VisitStmt_(const AssertStmtNode *op) final {
return VisitSingleBody(op);
}
Stmt VisitStmt_(const WhileNode *op) final { return VisitSingleBody(op); }
template <typename NodeType> Stmt VisitSingleBody(const NodeType *op) {
Stmt stmt = StmtMutator::VisitStmt_(op);
const auto *node = stmt.as<NodeType>();
ProxyKind body_kind = GetProxyKind(node->body);
SetProxyKind(stmt, body_kind);
return stmt;
}
void SetProxyKind(const Stmt &stmt, ProxyKind kind) {
proxy_map_[stmt.get()] = kind;
}
InjectFenceProxy() = default;
ProxyKind GetProxyKind(const Stmt &stmt) const {
if (!stmt.defined()) {
return ProxyKind::kUnknown;
}
auto it = proxy_map_.find(stmt.get());
if (it == proxy_map_.end()) {
return ProxyKind::kUnknown;
}
return it->second;
}
Stmt MakeFenceStmt() {
Stmt fence = Evaluate(Call(DataType::Handle(), fence_proxy_async(), {}));
SetProxyKind(fence, ProxyKind::kNeutral);
return fence;
}
std::unordered_map<const StmtNode *, ProxyKind> proxy_map_;
};
using namespace tir::transform;
} // namespace
tvm::transform::Pass InjectFenceProxy() {
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
f = TMAStoreSyncInjector::Substitute(f);
return InjectFenceProxy::Substitute(f);
auto pass_func = [](PrimFunc f, const IRModule &, const PassContext &) {
f = TMAStoreSyncInjector::Apply(f);
f = ProxyFenceInjector::Apply(f);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.InjectFenceProxy", {});
return tir::transform::CreatePrimFuncPass(pass_func, 0, "tl.InjectFenceProxy",
{});
}
TVM_FFI_STATIC_INIT_BLOCK({
......
......@@ -927,8 +927,8 @@ private:
original_order.push_back(MakeBlock(child, buffer_data_to_buffer_));
};
for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) {
const auto *nested_block_realize =
pipeline_body_seq->seq[i].as<BlockRealizeNode>();
const Stmt &child = pipeline_body_seq->seq[i];
const auto *nested_block_realize = child.as<BlockRealizeNode>();
if (nested_block_realize && is_one(nested_block_realize->predicate) &&
nested_block_realize->block->body->IsInstance<SeqStmtNode>()) {
const Block &nested_pipeline_block = nested_block_realize->block;
......@@ -938,13 +938,8 @@ private:
pipeline_allocs.push_back(buffer);
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
const auto *nested_seq = nested_pipeline_block->body.as<SeqStmtNode>();
for (size_t j = 0; j < nested_seq->seq.size(); j++) {
f_add_child(nested_seq->seq[j]);
}
} else {
f_add_child(pipeline_body_seq->seq[i]);
}
f_add_child(child);
}
auto pipeline_stages = Downcast<Array<Integer>>(
......
......@@ -53,5 +53,175 @@ def test_lower_fence_proxy():
_check(before, after)
def test_async_to_generic_no_double_fence():
@T.prim_func
def before():
with T.Kernel(8):
A_shared = T.decl_buffer((1024,), "uint8", scope="shared.dyn")
B_shared = T.decl_buffer((1024,), "uint8", scope="shared.dyn")
T.ptx_cp_async("uint8", A_shared.data, 0, B_shared.data, 0, 16)
T.fence_proxy_async()
T.call_extern("handle", "generic_op")
mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
mod = tl.transform.InjectFenceProxy()(mod)
def _count_fences(stmt):
count = 0
def visit(node):
nonlocal count
if isinstance(node, tir.Evaluate):
call = node.value
if isinstance(call, tir.Call):
op = call.op
name = getattr(op, "name", None)
if name == "tl.fence_proxy_async":
count += 1
tir.stmt_functor.post_order_visit(stmt, visit)
return count
assert _count_fences(mod["main"].body) == 1
def test_proxy_hint_override():
@T.prim_func
def before():
with T.Kernel(8):
T.evaluate(T.call_extern("handle", "custom_async"))
with T.attr("proxy_scope", "tl.proxy_hint", "neutral"):
T.evaluate(T.call_extern("handle", "custom_generic"))
T.evaluate(T.call_extern("handle", "custom_async_tail"))
mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
mod = tl.transform.InjectFenceProxy()(mod)
def _has_fence(stmt):
result = False
def visit(node):
nonlocal result
if isinstance(node, tir.Evaluate):
call = node.value
if isinstance(call, tir.Call):
op = call.op
name = getattr(op, "name", None)
if name == "tl.fence_proxy_async":
result = True
tir.stmt_functor.post_order_visit(stmt, visit)
return result
assert not _has_fence(mod["main"].body)
def test_tma_store_sync_injection():
@T.prim_func
def before():
with T.Kernel(8):
A_global = T.decl_buffer((128,), "float16", scope="global")
T.evaluate(T.call_intrin("handle", tir.op.Op.get("tl.tma_store"), A_global.data))
mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
mod = tl.transform.InjectFenceProxy()(mod)
arrives = 0
waits = 0
def visit(node):
nonlocal arrives, waits
if isinstance(node, tir.Evaluate):
call = node.value
if isinstance(call, tir.Call):
name = getattr(call.op, "name", None)
if name == "tl.tma_store_arrive":
arrives += 1
elif name in ("tl.tma_store_wait", "tl.tma_store_wait<0>"):
waits += 1
tir.stmt_functor.post_order_visit(mod["main"].body, visit)
assert arrives == 1
assert waits == 1
def test_wgmma_marked_async():
@T.prim_func
def before():
with T.Kernel(1):
A_shared = T.decl_buffer((1,), "float16", scope="shared")
desc_a = T.decl_buffer((1,), "uint64", scope="local.descriptor")
desc_b = T.decl_buffer((1,), "uint64", scope="local.descriptor")
C_local = T.decl_buffer((32,), "float16", scope="local")
A_shared[0] = T.float16(0)
T.warpgroup_arrive()
T.ptx_wgmma_ss("float16", "m64n64k16", T.bool(True), T.bool(True), "fp16", "fp16",
"fp16", desc_a.data, T.int32(0), desc_b.data, T.int32(0), C_local.data,
T.int32(0), T.bool(True), 1, 1)
mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
mod = tl.transform.InjectFenceProxy()(mod)
order = []
def visit(node):
if isinstance(node, tir.Evaluate):
call = node.value
if isinstance(call, tir.Call):
order.append(getattr(call.op, "name", ""))
tir.stmt_functor.post_order_visit(mod["main"].body, visit)
assert "tl.ptx_wgmma_ss" in order
assert "tl.fence_proxy_async" in order
assert order.index("tl.fence_proxy_async") < order.index("tl.ptx_wgmma_ss")
def test_wgmma_after_descriptor():
@T.prim_func
def before():
with T.Kernel(1):
desc_a = T.decl_buffer((1,), "uint64", scope="local.descriptor")
desc_b = T.decl_buffer((1,), "uint64", scope="local.descriptor")
C_local = T.decl_buffer((32,), "float16", scope="local")
T.initialize_descriptor(desc_a, T.uint64(0), 2, 1, 32)
T.initialize_descriptor(desc_b, T.uint64(0), 2, 1, 32)
T.warpgroup_arrive()
T.ptx_wgmma_ss("float16", "m64n64k16", T.bool(True), T.bool(True), "fp16", "fp16",
"fp16", desc_a.data, T.int32(0), desc_b.data, T.int32(0), C_local.data,
T.int32(0), T.bool(True), 1, 1)
mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
mod = tl.transform.InjectFenceProxy()(mod)
fence_count = 0
order = []
def visit(node):
nonlocal fence_count
if isinstance(node, tir.Evaluate):
call = node.value
if isinstance(call, tir.Call):
name = getattr(call.op, "name", "")
order.append(name)
if name == "tl.fence_proxy_async":
fence_count += 1
tir.stmt_functor.post_order_visit(mod["main"].body, visit)
assert fence_count >= 1
assert "tl.warpgroup_arrive" in order
assert order.index("tl.fence_proxy_async") < order.index("tl.warpgroup_arrive")
if __name__ == "__main__":
test_lower_fence_proxy()
tilelang.testing.main()
......@@ -156,7 +156,12 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
if allow_fence_proxy(target=target):
# in hopper device, wgmma is an async proxy
# so we need to inject a fence proxy before it
print("Before injectFenceProxy")
print(mod)
mod = tilelang.transform.InjectFenceProxy()(mod)
print("After InjectFenceProxy")
print(mod)
mod = tilelang.transform.LowerOpaqueBlock()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
mod = tilelang.transform.FlattenBuffer()(mod)
......
......@@ -242,12 +242,14 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
@T.macro
def _warp_mma(A_buf, B_buf, C_local_buf):
# TODO(lei): inject warpgroup_fence_operand for C_local_buf
desc_a = T.alloc_descriptor()
desc_b = T.alloc_descriptor()
T.initialize_descriptor(desc_a, A_buf.access_ptr("w"), a_swizzle_mode,
int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4))
T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode,
int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
T.warpgroup_arrive()
for ki in T.serial(0, (k_dim // micro_size_k)):
for i in T.serial(m_dim // 64):
A_offset = (ki % ak_atom_size) * micro_size_k + i * 64 * a_swizzle_atom_elems + (
......@@ -262,6 +264,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
(A_offset * elems_in_bytes) >> 4, desc_b.data,
(B_offset * elems_in_bytes) >> 4, C_local_buf.data, C_offset,
scale_out, scale_in_a, scale_in_b)
T.warpgroup_commit_batch()
T.warpgroup_wait(0)
return _warp_mma(A_buf, B_buf, C_local_buf)
......
......@@ -249,6 +249,37 @@ def mbarrier_expect_tx(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_expect_tx"), *args)
def warpgroup_arrive():
"""Signal warpgroup readiness for subsequent WGMMA operations.
Returns:
tir.Call: A handle to the warpgroup arrive operation.
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_arrive"))
def warpgroup_commit_batch():
"""Commit the current warpgroup batch for WGMMA operations.
Returns:
tir.Call: A handle to the warpgroup commit batch operation.
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_commit_batch"))
def warpgroup_wait(num_mma: int):
"""Wait for completion of the specified warpgroup batch.
Args:
num_mma: int
Identifier of the warpgroup MMA batch to wait on.
Returns:
tir.Call: A handle to the warpgroup wait operation.
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_wait"), num_mma)
def wait_wgmma(id: int):
"""Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment