"src/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "bcddacbab65a5b78719fbf73e6e7142fe855ef24"
Unverified Commit cb37bfef authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Refactor barrier management (#744)

* Introduce Barrier

* Enhance CUDA kernel with new barrier management and post-processing support

- Added a new CUDA kernel implementation in `example_mla_decode.py` for improved performance with shared memory barriers.
- Refactored barrier handling in `codegen_cuda.cc` and `codegen_hip.cc` to utilize a more flexible mbarrier structure.
- Updated intrinsic definitions from `ptx_stmatirx` to `ptx_stmatrix` across multiple files for consistency.
- Introduced additional print statements for debugging in the lowering phase of the TileLang engine.
- Enhanced the overall structure and readability of the codebase.

* Remove unused barrier handling code in CUDA and HIP code generators to streamline the implementation. This change enhances code clarity and reduces complexity in the barrier management logic.

* Enhance barrier management in TileLang

- Introduced a new intrinsic `allocate_barrier` for dynamic barrier allocation in the TileLang framework.
- Updated CUDA code generation to support the new barrier structure, allowing for improved synchronization in shared memory.
- Refactored existing barrier handling logic to accommodate the new intrinsic and streamline code.
- Added print statements for debugging purposes in various examples and the lowering phase of the TileLang engine.
- Removed deprecated memory scope handling code to enhance clarity and maintainability.

* lint fix

* lint fix

* Remove `allocate_barrier` intrinsic and related code from TileLang to streamline barrier management. This includes updates to CUDA code generation and the removal of associated Python wrappers, enhancing code clarity and maintainability.

* Refactor logging in JITKernel to improve kernel compilation tracking

- Removed unused import of `torch.backends` in the example file.
- Introduced logging for kernel compilation in `JITKernel`, replacing print statements with structured logging for better traceability and debugging.
- Added an assertion to ensure the presence of the `global_symbol` attribute in the kernel function.

* Refactor dequantization tests and update barrier function

- Removed the test for `example_dequant_gemm_bf16_fp4_hopper_serial` to streamline the testing suite.
- Updated the `mbarrier_cp_async_arrive` function to support both pointer and non-pointer types, enhancing flexibility in barrier management.

* Update CI configuration to increase pytest parallelism from 4 to 8 threads for improved test execution speed.

* Fix typos in rasterization parameters and update import path for cached module

- Corrected the spelling of `enable_rasteration` to `enable_rasterization` in the matmul function and its usage.
- Updated the import statement for the `cached` module to reflect the new path in the cache submodule.
- Added `StridedTensor` import in the language module for enhanced tensor functionality.

* Update ci.yml
parent eccdfe17
...@@ -2,7 +2,6 @@ import tilelang.testing ...@@ -2,7 +2,6 @@ import tilelang.testing
import example_dequant_gemv_fp16xint4 import example_dequant_gemv_fp16xint4
import example_dequant_gemm_fp4_hopper import example_dequant_gemm_fp4_hopper
import example_dequant_gemm_bf16_fp4_hopper_serial
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
...@@ -16,11 +15,5 @@ def test_example_dequant_gemm_fp4_hopper(): ...@@ -16,11 +15,5 @@ def test_example_dequant_gemm_fp4_hopper():
example_dequant_gemm_fp4_hopper.main() example_dequant_gemm_fp4_hopper.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_gemm_bf16_fp4_hopper_serial():
example_dequant_gemm_bf16_fp4_hopper_serial.main()
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
import torch import torch
import torch.backends
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
from tvm import DataType from tvm import DataType
......
...@@ -391,6 +391,7 @@ def main(batch=1, heads=128, kv_heads=1, kv_ctx=8192, dim=512, pe_dim=64): ...@@ -391,6 +391,7 @@ def main(batch=1, heads=128, kv_heads=1, kv_ctx=8192, dim=512, pe_dim=64):
num_split = 1 num_split = 1
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split) kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
latency = profiler.do_bench(warmup=500) latency = profiler.do_bench(warmup=500)
......
...@@ -66,7 +66,6 @@ def main(): ...@@ -66,7 +66,6 @@ def main():
# Run the kernel through the Profiler # Run the kernel through the Profiler
c = jit_kernel(a, b) c = jit_kernel(a, b)
# Reference multiplication using PyTorch # Reference multiplication using PyTorch
ref_c = a @ b ref_c = a @ b
......
...@@ -83,7 +83,7 @@ TIR_DEFINE_TL_BUILTIN(ptx_ldmatirx) ...@@ -83,7 +83,7 @@ TIR_DEFINE_TL_BUILTIN(ptx_ldmatirx)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_stmatirx) TIR_DEFINE_TL_BUILTIN(ptx_stmatrix)
.set_num_inputs(-1) .set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
......
...@@ -62,7 +62,7 @@ static constexpr const char *kDynamicAlignment = "tl.dynamic_alignment"; ...@@ -62,7 +62,7 @@ static constexpr const char *kDynamicAlignment = "tl.dynamic_alignment";
* swizzle, l2_promotion, oob_fill) * swizzle, l2_promotion, oob_fill)
* *
*/ */
const Op &create_tma_descriptor(); TVM_DLL const Op &create_tma_descriptor();
/*! /*!
* \brief tvm intrinsics for TMADescriptor creation for image to column load * \brief tvm intrinsics for TMADescriptor creation for image to column load
...@@ -73,7 +73,7 @@ const Op &create_tma_descriptor(); ...@@ -73,7 +73,7 @@ const Op &create_tma_descriptor();
* l2_promotion, oob_fill) * l2_promotion, oob_fill)
* *
*/ */
const Op &create_tma_im2col_descriptor(); TVM_DLL const Op &create_tma_im2col_descriptor();
/*! /*!
* \brief Create a list of mbarrier with num_threads * \brief Create a list of mbarrier with num_threads
...@@ -81,7 +81,7 @@ const Op &create_tma_im2col_descriptor(); ...@@ -81,7 +81,7 @@ const Op &create_tma_im2col_descriptor();
* create_list_of_mbarrier(num_threads0, num_threads1, ...) * create_list_of_mbarrier(num_threads0, num_threads1, ...)
* *
*/ */
const Op &create_list_of_mbarrier(); TVM_DLL const Op &create_list_of_mbarrier();
/*! /*!
* \brief Get the mbarrier with barrier_id * \brief Get the mbarrier with barrier_id
...@@ -89,7 +89,7 @@ const Op &create_list_of_mbarrier(); ...@@ -89,7 +89,7 @@ const Op &create_list_of_mbarrier();
* int64_t* GetMBarrier(barrier_id) * int64_t* GetMBarrier(barrier_id)
* *
*/ */
const Op &get_mbarrier(); TVM_DLL const Op &get_mbarrier();
/*! /*!
* \brief tvm intrinsics for loading data from global tensor descriptor to * \brief tvm intrinsics for loading data from global tensor descriptor to
...@@ -98,7 +98,7 @@ const Op &get_mbarrier(); ...@@ -98,7 +98,7 @@ const Op &get_mbarrier();
* tma_load(descriptor, mbarrier, smem_data, coord_0, coord_1, ...) * tma_load(descriptor, mbarrier, smem_data, coord_0, coord_1, ...)
* *
*/ */
const Op &tma_load(); TVM_DLL const Op &tma_load();
/*! /*!
* \brief tvm intrinsics for loading image from global tensor to columns in * \brief tvm intrinsics for loading image from global tensor to columns in
...@@ -108,7 +108,7 @@ const Op &tma_load(); ...@@ -108,7 +108,7 @@ const Op &tma_load();
* image_offset, ...) * image_offset, ...)
* *
*/ */
const Op &tma_load_im2col(); TVM_DLL const Op &tma_load_im2col();
/*! /*!
* \brief tvm intrinsics for storing data from shared memory to global tensor * \brief tvm intrinsics for storing data from shared memory to global tensor
...@@ -117,7 +117,7 @@ const Op &tma_load_im2col(); ...@@ -117,7 +117,7 @@ const Op &tma_load_im2col();
* tma_store(descriptor, smem_data, coord_0, coord_1, ...) * tma_store(descriptor, smem_data, coord_0, coord_1, ...)
* *
*/ */
const Op &tma_store(); TVM_DLL const Op &tma_store();
/*! /*!
* \brief tvm intrinsics for mbarrier wait with parity bit * \brief tvm intrinsics for mbarrier wait with parity bit
...@@ -125,7 +125,7 @@ const Op &tma_store(); ...@@ -125,7 +125,7 @@ const Op &tma_store();
* mbarrier_wait_parity(mbarrier, parity) * mbarrier_wait_parity(mbarrier, parity)
* *
*/ */
const Op &mbarrier_wait_parity(); TVM_DLL const Op &mbarrier_wait_parity();
/*! /*!
* \brief tvm intrinsics for mbarrier expect tx * \brief tvm intrinsics for mbarrier expect tx
...@@ -133,7 +133,7 @@ const Op &mbarrier_wait_parity(); ...@@ -133,7 +133,7 @@ const Op &mbarrier_wait_parity();
* mbarrier_expect_tx(mbarrier, transaction_bytes) * mbarrier_expect_tx(mbarrier, transaction_bytes)
* *
*/ */
const Op &mbarrier_expect_tx(); TVM_DLL const Op &mbarrier_expect_tx();
/*! /*!
* \brief tvm intrinsics for ldmatrix * \brief tvm intrinsics for ldmatrix
...@@ -141,7 +141,7 @@ const Op &mbarrier_expect_tx(); ...@@ -141,7 +141,7 @@ const Op &mbarrier_expect_tx();
* ptx_ldmatirx(transposed, num, shared_addr, local_addr) * ptx_ldmatirx(transposed, num, shared_addr, local_addr)
* *
*/ */
const Op &ptx_ldmatirx(); TVM_DLL const Op &ptx_ldmatirx();
/*! /*!
* \brief tvm intrinsics for stmatrix * \brief tvm intrinsics for stmatrix
...@@ -149,7 +149,7 @@ const Op &ptx_ldmatirx(); ...@@ -149,7 +149,7 @@ const Op &ptx_ldmatirx();
* ptx_ldmatirx(transposed, num, shared_addr, int32_values...) * ptx_ldmatirx(transposed, num, shared_addr, int32_values...)
* *
*/ */
const Op &ptx_stmatirx(); TVM_DLL const Op &ptx_stmatrix();
/*! /*!
* \brief Pack two b16 value into a b32 value * \brief Pack two b16 value into a b32 value
...@@ -157,7 +157,7 @@ const Op &ptx_stmatirx(); ...@@ -157,7 +157,7 @@ const Op &ptx_stmatirx();
* int32 pack_b16(b16_value, b16_value) * int32 pack_b16(b16_value, b16_value)
* *
*/ */
const Op &pack_b16(); TVM_DLL const Op &pack_b16();
/*! /*!
* \brief Similar to __syncthreads(), but can be used to sync partial threads * \brief Similar to __syncthreads(), but can be used to sync partial threads
...@@ -165,7 +165,7 @@ const Op &pack_b16(); ...@@ -165,7 +165,7 @@ const Op &pack_b16();
* sync_thread_partial(num_partial_threads or mbarrier) * sync_thread_partial(num_partial_threads or mbarrier)
* *
*/ */
const Op &sync_thread_partial(); TVM_DLL const Op &sync_thread_partial();
/*! /*!
* \brief Issue a shared memory fence for async operations * \brief Issue a shared memory fence for async operations
...@@ -173,7 +173,7 @@ const Op &sync_thread_partial(); ...@@ -173,7 +173,7 @@ const Op &sync_thread_partial();
* FenceProxyAsync() * FenceProxyAsync()
* *
*/ */
const Op &fence_proxy_async(); TVM_DLL const Op &fence_proxy_async();
/*! /*!
* \brief Indicate arrival of warp issuing TMA_STORE * \brief Indicate arrival of warp issuing TMA_STORE
...@@ -181,7 +181,7 @@ const Op &fence_proxy_async(); ...@@ -181,7 +181,7 @@ const Op &fence_proxy_async();
* tma_store_arrive() * tma_store_arrive()
* *
*/ */
const Op &tma_store_arrive(); TVM_DLL const Op &tma_store_arrive();
/*! /*!
* \brief Wait for TMA_STORE to finish * \brief Wait for TMA_STORE to finish
...@@ -189,7 +189,7 @@ const Op &tma_store_arrive(); ...@@ -189,7 +189,7 @@ const Op &tma_store_arrive();
* tma_store_wait() * tma_store_wait()
* *
*/ */
const Op &tma_store_wait(); TVM_DLL const Op &tma_store_wait();
/*! /*!
* \brief Set reg hint for warp-specialized branched * \brief Set reg hint for warp-specialized branched
...@@ -197,7 +197,7 @@ const Op &tma_store_wait(); ...@@ -197,7 +197,7 @@ const Op &tma_store_wait();
* SetMaxNRegInc(num_reg, is_inc) * SetMaxNRegInc(num_reg, is_inc)
* *
*/ */
const Op &set_max_nreg(); TVM_DLL const Op &set_max_nreg();
/*! /*!
* \brief No set reg hint for warp-specialized branched * \brief No set reg hint for warp-specialized branched
...@@ -205,7 +205,7 @@ const Op &set_max_nreg(); ...@@ -205,7 +205,7 @@ const Op &set_max_nreg();
* no_set_max_nreg() * no_set_max_nreg()
* *
*/ */
const Op &no_set_max_nreg(); TVM_DLL const Op &no_set_max_nreg();
/*! /*!
* \brief Wait the previous wgmma to finish * \brief Wait the previous wgmma to finish
...@@ -213,7 +213,7 @@ const Op &no_set_max_nreg(); ...@@ -213,7 +213,7 @@ const Op &no_set_max_nreg();
* wait_wgmma(num_mma) * wait_wgmma(num_mma)
* *
*/ */
const Op &wait_wgmma(); TVM_DLL const Op &wait_wgmma();
/*! /*!
* \brief Synchronize all threads in a grid * \brief Synchronize all threads in a grid
...@@ -221,7 +221,7 @@ const Op &wait_wgmma(); ...@@ -221,7 +221,7 @@ const Op &wait_wgmma();
* sync_grid() * sync_grid()
* *
*/ */
const Op &sync_grid(); TVM_DLL const Op &sync_grid();
/*! /*!
* \brief tvm intrinsic for loop continue * \brief tvm intrinsic for loop continue
...@@ -229,7 +229,7 @@ const Op &sync_grid(); ...@@ -229,7 +229,7 @@ const Op &sync_grid();
* loop_break() * loop_break()
* *
*/ */
const Op &loop_break(); TVM_DLL const Op &loop_break();
/*! /*!
* \brief tvm intrinsic for amd matrix core mfma instructions. * \brief tvm intrinsic for amd matrix core mfma instructions.
......
...@@ -302,7 +302,7 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -302,7 +302,7 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
num = 2; num = 2;
Array<PrimExpr> args; Array<PrimExpr> args;
const Op &op = is_ldmatrix ? tl::ptx_ldmatirx() : tl::ptx_stmatirx(); const Op &op = is_ldmatrix ? tl::ptx_ldmatirx() : tl::ptx_stmatrix();
args.push_back(static_cast<int>(is_transposed)); args.push_back(static_cast<int>(is_transposed));
args.push_back(num); args.push_back(num);
......
...@@ -695,7 +695,7 @@ void CodeGenTileLangCUDA::PrintStorageScope(const std::string &scope, ...@@ -695,7 +695,7 @@ void CodeGenTileLangCUDA::PrintStorageScope(const std::string &scope,
ICHECK_NE(scope, "global") ICHECK_NE(scope, "global")
<< "Cannot allocate global memory when targeting CUDA. You must pass " << "Cannot allocate global memory when targeting CUDA. You must pass "
"all global arrays as input instead"; "all global arrays as input instead";
if (scope == "shared") { if (scope == "shared" || scope == "shared.barrier") {
os << "__shared__ "; os << "__shared__ ";
} else if (scope == "shared.dyn") { } else if (scope == "shared.dyn") {
os << "extern __shared__ __align__(1024) "; os << "extern __shared__ __align__(1024) ";
...@@ -943,6 +943,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -943,6 +943,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
this->stream << ss.str(); this->stream << ss.str();
this->stream << ");\n"; this->stream << ");\n";
}; };
auto print_mbarrier_obj = [&](PrimExpr barrier_id) {
std::ostringstream ss;
if (barrier_id.as<IntImmNode>()) {
// incase the barrier_id is an integer, we need to print the barrier_id as
// an integer
ss << mbarrier_name_ << "[" << barrier_id << "]";
} else {
// otherwise may be a T.get_mbarrier() call or BufferLoad Node
// we need to print the barrier_id as a string
ss << this->PrintExpr(barrier_id);
}
return ss.str();
};
if (op->op.same_as(builtin::ptx_cp_async())) { if (op->op.same_as(builtin::ptx_cp_async())) {
std::string dst = this->PrintExpr(op->args[0]); std::string dst = this->PrintExpr(op->args[0]);
std::string dst_offset = this->PrintExpr(op->args[1]); std::string dst_offset = this->PrintExpr(op->args[1]);
...@@ -971,25 +984,73 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -971,25 +984,73 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
} else if (op->op.same_as(builtin::create_barriers())) { } else if (op->op.same_as(builtin::create_barriers())) {
this->PrintIndent(); this->PrintIndent();
int barrier_count = Downcast<IntImm>(op->args[0])->value; int barrier_count = Downcast<IntImm>(op->args[0])->value;
std::string barrier_name = "_mbarrier"; auto mbarrier_storage_name = mbarrier_name_ + "_mem";
this->stream << "__shared__ uint64_t " << barrier_name << "[" this->stream << "__shared__ uint64_t " << mbarrier_storage_name << "["
<< barrier_count << "];\n"; << barrier_count << "];\n";
this->PrintIndent();
this->stream << "auto " << mbarrier_name_ << " = reinterpret_cast<"
<< mbarrier_dtype_ << "*>(" << mbarrier_storage_name << ");\n";
} else if (op->op.same_as(tl::get_mbarrier())) { } else if (op->op.same_as(tl::get_mbarrier())) {
std::string barrier_name = "_mbarrier"; ICHECK_EQ(op->args.size(), 1);
std::string barrier_id = this->PrintExpr(op->args[0]); std::string barrier_id = this->PrintExpr(op->args[0]);
os << barrier_name + "[" + barrier_id + "]"; os << mbarrier_name_ + "[" + barrier_id + "]";
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) { } else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
print_extern_call_stmt("tl::mbarrier_arrive"); if (op->args.size() == 1) {
this->PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
this->stream << mbarrier_obj << ".arrive();\n";
} else if (op->args.size() == 3) {
this->PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto cta_id = this->PrintExpr(op->args[1]);
auto pred = this->PrintExpr(op->args[2]);
this->stream << mbarrier_obj << ".arrive(" << cta_id << ", " << pred
<< ");\n";
} else {
LOG(FATAL) << "Invalid parameter for tl::arrive_barrier "
<< op->args.size();
}
} else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) { } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
print_extern_call_stmt("tl::mbarrier_init"); ICHECK_EQ(op->args.size(), 2);
this->PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto arrive_count = this->PrintExpr(op->args[1]);
this->stream << mbarrier_obj << ".init(" << arrive_count << ");\n";
} else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) { } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
print_extern_call_stmt("tl::mbarrier_arrive_expect_tx"); if (op->args.size() == 2) {
this->PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto transaction_bytes = this->PrintExpr(op->args[1]);
this->stream << mbarrier_obj << ".arrive_and_expect_tx("
<< transaction_bytes << ");\n";
} else if (op->args.size() == 4) {
this->PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto transaction_bytes = this->PrintExpr(op->args[1]);
auto cta_id = this->PrintExpr(op->args[2]);
auto pred = this->PrintExpr(op->args[3]);
this->stream << mbarrier_obj << ".arrive_and_expect_tx("
<< transaction_bytes << ", " << cta_id << ", " << pred
<< ");\n";
} else {
LOG(FATAL) << "Invalid parameter for tl::arrive_barrier_expect_tx "
<< op->args.size();
}
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) { } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
print_extern_call_stmt("tl::mbarrier_cp_async_arrive"); print_extern_call_stmt("tl::mbarrier_cp_async_arrive");
} else if (op->op.same_as(tl::mbarrier_expect_tx())) { } else if (op->op.same_as(tl::mbarrier_expect_tx())) {
print_extern_call_stmt("tl::mbarrier_expect_tx"); ICHECK_EQ(op->args.size(), 2);
this->PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto transaction_bytes = this->PrintExpr(op->args[1]);
this->stream << mbarrier_obj << ".expect_transaction(" << transaction_bytes
<< ");\n";
} else if (op->op.same_as(tl::mbarrier_wait_parity())) { } else if (op->op.same_as(tl::mbarrier_wait_parity())) {
print_extern_call_stmt("tl::mbarrier_wait"); ICHECK_EQ(op->args.size(), 2);
this->PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto phase = this->PrintExpr(op->args[1]);
this->stream << mbarrier_obj << ".wait(" << phase << ");\n";
} else if (op->op.same_as(tl::sync_thread_partial())) { } else if (op->op.same_as(tl::sync_thread_partial())) {
print_extern_call_stmt("cutlass::arch::NamedBarrier::sync"); print_extern_call_stmt("cutlass::arch::NamedBarrier::sync");
} else if (op->op.same_as(tl::no_set_max_nreg())) { } else if (op->op.same_as(tl::no_set_max_nreg())) {
...@@ -1008,11 +1069,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1008,11 +1069,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
} }
auto desc = op->args[0]; auto desc = op->args[0];
ss << this->PrintExpr(desc) << ", "; ss << this->PrintExpr(desc) << ", ";
if (const IntImmNode *imm = op->args[1].as<IntImmNode>()) { ss << print_mbarrier_obj(op->args[1]) << ", ";
ss << "_mbarrier[" << imm->value << "], ";
} else {
ss << this->PrintExpr(op->args[1]) << ", ";
}
for (size_t i = 2; i < op->args.size() - 1; i++) { for (size_t i = 2; i < op->args.size() - 1; i++) {
if (i > 2) if (i > 2)
ss << ", "; ss << ", ";
...@@ -1050,7 +1107,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1050,7 +1107,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
if (trans == 1) if (trans == 1)
func_name += "_trans"; func_name += "_trans";
print_extern_call_stmt(func_name, 2); print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::ptx_stmatirx())) { } else if (op->op.same_as(tl::ptx_stmatrix())) {
int trans = Downcast<IntImm>(op->args[0])->value; int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value; int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num); std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num);
...@@ -1370,13 +1427,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1370,13 +1427,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
int n = Downcast<IntImm>(op->args[0])->value; int n = Downcast<IntImm>(op->args[0])->value;
this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n
<< ";\");\n\n"; << ";\");\n\n";
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
need_cast_smem_ptr_to_int_ = true;
int barrier_id = Downcast<IntImm>(op->args[0])->value;
CHECK(barrier_id < barrier_count_);
std::string barrier =
barrier_name_ + "[" + std::to_string(barrier_id) + "]";
this->stream << PrintCpAsyncBarrierAsm(barrier);
} else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) { } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
need_cast_smem_ptr_to_int_ = true; need_cast_smem_ptr_to_int_ = true;
int barrier_id = Downcast<IntImm>(op->args[0])->value; int barrier_id = Downcast<IntImm>(op->args[0])->value;
...@@ -1407,22 +1457,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1407,22 +1457,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
std::string barrier = std::string barrier =
barrier_name_ + "[" + std::to_string(barrier_id) + "]"; barrier_name_ + "[" + std::to_string(barrier_id) + "]";
this->stream << PrintWaitBarrierAsm(barrier); this->stream << PrintWaitBarrierAsm(barrier);
} else if (op->op.same_as(builtin::create_barriers())) {
CHECK_EQ(barrier_count_, -1);
int barrier_count = Downcast<IntImm>(op->args[0])->value;
// pad barrier alignment to avoid runtime alignment errors
CHECK_EQ(barrier_alignment_bytes_ % sizeof(uint64_t), 0);
int barrier_alignment_count = barrier_alignment_bytes_ / sizeof(uint64_t);
if (barrier_count % barrier_alignment_count != 0) {
barrier_count = ((barrier_count / barrier_alignment_count) + 1) *
barrier_alignment_count;
}
barrier_count_ = barrier_count;
this->stream << "__shared__ __align__(" << barrier_alignment_bytes_
<< ") uint64_t " << barrier_name_ << "[" << barrier_count
<< "];\n";
this->stream << "for (int i = 0; i < " << barrier_count << "; ++i) { "
<< barrier_name_ << "[i] = 0; }\n";
} else if (op->op.same_as(builtin::ptx_ldg32())) { } else if (op->op.same_as(builtin::ptx_ldg32())) {
/* /*
asm volatile ( asm volatile (
...@@ -1654,6 +1688,12 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) { ...@@ -1654,6 +1688,12 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
} }
if (scope == "shared") { if (scope == "shared") {
stream << ' ' << vid << '[' << constant_size << "];\n"; stream << ' ' << vid << '[' << constant_size << "];\n";
} else if (scope == "shared.barrier") {
auto v_id_mem = vid + "_mem";
stream << ' ' << v_id_mem << "[" << constant_size << "];\n";
PrintIndent();
stream << "auto " << vid << " = reinterpret_cast<" << mbarrier_dtype_
<< "*>(" << v_id_mem << ");\n";
} else if (scope == "local") { } else if (scope == "local") {
stream << ' ' << vid << '[' << constant_size << "];\n"; stream << ' ' << vid << '[' << constant_size << "];\n";
} else if (scope == "local.var") { } else if (scope == "local.var") {
......
...@@ -114,6 +114,10 @@ private: ...@@ -114,6 +114,10 @@ private:
const std::string barrier_name_ = "barrier"; const std::string barrier_name_ = "barrier";
// The size of the barrier array in shared memory // The size of the barrier array in shared memory
int barrier_count_ = -1; int barrier_count_ = -1;
// The name of the mbarrier array in shared memory
const std::string mbarrier_name_ = "mbarrier";
// The type name of the mbarrier array
const std::string mbarrier_dtype_ = "Barrier";
// The alignment of the barrier array in shared memory // The alignment of the barrier array in shared memory
// Set to 16 to maintain minimum alignment requirements for async bulk copy // Set to 16 to maintain minimum alignment requirements for async bulk copy
const int barrier_alignment_bytes_ = 16; const int barrier_alignment_bytes_ = 16;
......
...@@ -785,31 +785,9 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -785,31 +785,9 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
int n = Downcast<IntImm>(op->args[0])->value; int n = Downcast<IntImm>(op->args[0])->value;
std::string func_name = "tl::cp_async_wait<" + std::to_string(n) + ">"; std::string func_name = "tl::cp_async_wait<" + std::to_string(n) + ">";
print_extern_call_stmt(func_name, 1); print_extern_call_stmt(func_name, 1);
} else if (op->op.same_as(builtin::create_barriers())) {
this->PrintIndent();
int barrier_count = Downcast<IntImm>(op->args[0])->value;
std::string barrier_name = "_mbarrier";
this->stream << "__shared__ uint64_t " << barrier_name << "["
<< barrier_count << "];\n";
} else if (op->op.same_as(tl::get_mbarrier())) {
std::string barrier_name = "_mbarrier";
std::string barrier_id = this->PrintExpr(op->args[0]);
os << barrier_name + "[" + barrier_id + "]";
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
print_extern_call_stmt("tl::mbarrier_arrive");
} else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
print_extern_call_stmt("tl::mbarrier_init");
} else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
print_extern_call_stmt("tl::mbarrier_arrive_expect_tx");
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
print_extern_call_stmt("tl::mbarrier_cp_async_arrive");
} else if (op->op.same_as(tl::mbarrier_expect_tx())) {
print_extern_call_stmt("tl::mbarrier_expect_tx");
} else if (op->op.same_as(tl::mbarrier_wait_parity())) {
print_extern_call_stmt("tl::mbarrier_wait");
} else if (op->op.same_as(tl::sync_thread_partial())) { } else if (op->op.same_as(tl::sync_thread_partial())) {
print_extern_call_stmt("tl::syncthreads_partial"); print_extern_call_stmt("tl::syncthreads_partial");
} else if (op->op.same_as(tl::ptx_stmatirx())) { } else if (op->op.same_as(tl::ptx_stmatrix())) {
int trans = Downcast<IntImm>(op->args[0])->value; int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value; int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num); std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num);
......
#pragma once
#include "common.h"
#include <cutlass/arch/barrier.h>
// Reuse cutlass advanced barrier abstraction
using Barrier = cutlass::arch::ClusterTransactionBarrier;
namespace tl {
TL_DEVICE void mbarrier_init(uint64_t &smem_barrier, uint32_t arrive_count) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.init.shared.b64 [%1], %0;"
:
: "r"(arrive_count), "r"(smem_int_ptr));
}
TL_DEVICE uint32_t mbarrier_try_wait(uint64_t &smem_barrier, int phase_bit) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
uint32_t waitComplete;
asm volatile("{\n\t"
".reg .pred P1; \n\t"
"mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P1; \n\t"
"}"
: "=r"(waitComplete)
: "r"(smem_int_ptr), "r"(phase_bit));
return waitComplete;
}
TL_DEVICE void mbarrier_wait(uint64_t &smem_barrier, int phase_bit) {
if (mbarrier_try_wait(smem_barrier, phase_bit) == 0) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
// Arbitrarily large timer value after which try-wait expires and re-tries.
uint32_t ticks = 0x989680;
asm volatile("{\n\t"
".reg .pred P1; \n\t"
"LAB_WAIT: \n\t"
"mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t"
"@P1 bra DONE; \n\t"
"bra LAB_WAIT; \n\t"
"DONE: \n\t"
"}"
:
: "r"(smem_int_ptr), "r"(phase_bit), "r"(ticks));
}
}
TL_DEVICE void mbarrier_test_wait(uint64_t &smem_barrier, int phase_bit) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile(
"{\n"
".reg .pred P1;\n"
"LAB_WAIT:\n"
"mbarrier.test_wait.parity.shared::cta.b64 P1, [%0], %1;\n"
"@P1 bra.uni DONE;\n"
"nanosleep.u32 5;\n" // wait a few nanoseconds on pre-Hopper architectures
// to save instruction issue slots
"bra.uni LAB_WAIT;\n"
"DONE:\n"
"}\n" ::"r"(smem_int_ptr),
"r"(phase_bit));
}
TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.arrive.shared.b64 _, [%0];" : : "r"(smem_int_ptr));
}
TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier, int cta_id,
uint32_t pred) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
if (pred) {
asm volatile("{\n\t"
".reg .b32 remAddr32;\n\t"
"mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t"
"mbarrier.arrive.shared::cluster.b64 _, [remAddr32];\n\t"
"}"
:
: "r"(smem_int_ptr), "r"(cta_id));
}
}
TL_DEVICE void mbarrier_expect_tx(uint64_t &smem_barrier,
uint32_t transaction_bytes) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.expect_tx.shared.b64 [%1], %0;"
:
: "r"(transaction_bytes), "r"(smem_int_ptr));
}
TL_DEVICE void mbarrier_arrive_expect_tx(uint64_t &smem_barrier,
uint32_t transaction_bytes) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%1], %0;"
:
: "r"(transaction_bytes), "r"(smem_int_ptr));
}
template <typename BarrierType = uint64_t>
TL_DEVICE void mbarrier_cp_async_arrive(BarrierType &smem_mbar) {
uint32_t smem_int_mbar;
if constexpr (std::is_pointer_v<BarrierType>) {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(smem_mbar));
} else {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar));
}
asm volatile("cp.async.mbarrier.arrive.shared.b64 [%0];"
:
: "r"(smem_int_mbar));
}
TL_DEVICE void fence_proxy_async() {
asm volatile("fence.proxy.async.shared::cta;" : :);
}
// Indicate arrival of warp issuing TMA_STORE
TL_DEVICE void tma_store_arrive() {
asm volatile("cp.async.bulk.commit_group;");
}
template <int Count> TL_DEVICE void tma_store_wait() {
asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(Count) : "memory");
}
TL_DEVICE void syncthreads_partial(uint64_t &smem_barrier) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
uint64_t state = 0;
asm volatile("{\n"
".reg .pred P1;\n"
"mbarrier.arrive.shared.b64 %1, [%0];\n"
"LAB_WAIT:\n"
"mbarrier.try_wait.shared.b64 P1, [%0], %1;\n"
"@!P1 bra.uni LAB_WAIT;\n"
"}\n"
:
: "r"(smem_int_ptr), "l"(state));
}
} // namespace tl
...@@ -250,4 +250,12 @@ template <int thread_extent> TL_DEVICE bool tl_shuffle_elect() { ...@@ -250,4 +250,12 @@ template <int thread_extent> TL_DEVICE bool tl_shuffle_elect() {
cute::elect_one_sync(); cute::elect_one_sync();
} }
template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_alloc() {
asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount));
}
template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_dealloc() {
asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount));
}
} // namespace tl } // namespace tl
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <cuda.h> #include <cuda.h>
#endif #endif
#include "barrier.h"
#include "common.h" #include "common.h"
namespace tl { namespace tl {
...@@ -13,9 +14,11 @@ enum class CacheHintSm90 : uint64_t { ...@@ -13,9 +14,11 @@ enum class CacheHintSm90 : uint64_t {
EVICT_LAST = 0x14F0000000000000, EVICT_LAST = 0x14F0000000000000,
}; };
TL_DEVICE void tma_load(void *smem_ptr, void *gmem_ptr, uint64_t &smem_mbar, template <typename BarrierType = uint64_t>
TL_DEVICE void tma_load(void *smem_ptr, void *gmem_ptr, BarrierType &smem_mbar,
uint32_t size) { uint32_t size) {
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); uint32_t smem_int_mbar =
smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar));
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::" asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::"
"bytes [%0], [%1], %2, [%3]; \n" ::"r"(smem_int_ptr), "bytes [%0], [%1], %2, [%3]; \n" ::"r"(smem_int_ptr),
...@@ -35,11 +38,17 @@ TL_DEVICE void tma_load_multicast(void *smem_ptr, void *gmem_ptr, ...@@ -35,11 +38,17 @@ TL_DEVICE void tma_load_multicast(void *smem_ptr, void *gmem_ptr,
:); :);
} }
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL> template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL,
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, typename BarrierType = uint64_t>
TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar,
void const *const smem_ptr, int32_t const &crd0) { void const *const smem_ptr, int32_t const &crd0) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); uint32_t smem_int_mbar;
if constexpr (std::is_pointer_v<BarrierType>) {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(smem_mbar));
} else {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar));
}
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::" asm volatile("cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::"
"complete_tx::bytes.L2::cache_hint" "complete_tx::bytes.L2::cache_hint"
...@@ -50,12 +59,18 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, ...@@ -50,12 +59,18 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
: "memory"); : "memory");
} }
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL> template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL,
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, typename BarrierType = uint64_t>
TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar,
void const *const smem_ptr, int32_t const &crd0, void const *const smem_ptr, int32_t const &crd0,
int32_t const &crd1) { int32_t const &crd1) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); uint32_t smem_int_mbar;
if constexpr (std::is_pointer_v<BarrierType>) {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(smem_mbar));
} else {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar));
}
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::" asm volatile("cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::"
"complete_tx::bytes.L2::cache_hint" "complete_tx::bytes.L2::cache_hint"
...@@ -66,12 +81,18 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, ...@@ -66,12 +81,18 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
: "memory"); : "memory");
} }
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL> template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL,
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, typename BarrierType = uint64_t>
TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar,
void const *const smem_ptr, int32_t const &crd0, void const *const smem_ptr, int32_t const &crd0,
int32_t const &crd1, int32_t const &crd2) { int32_t const &crd1, int32_t const &crd2) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); uint32_t smem_int_mbar;
if constexpr (std::is_pointer_v<BarrierType>) {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(smem_mbar));
} else {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar));
}
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::" asm volatile("cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::"
"complete_tx::bytes.L2::cache_hint" "complete_tx::bytes.L2::cache_hint"
...@@ -81,13 +102,19 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, ...@@ -81,13 +102,19 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
"r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint)
: "memory"); : "memory");
} }
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL> template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL,
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, typename BarrierType = uint64_t>
TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar,
void const *const smem_ptr, int32_t const &crd0, void const *const smem_ptr, int32_t const &crd0,
int32_t const &crd1, int32_t const &crd2, int32_t const &crd1, int32_t const &crd2,
int32_t const &crd3) { int32_t const &crd3) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); uint32_t smem_int_mbar;
if constexpr (std::is_pointer_v<BarrierType>) {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(smem_mbar));
} else {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar));
}
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::" asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::"
"complete_tx::bytes.L2::cache_hint" "complete_tx::bytes.L2::cache_hint"
...@@ -98,13 +125,19 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, ...@@ -98,13 +125,19 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
: "memory"); : "memory");
} }
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL> template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL,
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, typename BarrierType = uint64_t>
TL_DEVICE void tma_load(const CUtensorMap &descriptor, BarrierType &smem_mbar,
void const *const smem_ptr, int32_t const &crd0, void const *const smem_ptr, int32_t const &crd0,
int32_t const &crd1, int32_t const &crd2, int32_t const &crd1, int32_t const &crd2,
int32_t const &crd3, int32_t const &crd4) { int32_t const &crd3, int32_t const &crd4) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); uint32_t smem_int_mbar;
if constexpr (std::is_pointer_v<BarrierType>) {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(smem_mbar));
} else {
smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar));
}
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::" asm volatile("cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::"
"complete_tx::bytes.L2::cache_hint" "complete_tx::bytes.L2::cache_hint"
...@@ -116,15 +149,17 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar, ...@@ -116,15 +149,17 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
: "memory"); : "memory");
} }
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL> template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL,
TL_DEVICE void tma_load_im2col(const CUtensorMap &descriptor, typename BarrierType = uint64_t>
uint64_t &smem_mbar, void const *const smem_ptr, TL_DEVICE void
int32_t const &coord_c, int32_t const &coord_w, tma_load_im2col(const CUtensorMap &descriptor, BarrierType &smem_mbar,
int32_t const &coord_h, int32_t const &coord_n, void const *const smem_ptr, int32_t const &coord_c,
uint16_t const &offset_w, int32_t const &coord_w, int32_t const &coord_h,
uint16_t const &offset_h) { int32_t const &coord_n, uint16_t const &offset_w,
uint16_t const &offset_h) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar); uint32_t smem_int_mbar =
smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar));
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier:" asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier:"
":complete_tx::bytes.L2::cache_hint" ":complete_tx::bytes.L2::cache_hint"
...@@ -212,138 +247,4 @@ TL_DEVICE void prefetch_tma_descriptor(const CUtensorMap &descriptor) { ...@@ -212,138 +247,4 @@ TL_DEVICE void prefetch_tma_descriptor(const CUtensorMap &descriptor) {
asm volatile("prefetch.tensormap [%0];" : : "l"(gmem_int_desc) : "memory"); asm volatile("prefetch.tensormap [%0];" : : "l"(gmem_int_desc) : "memory");
} }
TL_DEVICE void mbarrier_init(uint64_t &smem_barrier, uint32_t arrive_count) { } // namespace tl
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.init.shared.b64 [%1], %0;"
:
: "r"(arrive_count), "r"(smem_int_ptr));
}
TL_DEVICE uint32_t mbarrier_try_wait(uint64_t &smem_barrier, int phase_bit) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
uint32_t waitComplete;
asm volatile("{\n\t"
".reg .pred P1; \n\t"
"mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P1; \n\t"
"}"
: "=r"(waitComplete)
: "r"(smem_int_ptr), "r"(phase_bit));
return waitComplete;
}
TL_DEVICE void mbarrier_wait(uint64_t &smem_barrier, int phase_bit) {
if (mbarrier_try_wait(smem_barrier, phase_bit) == 0) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
// Arbitrarily large timer value after which try-wait expires and re-tries.
uint32_t ticks = 0x989680;
asm volatile("{\n\t"
".reg .pred P1; \n\t"
"LAB_WAIT: \n\t"
"mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t"
"@P1 bra DONE; \n\t"
"bra LAB_WAIT; \n\t"
"DONE: \n\t"
"}"
:
: "r"(smem_int_ptr), "r"(phase_bit), "r"(ticks));
}
}
TL_DEVICE void mbarrier_test_wait(uint64_t &smem_barrier, int phase_bit) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile(
"{\n"
".reg .pred P1;\n"
"LAB_WAIT:\n"
"mbarrier.test_wait.parity.shared::cta.b64 P1, [%0], %1;\n"
"@P1 bra.uni DONE;\n"
"nanosleep.u32 5;\n" // wait a few nanoseconds on pre-Hopper architectures
// to save instruction issue slots
"bra.uni LAB_WAIT;\n"
"DONE:\n"
"}\n" ::"r"(smem_int_ptr),
"r"(phase_bit));
}
TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.arrive.shared.b64 _, [%0];" : : "r"(smem_int_ptr));
}
TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier, int cta_id,
uint32_t pred) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
if (pred) {
asm volatile("{\n\t"
".reg .b32 remAddr32;\n\t"
"mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t"
"mbarrier.arrive.shared::cluster.b64 _, [remAddr32];\n\t"
"}"
:
: "r"(smem_int_ptr), "r"(cta_id));
}
}
TL_DEVICE void mbarrier_expect_tx(uint64_t &smem_barrier,
uint32_t transaction_bytes) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.expect_tx.shared.b64 [%1], %0;"
:
: "r"(transaction_bytes), "r"(smem_int_ptr));
}
TL_DEVICE void mbarrier_arrive_expect_tx(uint64_t &smem_barrier,
uint32_t transaction_bytes) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%1], %0;"
:
: "r"(transaction_bytes), "r"(smem_int_ptr));
}
TL_DEVICE void mbarrier_cp_async_arrive(uint64_t &smem_barrier) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("cp.async.mbarrier.arrive.shared.b64 [%0];"
:
: "r"(smem_int_ptr));
}
TL_DEVICE void fence_proxy_async() {
asm volatile("fence.proxy.async.shared::cta;" : :);
}
// Indicate arrival of warp issuing TMA_STORE
TL_DEVICE void tma_store_arrive() {
asm volatile("cp.async.bulk.commit_group;");
}
template <int Count> TL_DEVICE void tma_store_wait() {
asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(Count) : "memory");
}
TL_DEVICE void syncthreads_partial(uint64_t &smem_barrier) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
uint64_t state = 0;
asm volatile("{\n"
".reg .pred P1;\n"
"mbarrier.arrive.shared.b64 %1, [%0];\n"
"LAB_WAIT:\n"
"mbarrier.try_wait.shared.b64 P1, [%0], %1;\n"
"@!P1 bra.uni LAB_WAIT;\n"
"}\n"
:
: "r"(smem_int_ptr), "l"(state));
}
template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_alloc() {
asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount));
}
template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_dealloc() {
asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount));
}
} // namespace tl
\ No newline at end of file
...@@ -58,7 +58,7 @@ public: ...@@ -58,7 +58,7 @@ public:
Proxy proxy = Proxy::kAsync; Proxy proxy = Proxy::kAsync;
if (auto call = op->value.as<CallNode>()) { if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(ptx_ldmatirx()) || if (call->op.same_as(ptx_ldmatirx()) ||
call->op.same_as(ptx_stmatirx())) { call->op.same_as(ptx_stmatrix())) {
proxy = Proxy::kGeneric; proxy = Proxy::kGeneric;
} }
} }
......
...@@ -44,7 +44,8 @@ class StorageAccessInfoLower : public StmtExprMutator { ...@@ -44,7 +44,8 @@ class StorageAccessInfoLower : public StmtExprMutator {
public: public:
Stmt VisitStmt_(const AllocateNode *op) final { Stmt VisitStmt_(const AllocateNode *op) final {
auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var)); auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var));
if (scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".var") { if (scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".var" &&
scope.tag != ".barrier") {
auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var)); auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var));
ICHECK(info.defined()) ICHECK(info.defined())
<< "Cannot find memory info of " << scope.to_string(); << "Cannot find memory info of " << scope.to_string();
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
* \file lower_shared_barrier.cc * \file lower_shared_barrier.cc
* \brief Convert shared.barrier buffers to plain shared + ptx init. * \brief Convert shared.barrier buffers to plain shared + ptx init.
*/ */
#include "../op/builtin.h"
#include "tvm/ir/type.h" #include "tvm/ir/type.h"
#include "tvm/tir/expr.h" #include "tvm/tir/expr.h"
#include "tvm/tir/stmt.h" #include "tvm/tir/stmt.h"
...@@ -19,12 +20,15 @@ using namespace tir; ...@@ -19,12 +20,15 @@ using namespace tir;
class SharedBarrierRewriter : public StmtExprMutator { class SharedBarrierRewriter : public StmtExprMutator {
public: public:
static Stmt Rewrite(Stmt body) { static Stmt Rewrite(Stmt body, bool disable_shuffle_elect = false) {
SharedBarrierRewriter rewriter; SharedBarrierRewriter rewriter(disable_shuffle_elect);
return rewriter(body); return rewriter(body);
} }
private: private:
SharedBarrierRewriter(bool disable_shuffle_elect)
: disable_shuffle_elect_(disable_shuffle_elect) {}
Stmt VisitStmt_(const BlockNode *op) final { Stmt VisitStmt_(const BlockNode *op) final {
Block block = GetRef<Block>(op); Block block = GetRef<Block>(op);
Array<Buffer> alloc_buffers = op->alloc_buffers; Array<Buffer> alloc_buffers = op->alloc_buffers;
...@@ -74,25 +78,12 @@ private: ...@@ -74,25 +78,12 @@ private:
T.ptx_init_barrier_thread_count(data_is_ready[0], 128) T.ptx_init_barrier_thread_count(data_is_ready[0], 128)
T.ptx_init_barrier_thread_count(compute_is_done[0], 128) T.ptx_init_barrier_thread_count(compute_is_done[0], 128)
*/ */
// 1. create new data vars
Array<Var> new_data_vars;
for (auto buffer : barrier_buffers) {
auto data = buffer->data;
auto ptr_type = data->type_annotation.as<PointerTypeNode>();
auto new_data =
Var(data->name_hint, PointerType(ptr_type->element_type, "shared"));
var_remap_.Set(data, new_data);
new_data_vars.push_back(new_data);
}
// 2. create new buffers // 2. create new buffers
Array<Buffer> new_buffers; Array<Buffer> new_buffers;
for (auto buffer : barrier_buffers) { for (auto buffer : barrier_buffers) {
auto data = buffer->data; auto data = buffer->data;
ICHECK(var_remap_.find(data) != var_remap_.end()) auto new_buffer = Buffer(data, buffer->dtype, Array<PrimExpr>({1}),
<< "data not found in var_remap_";
auto new_data = var_remap_.at(data);
auto new_buffer = Buffer(new_data, buffer->dtype, Array<PrimExpr>({1}),
Array<PrimExpr>({1}), PrimExpr(0), buffer->name, Array<PrimExpr>({1}), PrimExpr(0), buffer->name,
buffer->data_alignment, buffer->offset_factor, buffer->data_alignment, buffer->offset_factor,
buffer->buffer_type); buffer->buffer_type);
...@@ -128,8 +119,14 @@ private: ...@@ -128,8 +119,14 @@ private:
} }
Array<Stmt> new_body; Array<Stmt> new_body;
new_body.push_back(IfThenElse(EQ(thread_var_->var, 0), PrimExpr condition;
SeqStmt(init_mbarrier_calls_), Stmt())); if (!disable_shuffle_elect_) {
condition = Call(DataType::Bool(), tl_shuffle_elect(), {0});
} else {
condition = EQ(thread_var_->var, 0);
}
new_body.push_back(
IfThenElse(condition, SeqStmt(init_mbarrier_calls_), Stmt()));
new_body.push_back( new_body.push_back(
Evaluate(Call(DataType::Handle(), builtin::tvm_storage_sync(), Evaluate(Call(DataType::Handle(), builtin::tvm_storage_sync(),
{StringImm("shared")}))); {StringImm("shared")})));
...@@ -146,12 +143,6 @@ private: ...@@ -146,12 +143,6 @@ private:
if (buffer_remap_.count(buffer)) { if (buffer_remap_.count(buffer)) {
auto new_buffer = buffer_remap_[load->buffer]; auto new_buffer = buffer_remap_[load->buffer];
return BufferLoad(new_buffer, load->indices); return BufferLoad(new_buffer, load->indices);
} else if (var_remap_.count(buffer->data)) {
auto new_buffer = Buffer(
var_remap_[buffer->data], buffer->dtype, buffer->shape,
buffer->strides, buffer->elem_offset, buffer->name,
buffer->data_alignment, buffer->offset_factor, buffer->buffer_type);
return BufferLoad(new_buffer, load->indices);
} }
return load; return load;
} }
...@@ -162,12 +153,6 @@ private: ...@@ -162,12 +153,6 @@ private:
if (buffer_remap_.count(buffer)) { if (buffer_remap_.count(buffer)) {
auto new_buffer = buffer_remap_[store->buffer]; auto new_buffer = buffer_remap_[store->buffer];
return BufferStore(new_buffer, store->value, store->indices); return BufferStore(new_buffer, store->value, store->indices);
} else if (var_remap_.count(buffer->data)) {
auto new_buffer = Buffer(
var_remap_[buffer->data], buffer->dtype, buffer->shape,
buffer->strides, buffer->elem_offset, buffer->name,
buffer->data_alignment, buffer->offset_factor, buffer->buffer_type);
return BufferStore(new_buffer, store->value, store->indices);
} }
return store; return store;
} }
...@@ -186,16 +171,17 @@ private: ...@@ -186,16 +171,17 @@ private:
// This is a workaround for cpu backend, // This is a workaround for cpu backend,
// we need to define a thread_var for the serial loop. // we need to define a thread_var for the serial loop.
IterVar thread_var_; IterVar thread_var_;
Map<Var, Var> var_remap_;
Map<Var, Buffer> buffer_data_to_buffer_; Map<Var, Buffer> buffer_data_to_buffer_;
Map<Buffer, Buffer> buffer_remap_; Map<Buffer, Buffer> buffer_remap_;
// Mapping from data Var of a Buffer to Buffer, for lookup // Mapping from data Var of a Buffer to Buffer, for lookup
std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_; std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
// Disable shuffle elect for the warp specialized kernel
bool disable_shuffle_elect_;
}; };
PrimFunc LowerSharedBarrier(PrimFunc f) { PrimFunc LowerSharedBarrier(PrimFunc f, bool disable_shuffle_elect) {
SharedBarrierRewriter rewriter; f.CopyOnWrite()->body =
f.CopyOnWrite()->body = rewriter.Rewrite(f->body); SharedBarrierRewriter::Rewrite(f->body, disable_shuffle_elect);
return f; return f;
} }
...@@ -204,7 +190,9 @@ using namespace tir::transform; ...@@ -204,7 +190,9 @@ using namespace tir::transform;
tvm::transform::Pass LowerSharedBarrier() { tvm::transform::Pass LowerSharedBarrier() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return tl::LowerSharedBarrier(std::move(f)); bool disable_shuffle_elect =
ctx->GetConfig<Bool>(kDisableShuffleElect, Bool(false)).value();
return tl::LowerSharedBarrier(std::move(f), disable_shuffle_elect);
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedBarrier", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedBarrier", {});
} }
......
...@@ -672,7 +672,8 @@ private: ...@@ -672,7 +672,8 @@ private:
// memory. Special memory is all combined into a single allocation. // memory. Special memory is all combined into a single allocation.
bool IsSpecialTaggedMemory(const StorageScope &scope) { bool IsSpecialTaggedMemory(const StorageScope &scope) {
return scope.tag.length() != 0 && scope.tag != ".dyn" && return scope.tag.length() != 0 && scope.tag != ".dyn" &&
scope.tag != ".workspace" && scope.tag != ".vtcm"; scope.tag != ".barrier" && scope.tag != ".workspace" &&
scope.tag != ".vtcm";
} }
// Allocate entry of node. // Allocate entry of node.
...@@ -841,7 +842,10 @@ private: ...@@ -841,7 +842,10 @@ private:
ICHECK_NE(e->scope.tag.length(), 0U); ICHECK_NE(e->scope.tag.length(), 0U);
// allocate with element type. // allocate with element type.
ICHECK_NE(e->const_nbits, 0U); ICHECK_NE(e->const_nbits, 0U);
MemoryInfo info = GetMemoryInfo(e->scope.to_string()); MemoryInfo info;
if (e->scope.tag != ".barrier" && e->scope.tag != ".var") {
info = GetMemoryInfo(e->scope.to_string());
}
uint64_t total_bits = e->const_nbits; uint64_t total_bits = e->const_nbits;
// By default, align to 32 bits. // By default, align to 32 bits.
size_t align = 32; size_t align = 32;
...@@ -1784,6 +1788,8 @@ public: ...@@ -1784,6 +1788,8 @@ public:
PrimExpr last_extent = extents[extents.size() - 1]; PrimExpr last_extent = extents[extents.size() - 1];
extents.Set(extents.size() - 1, extents.Set(extents.size() - 1,
last_extent / make_const(last_extent.dtype(), info.factor())); last_extent / make_const(last_extent.dtype(), info.factor()));
LOG(INFO) << "Allocate with " << new_buffer_var << " and "
<< info.new_element_dtype << " extents: " << extents;
return Allocate(new_buffer_var, info.new_element_dtype, extents, return Allocate(new_buffer_var, info.new_element_dtype, extents,
op->condition, op->body); op->condition, op->body);
} }
......
...@@ -14,11 +14,14 @@ ...@@ -14,11 +14,14 @@
#include "../op/builtin.h" #include "../op/builtin.h"
#include "./common/collector.h" #include "./common/collector.h"
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
using namespace runtime;
using arith::IRVisitorWithAnalyzer; using arith::IRVisitorWithAnalyzer;
enum class Role { kConsumer, kProducer, kBoth }; enum class Role { kConsumer, kProducer, kBoth };
...@@ -149,8 +152,8 @@ public: ...@@ -149,8 +152,8 @@ public:
} }
void VisitStmt_(const BufferStoreNode *op) final { void VisitStmt_(const BufferStoreNode *op) final {
bool is_shared_store = auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer->data));
op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared"; bool is_shared_store = scope.rank == StorageRank::kShared;
if (producer_buffers_.count(op->buffer.get())) { if (producer_buffers_.count(op->buffer.get())) {
SetRole(op, Role::kBoth); SetRole(op, Role::kBoth);
return; return;
...@@ -570,29 +573,35 @@ public: ...@@ -570,29 +573,35 @@ public:
class WSCodeEmitter : public StmtMutator { class WSCodeEmitter : public StmtMutator {
public: public:
/** /**
* @brief Construct a warp-specialized code emitter configured for producer or consumer emission. * @brief Construct a warp-specialized code emitter configured for producer or
* * consumer emission.
* Initializes a WSCodeEmitter that will emit barrier-aware, role-filtered code for a single *
* warp-specialized block. The emitter is configured with the loop/thread iteration variable, * Initializes a WSCodeEmitter that will emit barrier-aware, role-filtered
* buffer mapping, role marker used to classify statements, and two flags that control emission * code for a single warp-specialized block. The emitter is configured with
* behavior: * the loop/thread iteration variable, buffer mapping, role marker used to
* * classify statements, and two flags that control emission behavior:
* - `mbarrier_only`: when true, emission is restricted to barrier-related operations only. *
* - `only_has_wgmma`: when true, the emitter will account for the presence of WgMMA * - `mbarrier_only`: when true, emission is restricted to barrier-related
* (workgroup MMA) operations when computing barrier/thread gating behavior. * operations only.
* * - `only_has_wgmma`: when true, the emitter will account for the presence of
* @param is_emitting_producer True to emit producer-side groups; false to emit consumer-side groups. * WgMMA (workgroup MMA) operations when computing barrier/thread gating
* @param thread_iv IterVar representing the thread iteration variable (threadIdx.*) whose Var is used * behavior.
* for thread-index rewrites and gating. *
* @param buffer_data_to_buffer Map from buffer data Var to the corresponding Buffer (used to resolve * @param is_emitting_producer True to emit producer-side groups; false to
* buffer references during emission). * emit consumer-side groups.
* @param marker Role marker that classifies statements as producer/consumer/both; used to filter * @param thread_iv IterVar representing the thread iteration variable
* which statements are emitted on this path. * (threadIdx.*) whose Var is used for thread-index rewrites and gating.
* @param mbarrier_only If true, restrict emission to mbarrier-related statements and helpers. * @param buffer_data_to_buffer Map from buffer data Var to the corresponding
* @param only_has_wgmma If true, adjust emission and barrier-thread-count logic for blocks that * Buffer (used to resolve buffer references during emission).
* contain WgMMA operations. * @param marker Role marker that classifies statements as
*/ * producer/consumer/both; used to filter which statements are emitted on this
WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv, * path.
* @param mbarrier_only If true, restrict emission to mbarrier-related
* statements and helpers.
* @param only_has_wgmma If true, adjust emission and barrier-thread-count
* logic for blocks that contain WgMMA operations.
*/
WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv,
Map<Var, Buffer> buffer_data_to_buffer, Map<Var, Buffer> buffer_data_to_buffer,
const WarpSpecializedRoleMarker &marker, const WarpSpecializedRoleMarker &marker,
bool mbarrier_only = false, bool only_has_wgmma = false) bool mbarrier_only = false, bool only_has_wgmma = false)
...@@ -602,14 +611,15 @@ public: ...@@ -602,14 +611,15 @@ public:
only_has_wgmma_(only_has_wgmma) {} only_has_wgmma_(only_has_wgmma) {}
/** /**
* @brief Whether a SIMT-style bulk copy was detected. * @brief Whether a SIMT-style bulk copy was detected.
* *
* Returns true when a simulated SIMT (thread-parallel) copy pattern was observed * Returns true when a simulated SIMT (thread-parallel) copy pattern was
* during analysis/emission, which can affect barrier insertion and copy emission. * observed during analysis/emission, which can affect barrier insertion and
* * copy emission.
* @return true if a SIMT copy was detected; false otherwise. *
*/ * @return true if a SIMT copy was detected; false otherwise.
bool hasSimtCopy() const { return has_simt_copy_; } */
bool hasSimtCopy() const { return has_simt_copy_; }
private: private:
template <typename NodeType> Stmt FilterByRole(const NodeType *op) { template <typename NodeType> Stmt FilterByRole(const NodeType *op) {
...@@ -628,18 +638,18 @@ private: ...@@ -628,18 +638,18 @@ private:
} }
/** /**
* @brief Visit and transform a SeqStmt node, emitting grouped blocks with barrier * @brief Visit and transform a SeqStmt node, emitting grouped blocks with
* synchronization according to producer/consumer roles. * barrier synchronization according to producer/consumer roles.
* *
* This method examines the sequence to determine whether producer-side * This method examines the sequence to determine whether producer-side
* synchronization is required (based on marker_ roles). If no producer sync is * synchronization is required (based on marker_ roles). If no producer sync
* needed it delegates to FilterByRole. Otherwise it: * is needed it delegates to FilterByRole. Otherwise it:
* - Recursively visits and transforms each child statement. * - Recursively visits and transforms each child statement.
* - Extracts an acquire/release sync pattern for the sequence via * - Extracts an acquire/release sync pattern for the sequence via
* ExtractSyncPattern. * ExtractSyncPattern.
* - For producer emission (is_emitting_producer_ == true): * - For producer emission (is_emitting_producer_ == true):
* - Skips consumer-only statements unless marker_ marks a statement as Both, * - Skips consumer-only statements unless marker_ marks a statement as
* in which case the statement is emitted as its own group. * Both, in which case the statement is emitted as its own group.
* - For each statement, inserts parity waits for acquire patterns, rewrites * - For each statement, inserts parity waits for acquire patterns, rewrites
* release statements with MbarrierRewriter using a computed barrier id, * release statements with MbarrierRewriter using a computed barrier id,
* collects SimT-copy presence (setting has_simt_copy_ and inserting * collects SimT-copy presence (setting has_simt_copy_ and inserting
...@@ -1248,21 +1258,21 @@ private: ...@@ -1248,21 +1258,21 @@ private:
} }
/** /**
* @brief Rewrite a BlockRealize for warp specialization, inserting barriers and * @brief Rewrite a BlockRealize for warp specialization, inserting barriers
* emitting producer/consumer bodies. * and emitting producer/consumer bodies.
* *
* This visitor handles BlockRealize nodes when a thread IterVar (thread_iv_) * This visitor handles BlockRealize nodes when a thread IterVar (thread_iv_)
* is defined and warp-specialization is applicable. It: * is defined and warp-specialization is applicable. It:
* - Determines producer/consumer roles via WarpSpecializedRoleMarker and * - Determines producer/consumer roles via WarpSpecializedRoleMarker and
* returns the original block if no producer is detected. * returns the original block if no producer is detected.
* - If warp specialization is disabled, emits only mbarrier initialization and * - If warp specialization is disabled, emits only mbarrier initialization
* the mbarrier-only transformed body. * and the mbarrier-only transformed body.
* - Otherwise, detects WgMMA usage for the block body and constructs separate * - Otherwise, detects WgMMA usage for the block body and constructs separate
* WSCodeEmitter instances for producer and consumer paths (propagating the * WSCodeEmitter instances for producer and consumer paths (propagating the
* WgMMA flag to the consumer emitter). * WgMMA flag to the consumer emitter).
* - Generates producer/consumer code, applies register hint calls (set_max_nreg) * - Generates producer/consumer code, applies register hint calls
* when available, and rewrites thread indices with ThreadIdxRewriter to * (set_max_nreg) when available, and rewrites thread indices with
* partition threads between producer and consumer roles. * ThreadIdxRewriter to partition threads between producer and consumer roles.
* - Computes and initializes a list of mbarrier handles with per-barrier * - Computes and initializes a list of mbarrier handles with per-barrier
* arrive thread counts (taking SIMT-copy and WgMMA cases into account). * arrive thread counts (taking SIMT-copy and WgMMA cases into account).
* - Wraps the transformed body in an IfThenElse that dispatches producer vs * - Wraps the transformed body in an IfThenElse that dispatches producer vs
......
...@@ -42,6 +42,7 @@ def get_configs(): ...@@ -42,6 +42,7 @@ def get_configs():
} for values in itertools.product(*iter_params.values())] } for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs(),)
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def matmul(M, def matmul(M,
N, N,
...@@ -51,7 +52,7 @@ def matmul(M, ...@@ -51,7 +52,7 @@ def matmul(M,
block_K=32, block_K=32,
num_stages=0, num_stages=0,
thread_num=128, thread_num=128,
enable_rasteration=False): enable_rasterization=False):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -84,7 +85,7 @@ def matmul(M, ...@@ -84,7 +85,7 @@ def matmul(M,
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Enable (or disable) swizzling optimization # Enable (or disable) swizzling optimization
T.use_swizzle(panel_size=10, enable=enable_rasteration) T.use_swizzle(panel_size=10, enable=enable_rasterization)
# Clear out the accumulation buffer # Clear out the accumulation buffer
T.clear(C_local) T.clear(C_local)
......
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
from tilelang import cached from tilelang.cache import cached
import tilelang.language as T import tilelang.language as T
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment