Unverified Commit cb37bfef authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Refactor barrier management (#744)

* Introduce Barrier

* Enhance CUDA kernel with new barrier management and post-processing support

- Added a new CUDA kernel implementation in `example_mla_decode.py` for improved performance with shared memory barriers.
- Refactored barrier handling in `codegen_cuda.cc` and `codegen_hip.cc` to utilize a more flexible mbarrier structure.
- Updated intrinsic definitions from `ptx_stmatirx` to `ptx_stmatrix` across multiple files for consistency.
- Introduced additional print statements for debugging in the lowering phase of the TileLang engine.
- Enhanced the overall structure and readability of the codebase.

* Remove unused barrier handling code in CUDA and HIP code generators to streamline the implementation. This change enhances code clarity and reduces complexity in the barrier management logic.

* Enhance barrier management in TileLang

- Introduced a new intrinsic `allocate_barrier` for dynamic barrier allocation in the TileLang framework.
- Updated CUDA code generation to support the new barrier structure, allowing for improved synchronization in shared memory.
- Refactored existing barrier handling logic to accommodate the new intrinsic and streamline code.
- Added print statements for debugging purposes in various examples and the lowering phase of the TileLang engine.
- Removed deprecated memory scope handling code to enhance clarity and maintainability.

* lint fix

* lint fix

* Remove `allocate_barrier` intrinsic and related code from TileLang to streamline barrier management. This includes updates to CUDA code generation and the removal of associated Python wrappers, enhancing code clarity and maintainability.

* Refactor logging in JITKernel to improve kernel compilation tracking

- Removed unused import of `torch.backends` in the example file.
- Introduced logging for kernel compilation in `JITKernel`, replacing print statements with structured logging for better traceability and debugging.
- Added an assertion to ensure the presence of the `global_symbol` attribute in the kernel function.

* Refactor dequantization tests and update barrier function

- Removed the test for `example_dequant_gemm_bf16_fp4_hopper_serial` to streamline the testing suite.
- Updated the `mbarrier_cp_async_arrive` function to support both pointer and non-pointer types, enhancing flexibility in barrier management.

* Update CI configuration to increase pytest parallelism from 4 to 8 threads for improved test execution speed.

* Fix typos in rasterization parameters and update import path for cached module

- Corrected the spelling of `enable_rasteration` to `enable_rasterization` in the matmul function and its usage.
- Updated the import statement for the `cached` module to reflect the new path in the cache submodule.
- Added `StridedTensor` import in the language module for enhanced tensor functionality.

* Update ci.yml
parent eccdfe17
......@@ -2,7 +2,6 @@ import tilelang.testing
import example_dequant_gemv_fp16xint4
import example_dequant_gemm_fp4_hopper
import example_dequant_gemm_bf16_fp4_hopper_serial
@tilelang.testing.requires_cuda
......@@ -16,11 +15,5 @@ def test_example_dequant_gemm_fp4_hopper():
example_dequant_gemm_fp4_hopper.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_gemm_bf16_fp4_hopper_serial():
example_dequant_gemm_bf16_fp4_hopper_serial.main()
if __name__ == "__main__":
tilelang.testing.main()
import torch
import torch.backends
from tilelang import tvm as tvm
import tilelang.testing
from tvm import DataType
......
......@@ -391,6 +391,7 @@ def main(batch=1, heads=128, kv_heads=1, kv_ctx=8192, dim=512, pe_dim=64):
num_split = 1
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
latency = profiler.do_bench(warmup=500)
......
......@@ -66,7 +66,6 @@ def main():
# Run the kernel through the Profiler
c = jit_kernel(a, b)
# Reference multiplication using PyTorch
ref_c = a @ b
......
......@@ -83,7 +83,7 @@ TIR_DEFINE_TL_BUILTIN(ptx_ldmatirx)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_stmatirx)
TIR_DEFINE_TL_BUILTIN(ptx_stmatrix)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
......
......@@ -62,7 +62,7 @@ static constexpr const char *kDynamicAlignment = "tl.dynamic_alignment";
* swizzle, l2_promotion, oob_fill)
*
*/
const Op &create_tma_descriptor();
TVM_DLL const Op &create_tma_descriptor();
/*!
* \brief tvm intrinsics for TMADescriptor creation for image to column load
......@@ -73,7 +73,7 @@ const Op &create_tma_descriptor();
* l2_promotion, oob_fill)
*
*/
const Op &create_tma_im2col_descriptor();
TVM_DLL const Op &create_tma_im2col_descriptor();
/*!
* \brief Create a list of mbarrier with num_threads
......@@ -81,7 +81,7 @@ const Op &create_tma_im2col_descriptor();
* create_list_of_mbarrier(num_threads0, num_threads1, ...)
*
*/
const Op &create_list_of_mbarrier();
TVM_DLL const Op &create_list_of_mbarrier();
/*!
* \brief Get the mbarrier with barrier_id
......@@ -89,7 +89,7 @@ const Op &create_list_of_mbarrier();
* int64_t* GetMBarrier(barrier_id)
*
*/
const Op &get_mbarrier();
TVM_DLL const Op &get_mbarrier();
/*!
* \brief tvm intrinsics for loading data from global tensor descriptor to
......@@ -98,7 +98,7 @@ const Op &get_mbarrier();
* tma_load(descriptor, mbarrier, smem_data, coord_0, coord_1, ...)
*
*/
const Op &tma_load();
TVM_DLL const Op &tma_load();
/*!
* \brief tvm intrinsics for loading image from global tensor to columns in
......@@ -108,7 +108,7 @@ const Op &tma_load();
* image_offset, ...)
*
*/
const Op &tma_load_im2col();
TVM_DLL const Op &tma_load_im2col();
/*!
* \brief tvm intrinsics for storing data from shared memory to global tensor
......@@ -117,7 +117,7 @@ const Op &tma_load_im2col();
* tma_store(descriptor, smem_data, coord_0, coord_1, ...)
*
*/
const Op &tma_store();
TVM_DLL const Op &tma_store();
/*!
* \brief tvm intrinsics for mbarrier wait with parity bit
......@@ -125,7 +125,7 @@ const Op &tma_store();
* mbarrier_wait_parity(mbarrier, parity)
*
*/
const Op &mbarrier_wait_parity();
TVM_DLL const Op &mbarrier_wait_parity();
/*!
* \brief tvm intrinsics for mbarrier expect tx
......@@ -133,7 +133,7 @@ const Op &mbarrier_wait_parity();
* mbarrier_expect_tx(mbarrier, transaction_bytes)
*
*/
const Op &mbarrier_expect_tx();
TVM_DLL const Op &mbarrier_expect_tx();
/*!
* \brief tvm intrinsics for ldmatrix
......@@ -141,7 +141,7 @@ const Op &mbarrier_expect_tx();
* ptx_ldmatirx(transposed, num, shared_addr, local_addr)
*
*/
const Op &ptx_ldmatirx();
TVM_DLL const Op &ptx_ldmatirx();
/*!
* \brief tvm intrinsics for stmatrix
......@@ -149,7 +149,7 @@ const Op &ptx_ldmatirx();
* ptx_ldmatirx(transposed, num, shared_addr, int32_values...)
*
*/
const Op &ptx_stmatirx();
TVM_DLL const Op &ptx_stmatrix();
/*!
* \brief Pack two b16 value into a b32 value
......@@ -157,7 +157,7 @@ const Op &ptx_stmatirx();
* int32 pack_b16(b16_value, b16_value)
*
*/
const Op &pack_b16();
TVM_DLL const Op &pack_b16();
/*!
* \brief Similar to __syncthreads(), but can be used to sync partial threads
......@@ -165,7 +165,7 @@ const Op &pack_b16();
* sync_thread_partial(num_partial_threads or mbarrier)
*
*/
const Op &sync_thread_partial();
TVM_DLL const Op &sync_thread_partial();
/*!
* \brief Issue a shared memory fence for async operations
......@@ -173,7 +173,7 @@ const Op &sync_thread_partial();
* FenceProxyAsync()
*
*/
const Op &fence_proxy_async();
TVM_DLL const Op &fence_proxy_async();
/*!
* \brief Indicate arrival of warp issuing TMA_STORE
......@@ -181,7 +181,7 @@ const Op &fence_proxy_async();
* tma_store_arrive()
*
*/
const Op &tma_store_arrive();
TVM_DLL const Op &tma_store_arrive();
/*!
* \brief Wait for TMA_STORE to finish
......@@ -189,7 +189,7 @@ const Op &tma_store_arrive();
* tma_store_wait()
*
*/
const Op &tma_store_wait();
TVM_DLL const Op &tma_store_wait();
/*!
* \brief Set reg hint for warp-specialized branched
......@@ -197,7 +197,7 @@ const Op &tma_store_wait();
* SetMaxNRegInc(num_reg, is_inc)
*
*/
const Op &set_max_nreg();
TVM_DLL const Op &set_max_nreg();
/*!
* \brief No set reg hint for warp-specialized branched
......@@ -205,7 +205,7 @@ const Op &set_max_nreg();
* no_set_max_nreg()
*
*/
const Op &no_set_max_nreg();
TVM_DLL const Op &no_set_max_nreg();
/*!
* \brief Wait the previous wgmma to finish
......@@ -213,7 +213,7 @@ const Op &no_set_max_nreg();
* wait_wgmma(num_mma)
*
*/
const Op &wait_wgmma();
TVM_DLL const Op &wait_wgmma();
/*!
* \brief Synchronize all threads in a grid
......@@ -221,7 +221,7 @@ const Op &wait_wgmma();
* sync_grid()
*
*/
const Op &sync_grid();
TVM_DLL const Op &sync_grid();
/*!
* \brief tvm intrinsic for loop continue
......@@ -229,7 +229,7 @@ const Op &sync_grid();
* loop_break()
*
*/
const Op &loop_break();
TVM_DLL const Op &loop_break();
/*!
* \brief tvm intrinsic for amd matrix core mfma instructions.
......
......@@ -302,7 +302,7 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
num = 2;
Array<PrimExpr> args;
const Op &op = is_ldmatrix ? tl::ptx_ldmatirx() : tl::ptx_stmatirx();
const Op &op = is_ldmatrix ? tl::ptx_ldmatirx() : tl::ptx_stmatrix();
args.push_back(static_cast<int>(is_transposed));
args.push_back(num);
......
......@@ -695,7 +695,7 @@ void CodeGenTileLangCUDA::PrintStorageScope(const std::string &scope,
ICHECK_NE(scope, "global")
<< "Cannot allocate global memory when targeting CUDA. You must pass "
"all global arrays as input instead";
if (scope == "shared") {
if (scope == "shared" || scope == "shared.barrier") {
os << "__shared__ ";
} else if (scope == "shared.dyn") {
os << "extern __shared__ __align__(1024) ";
......@@ -943,6 +943,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
this->stream << ss.str();
this->stream << ");\n";
};
auto print_mbarrier_obj = [&](PrimExpr barrier_id) {
std::ostringstream ss;
if (barrier_id.as<IntImmNode>()) {
// incase the barrier_id is an integer, we need to print the barrier_id as
// an integer
ss << mbarrier_name_ << "[" << barrier_id << "]";
} else {
// otherwise may be a T.get_mbarrier() call or BufferLoad Node
// we need to print the barrier_id as a string
ss << this->PrintExpr(barrier_id);
}
return ss.str();
};
if (op->op.same_as(builtin::ptx_cp_async())) {
std::string dst = this->PrintExpr(op->args[0]);
std::string dst_offset = this->PrintExpr(op->args[1]);
......@@ -971,25 +984,73 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
} else if (op->op.same_as(builtin::create_barriers())) {
this->PrintIndent();
int barrier_count = Downcast<IntImm>(op->args[0])->value;
std::string barrier_name = "_mbarrier";
this->stream << "__shared__ uint64_t " << barrier_name << "["
auto mbarrier_storage_name = mbarrier_name_ + "_mem";
this->stream << "__shared__ uint64_t " << mbarrier_storage_name << "["
<< barrier_count << "];\n";
this->PrintIndent();
this->stream << "auto " << mbarrier_name_ << " = reinterpret_cast<"
<< mbarrier_dtype_ << "*>(" << mbarrier_storage_name << ");\n";
} else if (op->op.same_as(tl::get_mbarrier())) {
std::string barrier_name = "_mbarrier";
ICHECK_EQ(op->args.size(), 1);
std::string barrier_id = this->PrintExpr(op->args[0]);
os << barrier_name + "[" + barrier_id + "]";
os << mbarrier_name_ + "[" + barrier_id + "]";
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
print_extern_call_stmt("tl::mbarrier_arrive");
if (op->args.size() == 1) {
this->PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
this->stream << mbarrier_obj << ".arrive();\n";
} else if (op->args.size() == 3) {
this->PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto cta_id = this->PrintExpr(op->args[1]);
auto pred = this->PrintExpr(op->args[2]);
this->stream << mbarrier_obj << ".arrive(" << cta_id << ", " << pred
<< ");\n";
} else {
LOG(FATAL) << "Invalid parameter for tl::arrive_barrier "
<< op->args.size();
}
} else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
print_extern_call_stmt("tl::mbarrier_init");
ICHECK_EQ(op->args.size(), 2);
this->PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto arrive_count = this->PrintExpr(op->args[1]);
this->stream << mbarrier_obj << ".init(" << arrive_count << ");\n";
} else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
print_extern_call_stmt("tl::mbarrier_arrive_expect_tx");
if (op->args.size() == 2) {
this->PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto transaction_bytes = this->PrintExpr(op->args[1]);
this->stream << mbarrier_obj << ".arrive_and_expect_tx("
<< transaction_bytes << ");\n";
} else if (op->args.size() == 4) {
this->PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto transaction_bytes = this->PrintExpr(op->args[1]);
auto cta_id = this->PrintExpr(op->args[2]);
auto pred = this->PrintExpr(op->args[3]);
this->stream << mbarrier_obj << ".arrive_and_expect_tx("
<< transaction_bytes << ", " << cta_id << ", " << pred
<< ");\n";
} else {
LOG(FATAL) << "Invalid parameter for tl::arrive_barrier_expect_tx "
<< op->args.size();
}
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
print_extern_call_stmt("tl::mbarrier_cp_async_arrive");
} else if (op->op.same_as(tl::mbarrier_expect_tx())) {
print_extern_call_stmt("tl::mbarrier_expect_tx");
ICHECK_EQ(op->args.size(), 2);
this->PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto transaction_bytes = this->PrintExpr(op->args[1]);
this->stream << mbarrier_obj << ".expect_transaction(" << transaction_bytes
<< ");\n";
} else if (op->op.same_as(tl::mbarrier_wait_parity())) {
print_extern_call_stmt("tl::mbarrier_wait");
ICHECK_EQ(op->args.size(), 2);
this->PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto phase = this->PrintExpr(op->args[1]);
this->stream << mbarrier_obj << ".wait(" << phase << ");\n";
} else if (op->op.same_as(tl::sync_thread_partial())) {
print_extern_call_stmt("cutlass::arch::NamedBarrier::sync");
} else if (op->op.same_as(tl::no_set_max_nreg())) {
......@@ -1008,11 +1069,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
}
auto desc = op->args[0];
ss << this->PrintExpr(desc) << ", ";
if (const IntImmNode *imm = op->args[1].as<IntImmNode>()) {
ss << "_mbarrier[" << imm->value << "], ";
} else {
ss << this->PrintExpr(op->args[1]) << ", ";
}
ss << print_mbarrier_obj(op->args[1]) << ", ";
for (size_t i = 2; i < op->args.size() - 1; i++) {
if (i > 2)
ss << ", ";
......@@ -1050,7 +1107,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
if (trans == 1)
func_name += "_trans";
print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::ptx_stmatirx())) {
} else if (op->op.same_as(tl::ptx_stmatrix())) {
int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num);
......@@ -1370,13 +1427,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
int n = Downcast<IntImm>(op->args[0])->value;
this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n
<< ";\");\n\n";
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
need_cast_smem_ptr_to_int_ = true;
int barrier_id = Downcast<IntImm>(op->args[0])->value;
CHECK(barrier_id < barrier_count_);
std::string barrier =
barrier_name_ + "[" + std::to_string(barrier_id) + "]";
this->stream << PrintCpAsyncBarrierAsm(barrier);
} else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
need_cast_smem_ptr_to_int_ = true;
int barrier_id = Downcast<IntImm>(op->args[0])->value;
......@@ -1407,22 +1457,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
std::string barrier =
barrier_name_ + "[" + std::to_string(barrier_id) + "]";
this->stream << PrintWaitBarrierAsm(barrier);
} else if (op->op.same_as(builtin::create_barriers())) {
CHECK_EQ(barrier_count_, -1);
int barrier_count = Downcast<IntImm>(op->args[0])->value;
// pad barrier alignment to avoid runtime alignment errors
CHECK_EQ(barrier_alignment_bytes_ % sizeof(uint64_t), 0);
int barrier_alignment_count = barrier_alignment_bytes_ / sizeof(uint64_t);
if (barrier_count % barrier_alignment_count != 0) {
barrier_count = ((barrier_count / barrier_alignment_count) + 1) *
barrier_alignment_count;
}
barrier_count_ = barrier_count;
this->stream << "__shared__ __align__(" << barrier_alignment_bytes_
<< ") uint64_t " << barrier_name_ << "[" << barrier_count
<< "];\n";
this->stream << "for (int i = 0; i < " << barrier_count << "; ++i) { "
<< barrier_name_ << "[i] = 0; }\n";
} else if (op->op.same_as(builtin::ptx_ldg32())) {
/*
asm volatile (
......@@ -1654,6 +1688,12 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
}
if (scope == "shared") {
stream << ' ' << vid << '[' << constant_size << "];\n";
} else if (scope == "shared.barrier") {
auto v_id_mem = vid + "_mem";
stream << ' ' << v_id_mem << "[" << constant_size << "];\n";
PrintIndent();
stream << "auto " << vid << " = reinterpret_cast<" << mbarrier_dtype_
<< "*>(" << v_id_mem << ");\n";
} else if (scope == "local") {
stream << ' ' << vid << '[' << constant_size << "];\n";
} else if (scope == "local.var") {
......
......@@ -114,6 +114,10 @@ private:
const std::string barrier_name_ = "barrier";
// The size of the barrier array in shared memory
int barrier_count_ = -1;
// The name of the mbarrier array in shared memory
const std::string mbarrier_name_ = "mbarrier";
// The type name of the mbarrier array
const std::string mbarrier_dtype_ = "Barrier";
// The alignment of the barrier array in shared memory
// Set to 16 to maintain minimum alignment requirements for async bulk copy
const int barrier_alignment_bytes_ = 16;
......
......@@ -785,31 +785,9 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
int n = Downcast<IntImm>(op->args[0])->value;
std::string func_name = "tl::cp_async_wait<" + std::to_string(n) + ">";
print_extern_call_stmt(func_name, 1);
} else if (op->op.same_as(builtin::create_barriers())) {
this->PrintIndent();
int barrier_count = Downcast<IntImm>(op->args[0])->value;
std::string barrier_name = "_mbarrier";
this->stream << "__shared__ uint64_t " << barrier_name << "["
<< barrier_count << "];\n";
} else if (op->op.same_as(tl::get_mbarrier())) {
std::string barrier_name = "_mbarrier";
std::string barrier_id = this->PrintExpr(op->args[0]);
os << barrier_name + "[" + barrier_id + "]";
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
print_extern_call_stmt("tl::mbarrier_arrive");
} else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
print_extern_call_stmt("tl::mbarrier_init");
} else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
print_extern_call_stmt("tl::mbarrier_arrive_expect_tx");
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
print_extern_call_stmt("tl::mbarrier_cp_async_arrive");
} else if (op->op.same_as(tl::mbarrier_expect_tx())) {
print_extern_call_stmt("tl::mbarrier_expect_tx");
} else if (op->op.same_as(tl::mbarrier_wait_parity())) {
print_extern_call_stmt("tl::mbarrier_wait");
} else if (op->op.same_as(tl::sync_thread_partial())) {
print_extern_call_stmt("tl::syncthreads_partial");
} else if (op->op.same_as(tl::ptx_stmatirx())) {
} else if (op->op.same_as(tl::ptx_stmatrix())) {
int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num);
......
#pragma once
#include "common.h"
#include <cutlass/arch/barrier.h>
// Reuse cutlass advanced barrier abstraction
using Barrier = cutlass::arch::ClusterTransactionBarrier;
namespace tl {
TL_DEVICE void mbarrier_init(uint64_t &smem_barrier, uint32_t arrive_count) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.init.shared.b64 [%1], %0;"
:
: "r"(arrive_count), "r"(smem_int_ptr));
}
TL_DEVICE uint32_t mbarrier_try_wait(uint64_t &smem_barrier, int phase_bit) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
uint32_t waitComplete;
asm volatile("{\n\t"
".reg .pred P1; \n\t"
"mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P1; \n\t"
"}"
: "=r"(waitComplete)
: "r"(smem_int_ptr), "r"(phase_bit));
return waitComplete;
}
TL_DEVICE void mbarrier_wait(uint64_t &smem_barrier, int phase_bit) {
if (mbarrier_try_wait(smem_barrier, phase_bit) == 0) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
// Arbitrarily large timer value after which try-wait expires and re-tries.
uint32_t ticks = 0x989680;
asm volatile("{\n\t"
".reg .pred P1; \n\t"
"LAB_WAIT: \n\t"
"mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t"
"@P1 bra DONE; \n\t"
"bra LAB_WAIT; \n\t"
"DONE: \n\t"
"}"
:
: "r"(smem_int_ptr), "r"(phase_bit), "r"(ticks));
}
}
TL_DEVICE void mbarrier_test_wait(uint64_t &smem_barrier, int phase_bit) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile(
"{\n"
".reg .pred P1;\n"
"LAB_WAIT:\n"
"mbarrier.test_wait.parity.shared::cta.b64 P1, [%0], %1;\n"
"@P1 bra.uni DONE;\n"
"nanosleep.u32 5;\n" // wait a few nanoseconds on pre-Hopper architectures
// to save instruction issue slots
"bra.uni LAB_WAIT;\n"
"DONE:\n"
"}\n" ::"r"(smem_int_ptr),
"r"(phase_bit));
}
TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.arrive.shared.b64 _, [%0];" : : "r"(smem_int_ptr));
}
TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier, int cta_id,
uint32_t pred) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
if (pred) {
asm volatile("{\n\t"
".reg .b32 remAddr32;\n\t"
"mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t"
"mbarrier.arrive.shared::cluster.b64 _, [remAddr32];\n\t"
"}"
:
: "r"(smem_int_ptr), "r"(cta_id));
}
}
TL_DEVICE void mbarrier_expect_tx(uint64_t &smem_barrier,
uint32_t transaction_bytes) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.expect_tx.shared.b64 [%1], %0;"
:
: "r"(transaction_bytes), "r"(smem_int_ptr));
}
TL_DEVICE void mbarrier_arrive_expect_tx(uint64_t &smem_barrier,
uint32_t transaction_bytes) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%1], %0;"
:
: "r"(transaction_bytes), "r"(smem_int_ptr));
}
template <typename BarrierType = uint64_t>
TL_DEVICE void mbarrier_cp_async_arrive(BarrierType &smem_mbar) {
uint32_t smem_int_mbar;
if constexpr (std::is_pointer_v<BarrierType>) {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(smem_mbar));
} else {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar));
}
asm volatile("cp.async.mbarrier.arrive.shared.b64 [%0];"
:
: "r"(smem_int_mbar));
}
TL_DEVICE void fence_proxy_async() {
asm volatile("fence.proxy.async.shared::cta;" : :);
}
// Indicate arrival of warp issuing TMA_STORE
TL_DEVICE void tma_store_arrive() {
asm volatile("cp.async.bulk.commit_group;");
}
template <int Count> TL_DEVICE void tma_store_wait() {
asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(Count) : "memory");
}
TL_DEVICE void syncthreads_partial(uint64_t &smem_barrier) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
uint64_t state = 0;
asm volatile("{\n"
".reg .pred P1;\n"
"mbarrier.arrive.shared.b64 %1, [%0];\n"
"LAB_WAIT:\n"
"mbarrier.try_wait.shared.b64 P1, [%0], %1;\n"
"@!P1 bra.uni LAB_WAIT;\n"
"}\n"
:
: "r"(smem_int_ptr), "l"(state));
}
} // namespace tl
......@@ -250,4 +250,12 @@ template <int thread_extent> TL_DEVICE bool tl_shuffle_elect() {
cute::elect_one_sync();
}
template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_alloc() {
asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount));
}
template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_dealloc() {
asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount));
}
} // namespace tl
......@@ -4,6 +4,7 @@
#include <cuda.h>
#endif
#include "barrier.h"
#include "common.h"
namespace tl {
......@@ -13,9 +14,11 @@ enum class CacheHintSm90 : uint64_t {
EVICT_LAST = 0x14F0000000000000,
};
TL_DEVICE void tma_load(void *smem_ptr, void *gmem_ptr, uint64_t &smem_mbar,
template <typename BarrierType = uint64_t>
TL_DEVICE void tma_load(void *smem_ptr, void *gmem_ptr, BarrierType &smem_mbar,
uint32_t size) {
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_mbar =
smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar));
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::"
"bytes [%0], [%1], %2, [%3]; \n" ::"r"(smem_int_ptr),
......@@ -35,11 +38,17 @@ TL_DEVICE void tma_load_multicast(void *smem_ptr, void *gmem_ptr,
:);
}
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL,
typename BarrierType = uint64_t>
TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar,
void const *const smem_ptr, int32_t const &crd0) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_mbar;
if constexpr (std::is_pointer_v<BarrierType>) {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(smem_mbar));
} else {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar));
}
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::"
"complete_tx::bytes.L2::cache_hint"
......@@ -50,12 +59,18 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
: "memory");
}
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL,
typename BarrierType = uint64_t>
TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar,
void const *const smem_ptr, int32_t const &crd0,
int32_t const &crd1) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_mbar;
if constexpr (std::is_pointer_v<BarrierType>) {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(smem_mbar));
} else {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar));
}
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::"
"complete_tx::bytes.L2::cache_hint"
......@@ -66,12 +81,18 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
: "memory");
}
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL,
typename BarrierType = uint64_t>
TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar,
void const *const smem_ptr, int32_t const &crd0,
int32_t const &crd1, int32_t const &crd2) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_mbar;
if constexpr (std::is_pointer_v<BarrierType>) {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(smem_mbar));
} else {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar));
}
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::"
"complete_tx::bytes.L2::cache_hint"
......@@ -81,13 +102,19 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
"r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint)
: "memory");
}
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL,
typename BarrierType = uint64_t>
TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar,
void const *const smem_ptr, int32_t const &crd0,
int32_t const &crd1, int32_t const &crd2,
int32_t const &crd3) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_mbar;
if constexpr (std::is_pointer_v<BarrierType>) {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(smem_mbar));
} else {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar));
}
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::"
"complete_tx::bytes.L2::cache_hint"
......@@ -98,13 +125,19 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
: "memory");
}
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL,
typename BarrierType = uint64_t>
TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar,
void const *const smem_ptr, int32_t const &crd0,
int32_t const &crd1, int32_t const &crd2,
int32_t const &crd3, int32_t const &crd4) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_mbar;
if constexpr (std::is_pointer_v<BarrierType>) {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(smem_mbar));
} else {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar));
}
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::"
"complete_tx::bytes.L2::cache_hint"
......@@ -116,15 +149,17 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
: "memory");
}
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
TL_DEVICE void tma_load_im2col(const CUtensorMap &descriptor,
uint64_t &smem_mbar, void const *const smem_ptr,
int32_t const &coord_c, int32_t const &coord_w,
int32_t const &coord_h, int32_t const &coord_n,
uint16_t const &offset_w,
uint16_t const &offset_h) {
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL,
typename BarrierType = uint64_t>
TL_DEVICE void
tma_load_im2col(const CUtensorMap &descriptor, BarrierType &smem_mbar,
void const *const smem_ptr, int32_t const &coord_c,
int32_t const &coord_w, int32_t const &coord_h,
int32_t const &coord_n, uint16_t const &offset_w,
uint16_t const &offset_h) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_mbar =
smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar));
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier:"
":complete_tx::bytes.L2::cache_hint"
......@@ -212,138 +247,4 @@ TL_DEVICE void prefetch_tma_descriptor(const CUtensorMap &descriptor) {
asm volatile("prefetch.tensormap [%0];" : : "l"(gmem_int_desc) : "memory");
}
TL_DEVICE void mbarrier_init(uint64_t &smem_barrier, uint32_t arrive_count) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.init.shared.b64 [%1], %0;"
:
: "r"(arrive_count), "r"(smem_int_ptr));
}
TL_DEVICE uint32_t mbarrier_try_wait(uint64_t &smem_barrier, int phase_bit) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
uint32_t waitComplete;
asm volatile("{\n\t"
".reg .pred P1; \n\t"
"mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P1; \n\t"
"}"
: "=r"(waitComplete)
: "r"(smem_int_ptr), "r"(phase_bit));
return waitComplete;
}
TL_DEVICE void mbarrier_wait(uint64_t &smem_barrier, int phase_bit) {
if (mbarrier_try_wait(smem_barrier, phase_bit) == 0) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
// Arbitrarily large timer value after which try-wait expires and re-tries.
uint32_t ticks = 0x989680;
asm volatile("{\n\t"
".reg .pred P1; \n\t"
"LAB_WAIT: \n\t"
"mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t"
"@P1 bra DONE; \n\t"
"bra LAB_WAIT; \n\t"
"DONE: \n\t"
"}"
:
: "r"(smem_int_ptr), "r"(phase_bit), "r"(ticks));
}
}
TL_DEVICE void mbarrier_test_wait(uint64_t &smem_barrier, int phase_bit) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile(
"{\n"
".reg .pred P1;\n"
"LAB_WAIT:\n"
"mbarrier.test_wait.parity.shared::cta.b64 P1, [%0], %1;\n"
"@P1 bra.uni DONE;\n"
"nanosleep.u32 5;\n" // wait a few nanoseconds on pre-Hopper architectures
// to save instruction issue slots
"bra.uni LAB_WAIT;\n"
"DONE:\n"
"}\n" ::"r"(smem_int_ptr),
"r"(phase_bit));
}
TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.arrive.shared.b64 _, [%0];" : : "r"(smem_int_ptr));
}
TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier, int cta_id,
uint32_t pred) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
if (pred) {
asm volatile("{\n\t"
".reg .b32 remAddr32;\n\t"
"mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t"
"mbarrier.arrive.shared::cluster.b64 _, [remAddr32];\n\t"
"}"
:
: "r"(smem_int_ptr), "r"(cta_id));
}
}
TL_DEVICE void mbarrier_expect_tx(uint64_t &smem_barrier,
uint32_t transaction_bytes) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.expect_tx.shared.b64 [%1], %0;"
:
: "r"(transaction_bytes), "r"(smem_int_ptr));
}
TL_DEVICE void mbarrier_arrive_expect_tx(uint64_t &smem_barrier,
uint32_t transaction_bytes) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%1], %0;"
:
: "r"(transaction_bytes), "r"(smem_int_ptr));
}
TL_DEVICE void mbarrier_cp_async_arrive(uint64_t &smem_barrier) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("cp.async.mbarrier.arrive.shared.b64 [%0];"
:
: "r"(smem_int_ptr));
}
TL_DEVICE void fence_proxy_async() {
asm volatile("fence.proxy.async.shared::cta;" : :);
}
// Indicate arrival of warp issuing TMA_STORE
TL_DEVICE void tma_store_arrive() {
asm volatile("cp.async.bulk.commit_group;");
}
template <int Count> TL_DEVICE void tma_store_wait() {
asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(Count) : "memory");
}
TL_DEVICE void syncthreads_partial(uint64_t &smem_barrier) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
uint64_t state = 0;
asm volatile("{\n"
".reg .pred P1;\n"
"mbarrier.arrive.shared.b64 %1, [%0];\n"
"LAB_WAIT:\n"
"mbarrier.try_wait.shared.b64 P1, [%0], %1;\n"
"@!P1 bra.uni LAB_WAIT;\n"
"}\n"
:
: "r"(smem_int_ptr), "l"(state));
}
template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_alloc() {
asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount));
}
template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_dealloc() {
asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount));
}
} // namespace tl
\ No newline at end of file
} // namespace tl
......@@ -58,7 +58,7 @@ public:
Proxy proxy = Proxy::kAsync;
if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(ptx_ldmatirx()) ||
call->op.same_as(ptx_stmatirx())) {
call->op.same_as(ptx_stmatrix())) {
proxy = Proxy::kGeneric;
}
}
......
......@@ -44,7 +44,8 @@ class StorageAccessInfoLower : public StmtExprMutator {
public:
Stmt VisitStmt_(const AllocateNode *op) final {
auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var));
if (scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".var") {
if (scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".var" &&
scope.tag != ".barrier") {
auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var));
ICHECK(info.defined())
<< "Cannot find memory info of " << scope.to_string();
......
......@@ -2,6 +2,7 @@
* \file lower_shared_barrier.cc
* \brief Convert shared.barrier buffers to plain shared + ptx init.
*/
#include "../op/builtin.h"
#include "tvm/ir/type.h"
#include "tvm/tir/expr.h"
#include "tvm/tir/stmt.h"
......@@ -19,12 +20,15 @@ using namespace tir;
class SharedBarrierRewriter : public StmtExprMutator {
public:
static Stmt Rewrite(Stmt body) {
SharedBarrierRewriter rewriter;
static Stmt Rewrite(Stmt body, bool disable_shuffle_elect = false) {
SharedBarrierRewriter rewriter(disable_shuffle_elect);
return rewriter(body);
}
private:
SharedBarrierRewriter(bool disable_shuffle_elect)
: disable_shuffle_elect_(disable_shuffle_elect) {}
Stmt VisitStmt_(const BlockNode *op) final {
Block block = GetRef<Block>(op);
Array<Buffer> alloc_buffers = op->alloc_buffers;
......@@ -74,25 +78,12 @@ private:
T.ptx_init_barrier_thread_count(data_is_ready[0], 128)
T.ptx_init_barrier_thread_count(compute_is_done[0], 128)
*/
// 1. create new data vars
Array<Var> new_data_vars;
for (auto buffer : barrier_buffers) {
auto data = buffer->data;
auto ptr_type = data->type_annotation.as<PointerTypeNode>();
auto new_data =
Var(data->name_hint, PointerType(ptr_type->element_type, "shared"));
var_remap_.Set(data, new_data);
new_data_vars.push_back(new_data);
}
// 2. create new buffers
Array<Buffer> new_buffers;
for (auto buffer : barrier_buffers) {
auto data = buffer->data;
ICHECK(var_remap_.find(data) != var_remap_.end())
<< "data not found in var_remap_";
auto new_data = var_remap_.at(data);
auto new_buffer = Buffer(new_data, buffer->dtype, Array<PrimExpr>({1}),
auto new_buffer = Buffer(data, buffer->dtype, Array<PrimExpr>({1}),
Array<PrimExpr>({1}), PrimExpr(0), buffer->name,
buffer->data_alignment, buffer->offset_factor,
buffer->buffer_type);
......@@ -128,8 +119,14 @@ private:
}
Array<Stmt> new_body;
new_body.push_back(IfThenElse(EQ(thread_var_->var, 0),
SeqStmt(init_mbarrier_calls_), Stmt()));
PrimExpr condition;
if (!disable_shuffle_elect_) {
condition = Call(DataType::Bool(), tl_shuffle_elect(), {0});
} else {
condition = EQ(thread_var_->var, 0);
}
new_body.push_back(
IfThenElse(condition, SeqStmt(init_mbarrier_calls_), Stmt()));
new_body.push_back(
Evaluate(Call(DataType::Handle(), builtin::tvm_storage_sync(),
{StringImm("shared")})));
......@@ -146,12 +143,6 @@ private:
if (buffer_remap_.count(buffer)) {
auto new_buffer = buffer_remap_[load->buffer];
return BufferLoad(new_buffer, load->indices);
} else if (var_remap_.count(buffer->data)) {
auto new_buffer = Buffer(
var_remap_[buffer->data], buffer->dtype, buffer->shape,
buffer->strides, buffer->elem_offset, buffer->name,
buffer->data_alignment, buffer->offset_factor, buffer->buffer_type);
return BufferLoad(new_buffer, load->indices);
}
return load;
}
......@@ -162,12 +153,6 @@ private:
if (buffer_remap_.count(buffer)) {
auto new_buffer = buffer_remap_[store->buffer];
return BufferStore(new_buffer, store->value, store->indices);
} else if (var_remap_.count(buffer->data)) {
auto new_buffer = Buffer(
var_remap_[buffer->data], buffer->dtype, buffer->shape,
buffer->strides, buffer->elem_offset, buffer->name,
buffer->data_alignment, buffer->offset_factor, buffer->buffer_type);
return BufferStore(new_buffer, store->value, store->indices);
}
return store;
}
......@@ -186,16 +171,17 @@ private:
// This is a workaround for cpu backend,
// we need to define a thread_var for the serial loop.
IterVar thread_var_;
Map<Var, Var> var_remap_;
Map<Var, Buffer> buffer_data_to_buffer_;
Map<Buffer, Buffer> buffer_remap_;
// Mapping from data Var of a Buffer to Buffer, for lookup
std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
// Disable shuffle elect for the warp specialized kernel
bool disable_shuffle_elect_;
};
PrimFunc LowerSharedBarrier(PrimFunc f) {
SharedBarrierRewriter rewriter;
f.CopyOnWrite()->body = rewriter.Rewrite(f->body);
PrimFunc LowerSharedBarrier(PrimFunc f, bool disable_shuffle_elect) {
f.CopyOnWrite()->body =
SharedBarrierRewriter::Rewrite(f->body, disable_shuffle_elect);
return f;
}
......@@ -204,7 +190,9 @@ using namespace tir::transform;
tvm::transform::Pass LowerSharedBarrier() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return tl::LowerSharedBarrier(std::move(f));
bool disable_shuffle_elect =
ctx->GetConfig<Bool>(kDisableShuffleElect, Bool(false)).value();
return tl::LowerSharedBarrier(std::move(f), disable_shuffle_elect);
};
return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedBarrier", {});
}
......
......@@ -672,7 +672,8 @@ private:
// memory. Special memory is all combined into a single allocation.
bool IsSpecialTaggedMemory(const StorageScope &scope) {
return scope.tag.length() != 0 && scope.tag != ".dyn" &&
scope.tag != ".workspace" && scope.tag != ".vtcm";
scope.tag != ".barrier" && scope.tag != ".workspace" &&
scope.tag != ".vtcm";
}
// Allocate entry of node.
......@@ -841,7 +842,10 @@ private:
ICHECK_NE(e->scope.tag.length(), 0U);
// allocate with element type.
ICHECK_NE(e->const_nbits, 0U);
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
MemoryInfo info;
if (e->scope.tag != ".barrier" && e->scope.tag != ".var") {
info = GetMemoryInfo(e->scope.to_string());
}
uint64_t total_bits = e->const_nbits;
// By default, align to 32 bits.
size_t align = 32;
......@@ -1784,6 +1788,8 @@ public:
PrimExpr last_extent = extents[extents.size() - 1];
extents.Set(extents.size() - 1,
last_extent / make_const(last_extent.dtype(), info.factor()));
LOG(INFO) << "Allocate with " << new_buffer_var << " and "
<< info.new_element_dtype << " extents: " << extents;
return Allocate(new_buffer_var, info.new_element_dtype, extents,
op->condition, op->body);
}
......
......@@ -14,11 +14,14 @@
#include "../op/builtin.h"
#include "./common/collector.h"
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
using namespace runtime;
using arith::IRVisitorWithAnalyzer;
enum class Role { kConsumer, kProducer, kBoth };
......@@ -149,8 +152,8 @@ public:
}
void VisitStmt_(const BufferStoreNode *op) final {
bool is_shared_store =
op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared";
auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer->data));
bool is_shared_store = scope.rank == StorageRank::kShared;
if (producer_buffers_.count(op->buffer.get())) {
SetRole(op, Role::kBoth);
return;
......@@ -570,29 +573,35 @@ public:
class WSCodeEmitter : public StmtMutator {
public:
/**
* @brief Construct a warp-specialized code emitter configured for producer or consumer emission.
*
* Initializes a WSCodeEmitter that will emit barrier-aware, role-filtered code for a single
* warp-specialized block. The emitter is configured with the loop/thread iteration variable,
* buffer mapping, role marker used to classify statements, and two flags that control emission
* behavior:
*
* - `mbarrier_only`: when true, emission is restricted to barrier-related operations only.
* - `only_has_wgmma`: when true, the emitter will account for the presence of WgMMA
* (workgroup MMA) operations when computing barrier/thread gating behavior.
*
* @param is_emitting_producer True to emit producer-side groups; false to emit consumer-side groups.
* @param thread_iv IterVar representing the thread iteration variable (threadIdx.*) whose Var is used
* for thread-index rewrites and gating.
* @param buffer_data_to_buffer Map from buffer data Var to the corresponding Buffer (used to resolve
* buffer references during emission).
* @param marker Role marker that classifies statements as producer/consumer/both; used to filter
* which statements are emitted on this path.
* @param mbarrier_only If true, restrict emission to mbarrier-related statements and helpers.
* @param only_has_wgmma If true, adjust emission and barrier-thread-count logic for blocks that
* contain WgMMA operations.
*/
WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv,
* @brief Construct a warp-specialized code emitter configured for producer or
* consumer emission.
*
* Initializes a WSCodeEmitter that will emit barrier-aware, role-filtered
* code for a single warp-specialized block. The emitter is configured with
* the loop/thread iteration variable, buffer mapping, role marker used to
* classify statements, and two flags that control emission behavior:
*
* - `mbarrier_only`: when true, emission is restricted to barrier-related
* operations only.
* - `only_has_wgmma`: when true, the emitter will account for the presence of
* WgMMA (workgroup MMA) operations when computing barrier/thread gating
* behavior.
*
* @param is_emitting_producer True to emit producer-side groups; false to
* emit consumer-side groups.
* @param thread_iv IterVar representing the thread iteration variable
* (threadIdx.*) whose Var is used for thread-index rewrites and gating.
* @param buffer_data_to_buffer Map from buffer data Var to the corresponding
* Buffer (used to resolve buffer references during emission).
* @param marker Role marker that classifies statements as
* producer/consumer/both; used to filter which statements are emitted on this
* path.
* @param mbarrier_only If true, restrict emission to mbarrier-related
* statements and helpers.
* @param only_has_wgmma If true, adjust emission and barrier-thread-count
* logic for blocks that contain WgMMA operations.
*/
WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv,
Map<Var, Buffer> buffer_data_to_buffer,
const WarpSpecializedRoleMarker &marker,
bool mbarrier_only = false, bool only_has_wgmma = false)
......@@ -602,14 +611,15 @@ public:
only_has_wgmma_(only_has_wgmma) {}
/**
* @brief Whether a SIMT-style bulk copy was detected.
*
* Returns true when a simulated SIMT (thread-parallel) copy pattern was observed
* during analysis/emission, which can affect barrier insertion and copy emission.
*
* @return true if a SIMT copy was detected; false otherwise.
*/
bool hasSimtCopy() const { return has_simt_copy_; }
* @brief Whether a SIMT-style bulk copy was detected.
*
* Returns true when a simulated SIMT (thread-parallel) copy pattern was
* observed during analysis/emission, which can affect barrier insertion and
* copy emission.
*
* @return true if a SIMT copy was detected; false otherwise.
*/
bool hasSimtCopy() const { return has_simt_copy_; }
private:
template <typename NodeType> Stmt FilterByRole(const NodeType *op) {
......@@ -628,18 +638,18 @@ private:
}
/**
* @brief Visit and transform a SeqStmt node, emitting grouped blocks with barrier
* synchronization according to producer/consumer roles.
* @brief Visit and transform a SeqStmt node, emitting grouped blocks with
* barrier synchronization according to producer/consumer roles.
*
* This method examines the sequence to determine whether producer-side
* synchronization is required (based on marker_ roles). If no producer sync is
* needed it delegates to FilterByRole. Otherwise it:
* synchronization is required (based on marker_ roles). If no producer sync
* is needed it delegates to FilterByRole. Otherwise it:
* - Recursively visits and transforms each child statement.
* - Extracts an acquire/release sync pattern for the sequence via
* ExtractSyncPattern.
* - For producer emission (is_emitting_producer_ == true):
* - Skips consumer-only statements unless marker_ marks a statement as Both,
* in which case the statement is emitted as its own group.
* - Skips consumer-only statements unless marker_ marks a statement as
* Both, in which case the statement is emitted as its own group.
* - For each statement, inserts parity waits for acquire patterns, rewrites
* release statements with MbarrierRewriter using a computed barrier id,
* collects SimT-copy presence (setting has_simt_copy_ and inserting
......@@ -1248,21 +1258,21 @@ private:
}
/**
* @brief Rewrite a BlockRealize for warp specialization, inserting barriers and
* emitting producer/consumer bodies.
* @brief Rewrite a BlockRealize for warp specialization, inserting barriers
* and emitting producer/consumer bodies.
*
* This visitor handles BlockRealize nodes when a thread IterVar (thread_iv_)
* is defined and warp-specialization is applicable. It:
* - Determines producer/consumer roles via WarpSpecializedRoleMarker and
* returns the original block if no producer is detected.
* - If warp specialization is disabled, emits only mbarrier initialization and
* the mbarrier-only transformed body.
* - If warp specialization is disabled, emits only mbarrier initialization
* and the mbarrier-only transformed body.
* - Otherwise, detects WgMMA usage for the block body and constructs separate
* WSCodeEmitter instances for producer and consumer paths (propagating the
* WgMMA flag to the consumer emitter).
* - Generates producer/consumer code, applies register hint calls (set_max_nreg)
* when available, and rewrites thread indices with ThreadIdxRewriter to
* partition threads between producer and consumer roles.
* - Generates producer/consumer code, applies register hint calls
* (set_max_nreg) when available, and rewrites thread indices with
* ThreadIdxRewriter to partition threads between producer and consumer roles.
* - Computes and initializes a list of mbarrier handles with per-barrier
* arrive thread counts (taking SIMT-copy and WgMMA cases into account).
* - Wraps the transformed body in an IfThenElse that dispatches producer vs
......
......@@ -42,6 +42,7 @@ def get_configs():
} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs(),)
@tilelang.jit(out_idx=[-1])
def matmul(M,
N,
......@@ -51,7 +52,7 @@ def matmul(M,
block_K=32,
num_stages=0,
thread_num=128,
enable_rasteration=False):
enable_rasterization=False):
dtype = "float16"
accum_dtype = "float"
......@@ -84,7 +85,7 @@ def matmul(M,
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Enable (or disable) swizzling optimization
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.use_swizzle(panel_size=10, enable=enable_rasterization)
# Clear out the accumulation buffer
T.clear(C_local)
......
from tilelang import tvm as tvm
import tilelang.testing
from tilelang import cached
from tilelang.cache import cached
import tilelang.language as T
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment