Commit 6972aed7 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Language] Support explicit programming for identified warp groups (#445)

* [Refactor] Update KernelLaunch to clarify CPU and GPU kernel launch logic

* Added comments to distinguish between CPU and GPU kernel launch sections for better code readability.
* Changed the creation of empty blocks to use a consistent "root" identifier, enhancing clarity in frame management.

* [Refactor] Rename operations for consistency in lower_hopper_intrin and related files

* Updated function names from CamelCase to snake_case for better consistency across the codebase.
* Refactored calls to `CreateTMADescriptorOp`, `CreateListofMBarrierOp`, and similar functions to their new names: `create_tma_descriptor`, `create_list_of_mbarrier`, etc.
* Adjusted corresponding test cases to reflect these changes, ensuring compatibility with the new naming conventions.

* [Refactor] Rename operations to snake_case for consistency

* Updated function names from CamelCase to snake_case across various files, including `CreateTMADescriptorOp` to `create_tma_descriptor`, `GetMBarrierOp` to `get_mbarrier`, and others.
* Adjusted corresponding calls and definitions in the codebase to reflect these naming changes, ensuring uniformity and improved readability.
* Enhanced layout inference and loop partitioning logic to accommodate the new naming conventions.

* [Feature] Introduce Warp Specialization and Eliminate Storage Sync for MBarrier

* Added a new example `gemm_ws.py` demonstrating matrix multiplication with warp specialization using TileLang.
* Implemented `WarpSpecializeFrame` and `WarpSpecialize` functionality to manage warp group indices in TIR frames.
* Introduced `EliminateStorageSyncForMBarrier` transformation to optimize storage synchronization in mbarrier regions.
* Enhanced the TileLang API with new methods for retrieving block and thread extents.
* Updated the `LowerAndLegalize` and `OptimizeForTarget` functions to incorporate the new transformation.
* Improved layout inference and kernel launch logic for better performance and clarity.

* [Refactor] Clean up code formatting and improve readability

* Added blank lines for better separation of code blocks in `gemm_ws.py`, `phase.py`, `kernel.py`, and `warpgroup.py`.
* Reformatted the `tilelang.compile` call in `gemm_ws.py` for improved clarity.
* Updated comments in `warpgroup.py` to clarify the availability of the `WarpSpecialize` function for NVIDIA GPUs.
* Ensured consistent spacing and formatting across multiple files to enhance overall code readability.

* lint fix

* [Refactor] Update mbarrier functions for improved clarity and consistency

* Refactored `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` to accept explicit parameters for better readability.
* Updated calls in `gemm_ws.py` to use the new function signatures, enhancing code clarity.
* Adjusted `warpgroup.py` to remove unused thread extent variable, streamlining the code.
* Added detailed docstrings to clarify usage examples for memory barrier functions.

* Added blank lines in `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` for improved code readability and separation of logical sections.
parent 0fa03398
...@@ -295,9 +295,10 @@ private: ...@@ -295,9 +295,10 @@ private:
auto const_int_bound = analyzer_->const_int_bound(thread_var_); auto const_int_bound = analyzer_->const_int_bound(thread_var_);
auto min_value = const_int_bound->min_value; auto min_value = const_int_bound->min_value;
auto max_value = const_int_bound->max_value; auto max_value = const_int_bound->max_value;
auto extent = max_value + 1 - min_value;
thread_bounds = thread_bounds =
Range::FromMinExtent(IntImm(thread_var_->var.dtype(), min_value), Range::FromMinExtent(IntImm(thread_var_->var.dtype(), min_value),
IntImm(thread_var_->var.dtype(), max_value + 1)); IntImm(thread_var_->var.dtype(), extent));
} else { } else {
thread_bounds = Range::FromMinExtent(0, 1); thread_bounds = Range::FromMinExtent(0, 1);
} }
......
...@@ -53,8 +53,7 @@ public: ...@@ -53,8 +53,7 @@ public:
void VisitStmt_(const EvaluateNode *op) final { void VisitStmt_(const EvaluateNode *op) final {
Role role = Role::kConsumer; Role role = Role::kConsumer;
if (auto call = op->value.as<CallNode>()) { if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(TMALoadOp()) || if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
call->op.same_as(TMALoadIm2ColOp())) {
role = Role::kProducer; role = Role::kProducer;
has_bulk_copy_ = true; has_bulk_copy_ = true;
} }
......
...@@ -328,8 +328,8 @@ public: ...@@ -328,8 +328,8 @@ public:
if (partial_syncs_.count(stmt.get())) { if (partial_syncs_.count(stmt.get())) {
auto iter = partial_syncs_.find(stmt.get()); auto iter = partial_syncs_.find(stmt.get());
ICHECK(sync_scope_.rank == StorageRank::kShared); ICHECK(sync_scope_.rank == StorageRank::kShared);
barrier = Evaluate(Call(DataType::Int(32), tl::SyncThreadsPartialOp(), barrier = Evaluate(
{iter->second})); Call(DataType::Int(32), tl::sync_thread_partial(), {iter->second}));
} else { } else {
return StmtExprMutator::VisitStmt(stmt); return StmtExprMutator::VisitStmt(stmt);
} }
......
...@@ -43,7 +43,7 @@ public: ...@@ -43,7 +43,7 @@ public:
void clear() { has_tma_load_ = false; } void clear() { has_tma_load_ = false; }
void VisitExpr_(const CallNode *call) final { void VisitExpr_(const CallNode *call) final {
if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) { if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
has_tma_load_ = true; has_tma_load_ = true;
} }
} }
...@@ -116,8 +116,7 @@ public: ...@@ -116,8 +116,7 @@ public:
void VisitStmt_(const EvaluateNode *op) final { void VisitStmt_(const EvaluateNode *op) final {
Role role = Role::kConsumer; Role role = Role::kConsumer;
if (auto call = op->value.as<CallNode>()) { if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(TMALoadOp()) || if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
call->op.same_as(TMALoadIm2ColOp())) {
role = Role::kProducer; role = Role::kProducer;
has_bulk_copy_ = true; has_bulk_copy_ = true;
} }
...@@ -207,11 +206,11 @@ private: ...@@ -207,11 +206,11 @@ private:
}; };
static PrimExpr makeGetBarrier(PrimExpr barrier_id) { static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
return Call(DataType::Handle(), GetMBarrierOp(), {barrier_id}); return Call(DataType::Handle(), get_mbarrier(), {barrier_id});
} }
static Stmt makeExpectTX(PrimExpr barrier_id, PrimExpr bytes) { static Stmt makeExpectTX(PrimExpr barrier_id, PrimExpr bytes) {
auto call = Call(DataType::Handle(), MBarrierExpectTX(), auto call = Call(DataType::Handle(), mbarrier_expect_tx(),
{makeGetBarrier(barrier_id), bytes}); {makeGetBarrier(barrier_id), bytes});
return Evaluate(call); return Evaluate(call);
} }
...@@ -229,7 +228,7 @@ static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) { ...@@ -229,7 +228,7 @@ static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) {
} }
static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) { static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) {
auto call = Call(DataType::Handle(), MBarrierWaitParity(), auto call = Call(DataType::Handle(), mbarrier_wait_parity(),
{makeGetBarrier(barrier_id), parity}); {makeGetBarrier(barrier_id), parity});
return Evaluate(call); return Evaluate(call);
} }
...@@ -273,8 +272,7 @@ private: ...@@ -273,8 +272,7 @@ private:
Stmt VisitStmt_(const EvaluateNode *op) final { Stmt VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) { if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(TMALoadOp()) || if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
call->op.same_as(TMALoadIm2ColOp())) {
contain_tma_load_ = true; contain_tma_load_ = true;
if (insert_in_evaluate_) { if (insert_in_evaluate_) {
Array<Stmt> new_seq = {expect_tx_, GetRef<Evaluate>(op)}; Array<Stmt> new_seq = {expect_tx_, GetRef<Evaluate>(op)};
...@@ -308,7 +306,7 @@ public: ...@@ -308,7 +306,7 @@ public:
private: private:
void VisitExpr_(const CallNode *call) final { void VisitExpr_(const CallNode *call) final {
if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) { if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
Call access_ptr = Downcast<Call>(call->args[2]); Call access_ptr = Downcast<Call>(call->args[2]);
ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr())); ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
int type_bytes = access_ptr->args[0]->dtype.bytes(); int type_bytes = access_ptr->args[0]->dtype.bytes();
...@@ -361,7 +359,7 @@ public: ...@@ -361,7 +359,7 @@ public:
private: private:
PrimExpr VisitExpr_(const CallNode *op) final { PrimExpr VisitExpr_(const CallNode *op) final {
auto call = Downcast<Call>(StmtExprMutator::VisitExpr_(op)); auto call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) { if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
Call access_ptr = Downcast<Call>(call->args[2]); Call access_ptr = Downcast<Call>(call->args[2]);
ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr())); ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
call.CopyOnWrite()->args.Set(1, makeGetBarrier(producer_barrier_idx_)); call.CopyOnWrite()->args.Set(1, makeGetBarrier(producer_barrier_idx_));
...@@ -1082,7 +1080,7 @@ public: ...@@ -1082,7 +1080,7 @@ public:
private: private:
void VisitStmt_(const EvaluateNode *op) final { void VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) { if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(SetMaxNReg())) { if (call->op.same_as(set_max_nreg())) {
int reg_hint = call->args[0].as<IntImmNode>()->value; int reg_hint = call->args[0].as<IntImmNode>()->value;
int is_inc = call->args[1].as<IntImmNode>()->value; int is_inc = call->args[1].as<IntImmNode>()->value;
ICHECK(reg_hint <= 240 && reg_hint >= 24) ICHECK(reg_hint <= 240 && reg_hint >= 24)
...@@ -1092,7 +1090,7 @@ private: ...@@ -1092,7 +1090,7 @@ private:
// producer should decrease register hint while consumer should increase // producer should decrease register hint while consumer should increase
// register hint // register hint
nreg_.Set(is_inc, IntImm(DataType::Int(32), reg_hint)); nreg_.Set(is_inc, IntImm(DataType::Int(32), reg_hint));
} else if (call->op.same_as(NoSetMaxNReg())) { } else if (call->op.same_as(no_set_max_nreg())) {
has_no_set_max_nreg_ = true; has_no_set_max_nreg_ = true;
} }
} }
...@@ -1149,7 +1147,8 @@ private: ...@@ -1149,7 +1147,8 @@ private:
Stmt VisitStmt_(const EvaluateNode *op) final { Stmt VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) { if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(SetMaxNReg()) || call->op.same_as(NoSetMaxNReg())) { if (call->op.same_as(set_max_nreg()) ||
call->op.same_as(no_set_max_nreg())) {
return Evaluate(0); return Evaluate(0);
} }
} }
...@@ -1202,7 +1201,7 @@ private: ...@@ -1202,7 +1201,7 @@ private:
barrier_num_threads.push_back(arrive_thread_count); barrier_num_threads.push_back(arrive_thread_count);
} }
Stmt init_barrier = Evaluate(Call( Stmt init_barrier = Evaluate(Call(
DataType::Handle(), CreateListofMBarrierOp(), barrier_num_threads)); DataType::Handle(), create_list_of_mbarrier(), barrier_num_threads));
block.CopyOnWrite()->body = SeqStmt({init_barrier, code}); block.CopyOnWrite()->body = SeqStmt({init_barrier, code});
block_realize.CopyOnWrite()->block = block; block_realize.CopyOnWrite()->block = block;
return block_realize; return block_realize;
...@@ -1224,9 +1223,9 @@ private: ...@@ -1224,9 +1223,9 @@ private:
auto inc_reg_stmt = Evaluate(0); auto inc_reg_stmt = Evaluate(0);
auto dec_reg_stmt = Evaluate(0); auto dec_reg_stmt = Evaluate(0);
if (dec_reg >= 0 && inc_reg >= 0) { if (dec_reg >= 0 && inc_reg >= 0) {
inc_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(), inc_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
{inc_reg == 0 ? 240 : inc_reg, 1})); {inc_reg == 0 ? 240 : inc_reg, 1}));
dec_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(), dec_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
{dec_reg == 0 ? 24 : dec_reg, 0})); {dec_reg == 0 ? 24 : dec_reg, 0}));
} }
...@@ -1252,7 +1251,7 @@ private: ...@@ -1252,7 +1251,7 @@ private:
} }
Stmt init_barrier = Evaluate(Call( Stmt init_barrier = Evaluate(Call(
DataType::Handle(), CreateListofMBarrierOp(), barrier_num_threads)); DataType::Handle(), create_list_of_mbarrier(), barrier_num_threads));
Stmt body = IfThenElse(GE(thread_iv_->var, consumer_thread_extent), Stmt body = IfThenElse(GE(thread_iv_->var, consumer_thread_extent),
producer_code, consumer_code); producer_code, consumer_code);
// Add an attr here to handle the partial thread count in ThreadSync pass. // Add an attr here to handle the partial thread count in ThreadSync pass.
......
...@@ -44,7 +44,7 @@ def test_lower_fence_proxy(): ...@@ -44,7 +44,7 @@ def test_lower_fence_proxy():
C_local = T.decl_buffer((32,), scope="local") C_local = T.decl_buffer((32,), scope="local")
for i in T.unroll(16): for i in T.unroll(16):
C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2) C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2)
T.FenceProxyAsyncOp() T.fence_proxy_async()
T.call_extern("handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", T.call_extern("handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
......
...@@ -29,7 +29,7 @@ def test_lower_hopper_intrin_barrier(): ...@@ -29,7 +29,7 @@ def test_lower_hopper_intrin_barrier():
def before(): def before():
with T.Kernel(8): with T.Kernel(8):
_ = T.launch_thread("threadIdx.x", 128) _ = T.launch_thread("threadIdx.x", 128)
T.CreateListofMBarrierOp(128, 128, 128, 128) T.create_list_of_mbarrier(128, 128, 128, 128)
@T.prim_func @T.prim_func
def after(): def after():
...@@ -39,16 +39,16 @@ def test_lower_hopper_intrin_barrier(): ...@@ -39,16 +39,16 @@ def test_lower_hopper_intrin_barrier():
with T.If(v_1 == 0), T.Then(): with T.If(v_1 == 0), T.Then():
T.evaluate( T.evaluate(
tir.Call("handle", "tir.ptx_init_barrier_thread_count", tir.Call("handle", "tir.ptx_init_barrier_thread_count",
[T.GetMBarrierOp(0), 128])) [T.get_mbarrier(0), 128]))
T.evaluate( T.evaluate(
tir.Call("handle", "tir.ptx_init_barrier_thread_count", tir.Call("handle", "tir.ptx_init_barrier_thread_count",
[T.GetMBarrierOp(1), 128])) [T.get_mbarrier(1), 128]))
T.evaluate( T.evaluate(
tir.Call("handle", "tir.ptx_init_barrier_thread_count", tir.Call("handle", "tir.ptx_init_barrier_thread_count",
[T.GetMBarrierOp(2), 128])) [T.get_mbarrier(2), 128]))
T.evaluate( T.evaluate(
tir.Call("handle", "tir.ptx_init_barrier_thread_count", tir.Call("handle", "tir.ptx_init_barrier_thread_count",
[T.GetMBarrierOp(3), 128])) [T.get_mbarrier(3), 128]))
T.evaluate(tir.Call("handle", "tir.tvm_storage_sync", ["shared"])) T.evaluate(tir.Call("handle", "tir.tvm_storage_sync", ["shared"]))
_check(before, after) _check(before, after)
......
...@@ -48,14 +48,14 @@ def test_multi_version_buffer(): ...@@ -48,14 +48,14 @@ def test_multi_version_buffer():
C_local[i * 2 + vec] = T.float32(0) C_local[i * 2 + vec] = T.float32(0)
for k in T.serial(16, annotations={"num_stages": 3}): for k in T.serial(16, annotations={"num_stages": 3}):
if v == 0: if v == 0:
T.TMALoadOp( T.tma_load(
T.CreateTMADescriptorOp(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2,
2, 0), 0, 2, 0), 0,
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 2), T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 2),
k * 32, by * 64) k * 32, by * 64)
if v == 0: if v == 0:
T.TMALoadOp( T.tma_load(
T.CreateTMADescriptorOp(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3,
2, 0), 0, 2, 0), 0,
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 2), T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 2),
bx * 64, k * 32) bx * 64, k * 32)
...@@ -81,15 +81,15 @@ def test_multi_version_buffer(): ...@@ -81,15 +81,15 @@ def test_multi_version_buffer():
C_local[i * 2 + vec] = T.float32(0) C_local[i * 2 + vec] = T.float32(0)
for k in T.serial(16, annotations={"num_stages": 3}): for k in T.serial(16, annotations={"num_stages": 3}):
if v == 0: if v == 0:
T.TMALoadOp( T.tma_load(
T.CreateTMADescriptorOp(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2,
2, 0), 0, 2, 0), 0,
T.tvm_access_ptr( T.tvm_access_ptr(
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32, by * 64) k * 32, by * 64)
if v == 0: if v == 0:
T.TMALoadOp( T.tma_load(
T.CreateTMADescriptorOp(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3,
2, 0), 0, 2, 0), 0,
T.tvm_access_ptr( T.tvm_access_ptr(
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2),
......
...@@ -46,15 +46,15 @@ def test_warp_specialized(): ...@@ -46,15 +46,15 @@ def test_warp_specialized():
C_local = T.alloc_buffer((32,), scope="local") C_local = T.alloc_buffer((32,), scope="local")
for k in T.serial(16, annotations={"num_stages": 3}): for k in T.serial(16, annotations={"num_stages": 3}):
if v == 0: if v == 0:
T.TMALoadOp( T.tma_load(
T.CreateTMADescriptorOp(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2,
2, 0), 0, 2, 0), 0,
T.tvm_access_ptr( T.tvm_access_ptr(
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32, by * 64) k * 32, by * 64)
if v == 0: if v == 0:
T.TMALoadOp( T.tma_load(
T.CreateTMADescriptorOp(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3,
2, 0), 0, 2, 0), 0,
T.tvm_access_ptr( T.tvm_access_ptr(
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2),
...@@ -75,35 +75,35 @@ def test_warp_specialized(): ...@@ -75,35 +75,35 @@ def test_warp_specialized():
A_shared = T.decl_buffer((3, 1, 8, 256), "float16", scope="shared.dyn") A_shared = T.decl_buffer((3, 1, 8, 256), "float16", scope="shared.dyn")
B_shared = T.decl_buffer((3, 1, 4, 512), "float16", scope="shared.dyn") B_shared = T.decl_buffer((3, 1, 4, 512), "float16", scope="shared.dyn")
C_local = T.decl_buffer((32,), scope="local") C_local = T.decl_buffer((32,), scope="local")
T.CreateListofMBarrierOp(128, 128, 128, 128, 128, 128) T.create_list_of_mbarrier(128, 128, 128, 128, 128, 128)
T.attr([128, 128], "kWarpSpecializationScope", 0) T.attr([128, 128], "kWarpSpecializationScope", 0)
if v >= 128: if v >= 128:
T.SetMaxNReg(24, 0) T.set_max_nreg(24, 0)
for k in range(16): for k in range(16):
T.MBarrierWaitParity(T.GetMBarrierOp(k % 3 + 3), T.bitwise_xor(k // 3 % 2, 1)) T.mbarrier_wait_parity(T.get_mbarrier(k % 3 + 3), T.bitwise_xor(k // 3 % 2, 1))
if v - 128 == 0: if v - 128 == 0:
T.MBarrierExpectTX(T.GetMBarrierOp(k % 3), 4096) T.mbarrier_expect_tx(T.get_mbarrier(k % 3), 4096)
if v - 128 == 0: if v - 128 == 0:
T.TMALoadOp( T.tma_load(
T.CreateTMADescriptorOp(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2,
2, 0), T.GetMBarrierOp(k % 3), 2, 0), T.get_mbarrier(k % 3),
T.tvm_access_ptr( T.tvm_access_ptr(
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32, by * 64) k * 32, by * 64)
if v - 128 == 0: if v - 128 == 0:
T.MBarrierExpectTX(T.GetMBarrierOp(k % 3), 4096) T.mbarrier_expect_tx(T.get_mbarrier(k % 3), 4096)
if v - 128 == 0: if v - 128 == 0:
T.TMALoadOp( T.tma_load(
T.CreateTMADescriptorOp(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3,
2, 0), T.GetMBarrierOp(k % 3), 2, 0), T.get_mbarrier(k % 3),
T.tvm_access_ptr( T.tvm_access_ptr(
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2),
bx * 64, k * 32) bx * 64, k * 32)
T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.GetMBarrierOp(k % 3)])) T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3)]))
else: else:
T.SetMaxNReg(240, 1) T.set_max_nreg(240, 1)
for k in range(16): for k in range(16):
T.MBarrierWaitParity(T.GetMBarrierOp(k % 3), k // 3 % 2) T.mbarrier_wait_parity(T.get_mbarrier(k % 3), k // 3 % 2)
T.call_extern( T.call_extern(
"handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", "handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr( T.tvm_access_ptr(
...@@ -112,7 +112,7 @@ def test_warp_specialized(): ...@@ -112,7 +112,7 @@ def test_warp_specialized():
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))
T.evaluate( T.evaluate(
tir.Call("handle", "tir.ptx_arrive_barrier", [T.GetMBarrierOp(k % 3 + 3)])) tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)]))
_check(before, after) _check(before, after)
......
...@@ -9,7 +9,7 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None, ...@@ -9,7 +9,7 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
target: Optional[Target] = None) -> bool: target: Optional[Target] = None) -> bool:
if pass_ctx is None: if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context() pass_ctx = tilelang.transform.get_pass_context()
if target.arch not in {"sm_90"}: if target.arch not in {"sm_90", "sm_90a"}:
return False return False
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False) disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False) disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
...@@ -17,6 +17,10 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None, ...@@ -17,6 +17,10 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
return not (disable_tma_lower and disable_warp_specialized) return not (disable_tma_lower and disable_warp_specialized)
def allow_fence_proxy(target: Optional[Target] = None) -> bool:
return target.arch in {"sm_90", "sm_90a"}
def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool: def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool:
if pass_ctx is None: if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context() pass_ctx = tilelang.transform.get_pass_context()
...@@ -60,6 +64,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -60,6 +64,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# to get better performance with async copy # to get better performance with async copy
mod = tilelang.transform.PipelinePlanning()(mod) mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod) mod = tilelang.transform.InjectSoftwarePipeline()(mod)
# warp_specialized pass will pack the if stmt into the block
# so we need to lower the opaque block first
mod = tir.transform.LowerOpaqueBlock()(mod) mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.MergeIfStmt()(mod) mod = tilelang.transform.MergeIfStmt()(mod)
mod = tilelang.transform.RewriteWgmmaSync()(mod) mod = tilelang.transform.RewriteWgmmaSync()(mod)
...@@ -71,6 +77,11 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -71,6 +77,11 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.InjectSoftwarePipeline()(mod) mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tilelang.transform.MergeIfStmt()(mod) mod = tilelang.transform.MergeIfStmt()(mod)
if allow_fence_proxy(target=target):
# in hopper device, wgmma is an async proxy
# so we need to inject a fence proxy before it
mod = tilelang.transform.InjectFenceProxy()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod) mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.FlattenBuffer()(mod) mod = tilelang.transform.FlattenBuffer()(mod)
mod = tir.transform.NarrowDataType(32)(mod) mod = tir.transform.NarrowDataType(32)(mod)
...@@ -104,6 +115,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -104,6 +115,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.ConfigIndexBitwidth()(mod) mod = tilelang.transform.ConfigIndexBitwidth()(mod)
mod = tilelang.transform.ThreadSync("shared")(mod) mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod) mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
mod = tilelang.transform.EliminateStorageSyncForMBarrier()(mod)
mod = tilelang.transform.InjectPTXAsyncCopy()(mod) mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
mod = tilelang.transform.AnnotateDeviceRegions()(mod) mod = tilelang.transform.AnnotateDeviceRegions()(mod)
......
...@@ -259,6 +259,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -259,6 +259,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
result = self.lib.init() result = self.lib.init()
if result != 0: if result != 0:
error_msg = self.lib.get_last_error().decode('utf-8') error_msg = self.lib.get_last_error().decode('utf-8')
error_msg += f"\n{self.lib_code}"
raise RuntimeError(f"Initialization failed: {error_msg}") raise RuntimeError(f"Initialization failed: {error_msg}")
self.cython_wrapper = CythonKernelWrapper(self.result_idx, self.params, self.lib) self.cython_wrapper = CythonKernelWrapper(self.result_idx, self.params, self.lib)
......
...@@ -32,6 +32,7 @@ from .kernel import ( ...@@ -32,6 +32,7 @@ from .kernel import (
get_block_binding, # noqa: F401 get_block_binding, # noqa: F401
get_block_bindings, # noqa: F401 get_block_bindings, # noqa: F401
) )
from .warpgroup import ws # noqa: F401
from .allocate import ( from .allocate import (
alloc_local, # noqa: F401 alloc_local, # noqa: F401
alloc_shared, # noqa: F401 alloc_shared, # noqa: F401
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from tilelang import tvm as tvm
from tilelang.language import ptx_arrive_barrier
from tvm import tir from tvm import tir
from typing import Union
from tvm.tir import PrimExpr, Var
def CreateListofMBarrierOp(*args): def create_list_of_mbarrier(*args):
"""Create a list of memory barrier operations. """Create a list of memory barrier operations.
Args: Args:
...@@ -12,10 +16,10 @@ def CreateListofMBarrierOp(*args): ...@@ -12,10 +16,10 @@ def CreateListofMBarrierOp(*args):
Returns: Returns:
tir.Call: A handle to the created list of memory barriers tir.Call: A handle to the created list of memory barriers
""" """
return tir.call_intrin("handle", tir.op.Op.get("tl.CreateListofMBarrierOp"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.create_list_of_mbarrier"), *args)
def GetMBarrierOp(*args): def get_mbarrier(*args):
"""Retrieve a memory barrier operation. """Retrieve a memory barrier operation.
Args: Args:
...@@ -24,10 +28,10 @@ def GetMBarrierOp(*args): ...@@ -24,10 +28,10 @@ def GetMBarrierOp(*args):
Returns: Returns:
tir.Call: A handle to the requested memory barrier tir.Call: A handle to the requested memory barrier
""" """
return tir.call_intrin("handle", tir.op.Op.get("tl.GetMBarrierOp"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.get_mbarrier"), *args)
def CreateTMADescriptorOp(*args): def create_tma_descriptor(*args):
"""Create a Tensor Memory Access (TMA) descriptor. """Create a Tensor Memory Access (TMA) descriptor.
Args: Args:
...@@ -36,10 +40,10 @@ def CreateTMADescriptorOp(*args): ...@@ -36,10 +40,10 @@ def CreateTMADescriptorOp(*args):
Returns: Returns:
tir.Call: A handle to the created TMA descriptor tir.Call: A handle to the created TMA descriptor
""" """
return tir.call_intrin("handle", tir.op.Op.get("tl.CreateTMADescriptorOp"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.create_tma_descriptor"), *args)
def TMALoadOp(*args): def tma_load(*args):
"""Perform a Tensor Memory Access (TMA) load operation. """Perform a Tensor Memory Access (TMA) load operation.
Args: Args:
...@@ -48,10 +52,10 @@ def TMALoadOp(*args): ...@@ -48,10 +52,10 @@ def TMALoadOp(*args):
Returns: Returns:
tir.Call: A handle to the TMA load operation tir.Call: A handle to the TMA load operation
""" """
return tir.call_intrin("handle", tir.op.Op.get("tl.TMALoadOp"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.tma_load"), *args)
def FenceProxyAsyncOp(*args): def fence_proxy_async(*args):
"""Create a fence for asynchronous proxy operations. """Create a fence for asynchronous proxy operations.
Args: Args:
...@@ -60,10 +64,10 @@ def FenceProxyAsyncOp(*args): ...@@ -60,10 +64,10 @@ def FenceProxyAsyncOp(*args):
Returns: Returns:
tir.Call: A handle to the fence operation tir.Call: A handle to the fence operation
""" """
return tir.call_intrin("handle", tir.op.Op.get("tl.FenceProxyAsyncOp"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.fence_proxy_async"), *args)
def TMAStoreArrive(*args): def tma_store_arrive(*args):
"""Signal the arrival of a TMA store operation. """Signal the arrival of a TMA store operation.
Args: Args:
...@@ -72,10 +76,10 @@ def TMAStoreArrive(*args): ...@@ -72,10 +76,10 @@ def TMAStoreArrive(*args):
Returns: Returns:
tir.Call: A handle to the store arrive operation tir.Call: A handle to the store arrive operation
""" """
return tir.call_intrin("handle", tir.op.Op.get("tl.TMAStoreArrive"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.tma_store_arrive"), *args)
def TMAStoreWait(*args): def tma_store_wait(*args):
"""Wait for completion of TMA store operations. """Wait for completion of TMA store operations.
Args: Args:
...@@ -84,10 +88,10 @@ def TMAStoreWait(*args): ...@@ -84,10 +88,10 @@ def TMAStoreWait(*args):
Returns: Returns:
tir.Call: A handle to the store wait operation tir.Call: A handle to the store wait operation
""" """
return tir.call_intrin("handle", tir.op.Op.get("tl.TMAStoreWait"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.tma_store_wait"), *args)
def SetMaxNReg(*args): def set_max_nreg(*args):
"""Set the maximum number of registers to use. """Set the maximum number of registers to use.
Args: Args:
...@@ -96,10 +100,10 @@ def SetMaxNReg(*args): ...@@ -96,10 +100,10 @@ def SetMaxNReg(*args):
Returns: Returns:
tir.Call: A handle to the register setting operation tir.Call: A handle to the register setting operation
""" """
return tir.call_intrin("handle", tir.op.Op.get("tl.SetMaxNReg"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.set_max_nreg"), *args)
def NoSetMaxNReg(*args): def no_set_max_nreg(*args):
"""Disable the maximum register limit setting. """Disable the maximum register limit setting.
Args: Args:
...@@ -108,22 +112,66 @@ def NoSetMaxNReg(*args): ...@@ -108,22 +112,66 @@ def NoSetMaxNReg(*args):
Returns: Returns:
tir.Call: A handle to the register limit disable operation tir.Call: A handle to the register limit disable operation
""" """
return tir.call_intrin("handle", tir.op.Op.get("tl.NoSetMaxNReg"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.no_set_max_nreg"), *args)
def MBarrierWaitParity(*args): def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr], parity: Union[int, Var]):
"""Wait for memory barrier parity condition. """Wait for memory barrier parity condition.
Args: Args:
*args: Variable arguments specifying the parity wait condition mbarrier: Optional[int, PrimExpr]
The memory barrier to wait on
parity: Optional[int, Var]
The parity value to wait for
Examples:
.. code-block:: python
# Wait for parity 0 on barrier 0
T.mbarrier_wait_parity(0, 0)
# Wait for parity value in variable ko on barrier 1
T.mbarrier_wait_parity(1, ko)
# Wait using barrier handle
barrier = T.get_mbarrier(0)
T.mbarrier_wait_parity(barrier, 1)
# Common usage in pipelined kernels:
for ko in range(num_stages):
# Producer waits for consumer to finish previous iteration
T.mbarrier_wait_parity(1, ko ^ 1)
# Producer copies data
T.copy(A_global, A_shared)
# Producer signals data ready
T.mbarrier_arrive(0)
# Consumer waits for producer data
T.mbarrier_wait_parity(0, ko)
# Consumer computes
T.gemm(A_shared, B_shared, C_local)
# Consumer signals completion
T.mbarrier_arrive(1)
Returns: Returns:
tir.Call: A handle to the barrier wait operation tir.Call: A handle to the barrier wait operation
""" """
return tir.call_intrin("handle", tir.op.Op.get("tl.MBarrierWaitParity"), *args) if isinstance(mbarrier, int):
mbarrier = get_mbarrier(mbarrier)
return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_wait_parity"), mbarrier, parity)
def mbarrier_arrive(mbarrier: Union[int, PrimExpr]):
"""Arrive at memory barrier.
Args:
mbarrier: Optional[int, PrimExpr]
The memory barrier to arrive at
"""
if isinstance(mbarrier, int):
mbarrier = get_mbarrier(mbarrier)
return ptx_arrive_barrier(mbarrier)
def MBarrierExpectTX(*args): def mbarrier_expect_tx(*args):
"""Set expected transaction count for memory barrier. """Set expected transaction count for memory barrier.
Args: Args:
...@@ -132,10 +180,10 @@ def MBarrierExpectTX(*args): ...@@ -132,10 +180,10 @@ def MBarrierExpectTX(*args):
Returns: Returns:
tir.Call: A handle to the barrier expectation operation tir.Call: A handle to the barrier expectation operation
""" """
return tir.call_intrin("handle", tir.op.Op.get("tl.MBarrierExpectTX"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_expect_tx"), *args)
def WaitWgmma(*args): def wait_wgmma(*args):
"""Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete. """Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.
Args: Args:
...@@ -144,4 +192,4 @@ def WaitWgmma(*args): ...@@ -144,4 +192,4 @@ def WaitWgmma(*args):
Returns: Returns:
tir.Call: A handle to the WGMMA wait operation tir.Call: A handle to the WGMMA wait operation
""" """
return tir.call_intrin("handle", tir.op.Op.get("tl.WaitWgmma"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.wait_wgmma"), *args)
...@@ -128,6 +128,12 @@ class KernelLaunchFrame(TIRFrame): ...@@ -128,6 +128,12 @@ class KernelLaunchFrame(TIRFrame):
iter_var = self.frames[dim].iter_var iter_var = self.frames[dim].iter_var
return int(iter_var.dom.extent) return int(iter_var.dom.extent)
def get_block_extents(self) -> List[int]:
"""
Returns the block extents for all three dimensions.
"""
return [self.get_block_extent(dim) for dim in range(3)]
def get_thread_extent(self, dim: int) -> int: def get_thread_extent(self, dim: int) -> int:
""" """
Returns the thread extent for the given dimension. Returns the thread extent for the given dimension.
...@@ -136,6 +142,12 @@ class KernelLaunchFrame(TIRFrame): ...@@ -136,6 +142,12 @@ class KernelLaunchFrame(TIRFrame):
iter_var = self.frames[-4 + dim].iter_var iter_var = self.frames[-4 + dim].iter_var
return int(iter_var.dom.extent) return int(iter_var.dom.extent)
def get_thread_extents(self) -> List[int]:
"""
Returns the thread extents for all three dimensions.
"""
return [self.get_thread_extent(dim) for dim in range(3)]
def get_thread_binding(self, dim: int = 0) -> Var: def get_thread_binding(self, dim: int = 0) -> Var:
""" """
Returns the thread binding for the given dimension. Returns the thread binding for the given dimension.
...@@ -268,3 +280,27 @@ def get_block_bindings() -> List[Var]: ...@@ -268,3 +280,27 @@ def get_block_bindings() -> List[Var]:
"""Returns all three block bindings. """Returns all three block bindings.
""" """
return KernelLaunchFrame.Current().get_block_bindings() return KernelLaunchFrame.Current().get_block_bindings()
def get_thread_extent(dim: int = 0) -> int:
"""Returns the thread extent for the given dimension.
"""
return KernelLaunchFrame.Current().get_thread_extent(dim)
def get_thread_extents() -> List[int]:
"""Returns all three thread extents.
"""
return KernelLaunchFrame.Current().get_thread_extents()
def get_block_extent(dim: int = 0) -> int:
"""Returns the block extent for the given dimension.
"""
return KernelLaunchFrame.Current().get_block_extent(dim)
def get_block_extents() -> List[int]:
"""Returns all three block extents.
"""
return KernelLaunchFrame.Current().get_block_extents()
"""The language interface for tl programs."""
from tvm.script.ir_builder.tir.frame import TIRFrame
from tvm._ffi import register_object
from tilelang import _ffi_api
from .kernel import get_thread_bindings, get_thread_extents
@register_object("tl.WarpSpecializeFrame")
class WarpSpecializeFrame(TIRFrame):
"""
WarpSpecializeFrame is a custom TIRFrame that manages warp group indices
and handles the entry and exit of the kernel launch scope.
"""
def WarpSpecialize(warp_group_idx: int,):
"""Tools to construct a warp group frame.
Parameters
----------
warp_group_idx : int
A integer representing warp group index
Or a list of integers representing blockDim.(x|y|z)
if the value is -1, we skip the threadIdx.x binding.
Returns
-------
res : Tuple[frame.LaunchThreadFrame]
The result LaunchThreadFrame.
"""
id_x, id_y, id_z = get_thread_bindings()
ex_x, ex_y, _ = get_thread_extents()
tid = id_z * (ex_y * ex_x) + id_y * ex_x + id_x
# only available for nvidia gpus.
warp_group_size = 128
return _ffi_api.WarpSpecialize(warp_group_idx, tid, warp_group_size)
# Alias for WarpSpecialize for more concise usage
ws = WarpSpecialize
...@@ -314,3 +314,9 @@ def FlattenBuffer(): ...@@ -314,3 +314,9 @@ def FlattenBuffer():
The result pass The result pass
""" """
return _ffi_api.FlattenBuffer() # type: ignore return _ffi_api.FlattenBuffer() # type: ignore
def EliminateStorageSyncForMBarrier():
"""EliminateStorageSyncForMBarrier
"""
return _ffi_api.EliminateStorageSyncForMBarrier() # type: ignore
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment