"tests/L0/git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "3490b9e1a26ba94d6939c63171dc6a0f083793aa"
Commit 6972aed7 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Language] Support explicit programming for identified warp groups (#445)

* [Refactor] Update KernelLaunch to clarify CPU and GPU kernel launch logic

* Added comments to distinguish between CPU and GPU kernel launch sections for better code readability.
* Changed the creation of empty blocks to use a consistent "root" identifier, enhancing clarity in frame management.

* [Refactor] Rename operations for consistency in lower_hopper_intrin and related files

* Updated function names from CamelCase to snake_case for better consistency across the codebase.
* Refactored calls to `CreateTMADescriptorOp`, `CreateListofMBarrierOp`, and similar functions to their new names: `create_tma_descriptor`, `create_list_of_mbarrier`, etc.
* Adjusted corresponding test cases to reflect these changes, ensuring compatibility with the new naming conventions.

* [Refactor] Rename operations to snake_case for consistency

* Updated function names from CamelCase to snake_case across various files, including `CreateTMADescriptorOp` to `create_tma_descriptor`, `GetMBarrierOp` to `get_mbarrier`, and others.
* Adjusted corresponding calls and definitions in the codebase to reflect these naming changes, ensuring uniformity and improved readability.
* Enhanced layout inference and loop partitioning logic to accommodate the new naming conventions.

* [Feature] Introduce Warp Specialization and Eliminate Storage Sync for MBarrier

* Added a new example `gemm_ws.py` demonstrating matrix multiplication with warp specialization using TileLang.
* Implemented `WarpSpecializeFrame` and `WarpSpecialize` functionality to manage warp group indices in TIR frames.
* Introduced `EliminateStorageSyncForMBarrier` transformation to optimize storage synchronization in mbarrier regions.
* Enhanced the TileLang API with new methods for retrieving block and thread extents.
* Updated the `LowerAndLegalize` and `OptimizeForTarget` functions to incorporate the new transformation.
* Improved layout inference and kernel launch logic for better performance and clarity.

* [Refactor] Clean up code formatting and improve readability

* Added blank lines for better separation of code blocks in `gemm_ws.py`, `phase.py`, `kernel.py`, and `warpgroup.py`.
* Reformatted the `tilelang.compile` call in `gemm_ws.py` for improved clarity.
* Updated comments in `warpgroup.py` to clarify the availability of the `WarpSpecialize` function for NVIDIA GPUs.
* Ensured consistent spacing and formatting across multiple files to enhance overall code readability.

* lint fix

* [Refactor] Update mbarrier functions for improved clarity and consistency

* Refactored `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` to accept explicit parameters for better readability.
* Updated calls in `gemm_ws.py` to use the new function signatures, enhancing code clarity.
* Adjusted `warpgroup.py` to remove unused thread extent variable, streamlining the code.
* Added detailed docstrings to clarify usage examples for memory barrier functions.

* Added blank lines in `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` for improved code readability and separation of logical sections.
parent 0fa03398
......@@ -295,9 +295,10 @@ private:
auto const_int_bound = analyzer_->const_int_bound(thread_var_);
auto min_value = const_int_bound->min_value;
auto max_value = const_int_bound->max_value;
auto extent = max_value + 1 - min_value;
thread_bounds =
Range::FromMinExtent(IntImm(thread_var_->var.dtype(), min_value),
IntImm(thread_var_->var.dtype(), max_value + 1));
IntImm(thread_var_->var.dtype(), extent));
} else {
thread_bounds = Range::FromMinExtent(0, 1);
}
......
......@@ -53,8 +53,7 @@ public:
void VisitStmt_(const EvaluateNode *op) final {
Role role = Role::kConsumer;
if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(TMALoadOp()) ||
call->op.same_as(TMALoadIm2ColOp())) {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
role = Role::kProducer;
has_bulk_copy_ = true;
}
......
......@@ -328,8 +328,8 @@ public:
if (partial_syncs_.count(stmt.get())) {
auto iter = partial_syncs_.find(stmt.get());
ICHECK(sync_scope_.rank == StorageRank::kShared);
barrier = Evaluate(Call(DataType::Int(32), tl::SyncThreadsPartialOp(),
{iter->second}));
barrier = Evaluate(
Call(DataType::Int(32), tl::sync_thread_partial(), {iter->second}));
} else {
return StmtExprMutator::VisitStmt(stmt);
}
......
......@@ -43,7 +43,7 @@ public:
void clear() { has_tma_load_ = false; }
void VisitExpr_(const CallNode *call) final {
if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
has_tma_load_ = true;
}
}
......@@ -116,8 +116,7 @@ public:
void VisitStmt_(const EvaluateNode *op) final {
Role role = Role::kConsumer;
if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(TMALoadOp()) ||
call->op.same_as(TMALoadIm2ColOp())) {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
role = Role::kProducer;
has_bulk_copy_ = true;
}
......@@ -207,11 +206,11 @@ private:
};
static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
return Call(DataType::Handle(), GetMBarrierOp(), {barrier_id});
return Call(DataType::Handle(), get_mbarrier(), {barrier_id});
}
static Stmt makeExpectTX(PrimExpr barrier_id, PrimExpr bytes) {
auto call = Call(DataType::Handle(), MBarrierExpectTX(),
auto call = Call(DataType::Handle(), mbarrier_expect_tx(),
{makeGetBarrier(barrier_id), bytes});
return Evaluate(call);
}
......@@ -229,7 +228,7 @@ static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) {
}
static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) {
auto call = Call(DataType::Handle(), MBarrierWaitParity(),
auto call = Call(DataType::Handle(), mbarrier_wait_parity(),
{makeGetBarrier(barrier_id), parity});
return Evaluate(call);
}
......@@ -273,8 +272,7 @@ private:
Stmt VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(TMALoadOp()) ||
call->op.same_as(TMALoadIm2ColOp())) {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
contain_tma_load_ = true;
if (insert_in_evaluate_) {
Array<Stmt> new_seq = {expect_tx_, GetRef<Evaluate>(op)};
......@@ -308,7 +306,7 @@ public:
private:
void VisitExpr_(const CallNode *call) final {
if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
Call access_ptr = Downcast<Call>(call->args[2]);
ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
int type_bytes = access_ptr->args[0]->dtype.bytes();
......@@ -361,7 +359,7 @@ public:
private:
PrimExpr VisitExpr_(const CallNode *op) final {
auto call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
Call access_ptr = Downcast<Call>(call->args[2]);
ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
call.CopyOnWrite()->args.Set(1, makeGetBarrier(producer_barrier_idx_));
......@@ -1082,7 +1080,7 @@ public:
private:
void VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(SetMaxNReg())) {
if (call->op.same_as(set_max_nreg())) {
int reg_hint = call->args[0].as<IntImmNode>()->value;
int is_inc = call->args[1].as<IntImmNode>()->value;
ICHECK(reg_hint <= 240 && reg_hint >= 24)
......@@ -1092,7 +1090,7 @@ private:
// producer should decrease register hint while consumer should increase
// register hint
nreg_.Set(is_inc, IntImm(DataType::Int(32), reg_hint));
} else if (call->op.same_as(NoSetMaxNReg())) {
} else if (call->op.same_as(no_set_max_nreg())) {
has_no_set_max_nreg_ = true;
}
}
......@@ -1149,7 +1147,8 @@ private:
Stmt VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(SetMaxNReg()) || call->op.same_as(NoSetMaxNReg())) {
if (call->op.same_as(set_max_nreg()) ||
call->op.same_as(no_set_max_nreg())) {
return Evaluate(0);
}
}
......@@ -1202,7 +1201,7 @@ private:
barrier_num_threads.push_back(arrive_thread_count);
}
Stmt init_barrier = Evaluate(Call(
DataType::Handle(), CreateListofMBarrierOp(), barrier_num_threads));
DataType::Handle(), create_list_of_mbarrier(), barrier_num_threads));
block.CopyOnWrite()->body = SeqStmt({init_barrier, code});
block_realize.CopyOnWrite()->block = block;
return block_realize;
......@@ -1224,9 +1223,9 @@ private:
auto inc_reg_stmt = Evaluate(0);
auto dec_reg_stmt = Evaluate(0);
if (dec_reg >= 0 && inc_reg >= 0) {
inc_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(),
inc_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
{inc_reg == 0 ? 240 : inc_reg, 1}));
dec_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(),
dec_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
{dec_reg == 0 ? 24 : dec_reg, 0}));
}
......@@ -1252,7 +1251,7 @@ private:
}
Stmt init_barrier = Evaluate(Call(
DataType::Handle(), CreateListofMBarrierOp(), barrier_num_threads));
DataType::Handle(), create_list_of_mbarrier(), barrier_num_threads));
Stmt body = IfThenElse(GE(thread_iv_->var, consumer_thread_extent),
producer_code, consumer_code);
// Add an attr here to handle the partial thread count in ThreadSync pass.
......
......@@ -44,7 +44,7 @@ def test_lower_fence_proxy():
C_local = T.decl_buffer((32,), scope="local")
for i in T.unroll(16):
C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2)
T.FenceProxyAsyncOp()
T.fence_proxy_async()
T.call_extern("handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
......
......@@ -29,7 +29,7 @@ def test_lower_hopper_intrin_barrier():
def before():
with T.Kernel(8):
_ = T.launch_thread("threadIdx.x", 128)
T.CreateListofMBarrierOp(128, 128, 128, 128)
T.create_list_of_mbarrier(128, 128, 128, 128)
@T.prim_func
def after():
......@@ -39,16 +39,16 @@ def test_lower_hopper_intrin_barrier():
with T.If(v_1 == 0), T.Then():
T.evaluate(
tir.Call("handle", "tir.ptx_init_barrier_thread_count",
[T.GetMBarrierOp(0), 128]))
[T.get_mbarrier(0), 128]))
T.evaluate(
tir.Call("handle", "tir.ptx_init_barrier_thread_count",
[T.GetMBarrierOp(1), 128]))
[T.get_mbarrier(1), 128]))
T.evaluate(
tir.Call("handle", "tir.ptx_init_barrier_thread_count",
[T.GetMBarrierOp(2), 128]))
[T.get_mbarrier(2), 128]))
T.evaluate(
tir.Call("handle", "tir.ptx_init_barrier_thread_count",
[T.GetMBarrierOp(3), 128]))
[T.get_mbarrier(3), 128]))
T.evaluate(tir.Call("handle", "tir.tvm_storage_sync", ["shared"]))
_check(before, after)
......
......@@ -48,14 +48,14 @@ def test_multi_version_buffer():
C_local[i * 2 + vec] = T.float32(0)
for k in T.serial(16, annotations={"num_stages": 3}):
if v == 0:
T.TMALoadOp(
T.CreateTMADescriptorOp(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2,
T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2,
2, 0), 0,
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 2),
k * 32, by * 64)
if v == 0:
T.TMALoadOp(
T.CreateTMADescriptorOp(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3,
T.tma_load(
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3,
2, 0), 0,
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 2),
bx * 64, k * 32)
......@@ -81,15 +81,15 @@ def test_multi_version_buffer():
C_local[i * 2 + vec] = T.float32(0)
for k in T.serial(16, annotations={"num_stages": 3}):
if v == 0:
T.TMALoadOp(
T.CreateTMADescriptorOp(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2,
T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2,
2, 0), 0,
T.tvm_access_ptr(
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32, by * 64)
if v == 0:
T.TMALoadOp(
T.CreateTMADescriptorOp(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3,
T.tma_load(
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3,
2, 0), 0,
T.tvm_access_ptr(
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2),
......
......@@ -46,15 +46,15 @@ def test_warp_specialized():
C_local = T.alloc_buffer((32,), scope="local")
for k in T.serial(16, annotations={"num_stages": 3}):
if v == 0:
T.TMALoadOp(
T.CreateTMADescriptorOp(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2,
T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2,
2, 0), 0,
T.tvm_access_ptr(
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32, by * 64)
if v == 0:
T.TMALoadOp(
T.CreateTMADescriptorOp(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3,
T.tma_load(
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3,
2, 0), 0,
T.tvm_access_ptr(
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2),
......@@ -75,35 +75,35 @@ def test_warp_specialized():
A_shared = T.decl_buffer((3, 1, 8, 256), "float16", scope="shared.dyn")
B_shared = T.decl_buffer((3, 1, 4, 512), "float16", scope="shared.dyn")
C_local = T.decl_buffer((32,), scope="local")
T.CreateListofMBarrierOp(128, 128, 128, 128, 128, 128)
T.create_list_of_mbarrier(128, 128, 128, 128, 128, 128)
T.attr([128, 128], "kWarpSpecializationScope", 0)
if v >= 128:
T.SetMaxNReg(24, 0)
T.set_max_nreg(24, 0)
for k in range(16):
T.MBarrierWaitParity(T.GetMBarrierOp(k % 3 + 3), T.bitwise_xor(k // 3 % 2, 1))
T.mbarrier_wait_parity(T.get_mbarrier(k % 3 + 3), T.bitwise_xor(k // 3 % 2, 1))
if v - 128 == 0:
T.MBarrierExpectTX(T.GetMBarrierOp(k % 3), 4096)
T.mbarrier_expect_tx(T.get_mbarrier(k % 3), 4096)
if v - 128 == 0:
T.TMALoadOp(
T.CreateTMADescriptorOp(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2,
2, 0), T.GetMBarrierOp(k % 3),
T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2,
2, 0), T.get_mbarrier(k % 3),
T.tvm_access_ptr(
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32, by * 64)
if v - 128 == 0:
T.MBarrierExpectTX(T.GetMBarrierOp(k % 3), 4096)
T.mbarrier_expect_tx(T.get_mbarrier(k % 3), 4096)
if v - 128 == 0:
T.TMALoadOp(
T.CreateTMADescriptorOp(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3,
2, 0), T.GetMBarrierOp(k % 3),
T.tma_load(
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3,
2, 0), T.get_mbarrier(k % 3),
T.tvm_access_ptr(
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2),
bx * 64, k * 32)
T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.GetMBarrierOp(k % 3)]))
T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3)]))
else:
T.SetMaxNReg(240, 1)
T.set_max_nreg(240, 1)
for k in range(16):
T.MBarrierWaitParity(T.GetMBarrierOp(k % 3), k // 3 % 2)
T.mbarrier_wait_parity(T.get_mbarrier(k % 3), k // 3 % 2)
T.call_extern(
"handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(
......@@ -112,7 +112,7 @@ def test_warp_specialized():
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))
T.evaluate(
tir.Call("handle", "tir.ptx_arrive_barrier", [T.GetMBarrierOp(k % 3 + 3)]))
tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)]))
_check(before, after)
......
......@@ -9,7 +9,7 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
target: Optional[Target] = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
if target.arch not in {"sm_90"}:
if target.arch not in {"sm_90", "sm_90a"}:
return False
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
......@@ -17,6 +17,10 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
return not (disable_tma_lower and disable_warp_specialized)
def allow_fence_proxy(target: Optional[Target] = None) -> bool:
return target.arch in {"sm_90", "sm_90a"}
def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
......@@ -60,6 +64,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# to get better performance with async copy
mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
# warp_specialized pass will pack the if stmt into the block
# so we need to lower the opaque block first
mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.MergeIfStmt()(mod)
mod = tilelang.transform.RewriteWgmmaSync()(mod)
......@@ -71,6 +77,11 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tilelang.transform.MergeIfStmt()(mod)
if allow_fence_proxy(target=target):
# in hopper device, wgmma is an async proxy
# so we need to inject a fence proxy before it
mod = tilelang.transform.InjectFenceProxy()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.FlattenBuffer()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
......@@ -104,6 +115,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.ConfigIndexBitwidth()(mod)
mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
mod = tilelang.transform.EliminateStorageSyncForMBarrier()(mod)
mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
mod = tilelang.transform.AnnotateDeviceRegions()(mod)
......
......@@ -259,6 +259,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
result = self.lib.init()
if result != 0:
error_msg = self.lib.get_last_error().decode('utf-8')
error_msg += f"\n{self.lib_code}"
raise RuntimeError(f"Initialization failed: {error_msg}")
self.cython_wrapper = CythonKernelWrapper(self.result_idx, self.params, self.lib)
......
......@@ -32,6 +32,7 @@ from .kernel import (
get_block_binding, # noqa: F401
get_block_bindings, # noqa: F401
)
from .warpgroup import ws # noqa: F401
from .allocate import (
alloc_local, # noqa: F401
alloc_shared, # noqa: F401
......
"""The language interface for tl programs."""
from tilelang import tvm as tvm
from tilelang.language import ptx_arrive_barrier
from tvm import tir
from typing import Union
from tvm.tir import PrimExpr, Var
def CreateListofMBarrierOp(*args):
def create_list_of_mbarrier(*args):
"""Create a list of memory barrier operations.
Args:
......@@ -12,10 +16,10 @@ def CreateListofMBarrierOp(*args):
Returns:
tir.Call: A handle to the created list of memory barriers
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.CreateListofMBarrierOp"), *args)
return tir.call_intrin("handle", tir.op.Op.get("tl.create_list_of_mbarrier"), *args)
def GetMBarrierOp(*args):
def get_mbarrier(*args):
"""Retrieve a memory barrier operation.
Args:
......@@ -24,10 +28,10 @@ def GetMBarrierOp(*args):
Returns:
tir.Call: A handle to the requested memory barrier
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.GetMBarrierOp"), *args)
return tir.call_intrin("handle", tir.op.Op.get("tl.get_mbarrier"), *args)
def CreateTMADescriptorOp(*args):
def create_tma_descriptor(*args):
"""Create a Tensor Memory Access (TMA) descriptor.
Args:
......@@ -36,10 +40,10 @@ def CreateTMADescriptorOp(*args):
Returns:
tir.Call: A handle to the created TMA descriptor
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.CreateTMADescriptorOp"), *args)
return tir.call_intrin("handle", tir.op.Op.get("tl.create_tma_descriptor"), *args)
def TMALoadOp(*args):
def tma_load(*args):
"""Perform a Tensor Memory Access (TMA) load operation.
Args:
......@@ -48,10 +52,10 @@ def TMALoadOp(*args):
Returns:
tir.Call: A handle to the TMA load operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.TMALoadOp"), *args)
return tir.call_intrin("handle", tir.op.Op.get("tl.tma_load"), *args)
def FenceProxyAsyncOp(*args):
def fence_proxy_async(*args):
"""Create a fence for asynchronous proxy operations.
Args:
......@@ -60,10 +64,10 @@ def FenceProxyAsyncOp(*args):
Returns:
tir.Call: A handle to the fence operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.FenceProxyAsyncOp"), *args)
return tir.call_intrin("handle", tir.op.Op.get("tl.fence_proxy_async"), *args)
def TMAStoreArrive(*args):
def tma_store_arrive(*args):
"""Signal the arrival of a TMA store operation.
Args:
......@@ -72,10 +76,10 @@ def TMAStoreArrive(*args):
Returns:
tir.Call: A handle to the store arrive operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.TMAStoreArrive"), *args)
return tir.call_intrin("handle", tir.op.Op.get("tl.tma_store_arrive"), *args)
def TMAStoreWait(*args):
def tma_store_wait(*args):
"""Wait for completion of TMA store operations.
Args:
......@@ -84,10 +88,10 @@ def TMAStoreWait(*args):
Returns:
tir.Call: A handle to the store wait operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.TMAStoreWait"), *args)
return tir.call_intrin("handle", tir.op.Op.get("tl.tma_store_wait"), *args)
def SetMaxNReg(*args):
def set_max_nreg(*args):
"""Set the maximum number of registers to use.
Args:
......@@ -96,10 +100,10 @@ def SetMaxNReg(*args):
Returns:
tir.Call: A handle to the register setting operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.SetMaxNReg"), *args)
return tir.call_intrin("handle", tir.op.Op.get("tl.set_max_nreg"), *args)
def NoSetMaxNReg(*args):
def no_set_max_nreg(*args):
"""Disable the maximum register limit setting.
Args:
......@@ -108,22 +112,66 @@ def NoSetMaxNReg(*args):
Returns:
tir.Call: A handle to the register limit disable operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.NoSetMaxNReg"), *args)
return tir.call_intrin("handle", tir.op.Op.get("tl.no_set_max_nreg"), *args)
def MBarrierWaitParity(*args):
def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr], parity: Union[int, Var]):
"""Wait for memory barrier parity condition.
Args:
*args: Variable arguments specifying the parity wait condition
mbarrier: Optional[int, PrimExpr]
The memory barrier to wait on
parity: Optional[int, Var]
The parity value to wait for
Examples:
.. code-block:: python
# Wait for parity 0 on barrier 0
T.mbarrier_wait_parity(0, 0)
# Wait for parity value in variable ko on barrier 1
T.mbarrier_wait_parity(1, ko)
# Wait using barrier handle
barrier = T.get_mbarrier(0)
T.mbarrier_wait_parity(barrier, 1)
# Common usage in pipelined kernels:
for ko in range(num_stages):
# Producer waits for consumer to finish previous iteration
T.mbarrier_wait_parity(1, ko ^ 1)
# Producer copies data
T.copy(A_global, A_shared)
# Producer signals data ready
T.mbarrier_arrive(0)
# Consumer waits for producer data
T.mbarrier_wait_parity(0, ko)
# Consumer computes
T.gemm(A_shared, B_shared, C_local)
# Consumer signals completion
T.mbarrier_arrive(1)
Returns:
tir.Call: A handle to the barrier wait operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.MBarrierWaitParity"), *args)
if isinstance(mbarrier, int):
mbarrier = get_mbarrier(mbarrier)
return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_wait_parity"), mbarrier, parity)
def mbarrier_arrive(mbarrier: Union[int, PrimExpr]):
"""Arrive at memory barrier.
Args:
mbarrier: Optional[int, PrimExpr]
The memory barrier to arrive at
"""
if isinstance(mbarrier, int):
mbarrier = get_mbarrier(mbarrier)
return ptx_arrive_barrier(mbarrier)
def MBarrierExpectTX(*args):
def mbarrier_expect_tx(*args):
"""Set expected transaction count for memory barrier.
Args:
......@@ -132,10 +180,10 @@ def MBarrierExpectTX(*args):
Returns:
tir.Call: A handle to the barrier expectation operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.MBarrierExpectTX"), *args)
return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_expect_tx"), *args)
def WaitWgmma(*args):
def wait_wgmma(*args):
"""Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.
Args:
......@@ -144,4 +192,4 @@ def WaitWgmma(*args):
Returns:
tir.Call: A handle to the WGMMA wait operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.WaitWgmma"), *args)
return tir.call_intrin("handle", tir.op.Op.get("tl.wait_wgmma"), *args)
......@@ -128,6 +128,12 @@ class KernelLaunchFrame(TIRFrame):
iter_var = self.frames[dim].iter_var
return int(iter_var.dom.extent)
def get_block_extents(self) -> List[int]:
"""
Returns the block extents for all three dimensions.
"""
return [self.get_block_extent(dim) for dim in range(3)]
def get_thread_extent(self, dim: int) -> int:
"""
Returns the thread extent for the given dimension.
......@@ -136,6 +142,12 @@ class KernelLaunchFrame(TIRFrame):
iter_var = self.frames[-4 + dim].iter_var
return int(iter_var.dom.extent)
def get_thread_extents(self) -> List[int]:
"""
Returns the thread extents for all three dimensions.
"""
return [self.get_thread_extent(dim) for dim in range(3)]
def get_thread_binding(self, dim: int = 0) -> Var:
"""
Returns the thread binding for the given dimension.
......@@ -268,3 +280,27 @@ def get_block_bindings() -> List[Var]:
"""Returns all three block bindings.
"""
return KernelLaunchFrame.Current().get_block_bindings()
def get_thread_extent(dim: int = 0) -> int:
"""Returns the thread extent for the given dimension.
"""
return KernelLaunchFrame.Current().get_thread_extent(dim)
def get_thread_extents() -> List[int]:
"""Returns all three thread extents.
"""
return KernelLaunchFrame.Current().get_thread_extents()
def get_block_extent(dim: int = 0) -> int:
"""Returns the block extent for the given dimension.
"""
return KernelLaunchFrame.Current().get_block_extent(dim)
def get_block_extents() -> List[int]:
"""Returns all three block extents.
"""
return KernelLaunchFrame.Current().get_block_extents()
"""The language interface for tl programs."""
from tvm.script.ir_builder.tir.frame import TIRFrame
from tvm._ffi import register_object
from tilelang import _ffi_api
from .kernel import get_thread_bindings, get_thread_extents
@register_object("tl.WarpSpecializeFrame")
class WarpSpecializeFrame(TIRFrame):
"""
WarpSpecializeFrame is a custom TIRFrame that manages warp group indices
and handles the entry and exit of the kernel launch scope.
"""
def WarpSpecialize(warp_group_idx: int,):
"""Tools to construct a warp group frame.
Parameters
----------
warp_group_idx : int
A integer representing warp group index
Or a list of integers representing blockDim.(x|y|z)
if the value is -1, we skip the threadIdx.x binding.
Returns
-------
res : Tuple[frame.LaunchThreadFrame]
The result LaunchThreadFrame.
"""
id_x, id_y, id_z = get_thread_bindings()
ex_x, ex_y, _ = get_thread_extents()
tid = id_z * (ex_y * ex_x) + id_y * ex_x + id_x
# only available for nvidia gpus.
warp_group_size = 128
return _ffi_api.WarpSpecialize(warp_group_idx, tid, warp_group_size)
# Alias for WarpSpecialize for more concise usage
ws = WarpSpecialize
......@@ -314,3 +314,9 @@ def FlattenBuffer():
The result pass
"""
return _ffi_api.FlattenBuffer() # type: ignore
def EliminateStorageSyncForMBarrier():
"""EliminateStorageSyncForMBarrier
"""
return _ffi_api.EliminateStorageSyncForMBarrier() # type: ignore
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment