Commit fce16b00 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Separate warp specialize rewriter and tma barrier injector pass (#447)

* [Refactor] Update KernelLaunch to clarify CPU and GPU kernel launch logic

* Added comments to distinguish between CPU and GPU kernel launch sections for better code readability.
* Changed the creation of empty blocks to use a consistent "root" identifier, enhancing clarity in frame management.

* [Refactor] Rename operations for consistency in lower_hopper_intrin and related files

* Updated function names from CamelCase to snake_case for better consistency across the codebase.
* Refactored calls to `CreateTMADescriptorOp`, `CreateListofMBarrierOp`, and similar functions to their new names: `create_tma_descriptor`, `create_list_of_mbarrier`, etc.
* Adjusted corresponding test cases to reflect these changes, ensuring compatibility with the new naming conventions.

* [Refactor] Rename operations to snake_case for consistency

* Updated function names from CamelCase to snake_case across various files, including `CreateTMADescriptorOp` to `create_tma_descriptor`, `GetMBarrierOp` to `get_mbarrier`, and others.
* Adjusted corresponding calls and definitions in the codebase to reflect these naming changes, ensuring uniformity and improved readability.
* Enhanced layout inference and loop partitioning logic to accommodate the new naming conventions.

* [Feature] Introduce Warp Specialization and Eliminate Storage Sync for MBarrier

* Added a new example `gemm_ws.py` demonstrating matrix multiplication with warp specialization using TileLang.
* Implemented `WarpSpecializeFrame` and `WarpSpecialize` functionality to manage warp group indices in TIR frames.
* Introduced `EliminateStorageSyncForMBarrier` transformation to optimize storage synchronization in mbarrier regions.
* Enhanced the TileLang API with new methods for retrieving block and thread extents.
* Updated the `LowerAndLegalize` and `OptimizeForTarget` functions to incorporate the new transformation.
* Improved layout inference and kernel launch logic for better performance and clarity.

* [Refactor] Clean up code formatting and improve readability

* Added blank lines for better separation of code blocks in `gemm_ws.py`, `phase.py`, `kernel.py`, and `warpgroup.py`.
* Reformatted the `tilelang.compile` call in `gemm_ws.py` for improved clarity.
* Updated comments in `warpgroup.py` to clarify the availability of the `WarpSpecialize` function for NVIDIA GPUs.
* Ensured consistent spacing and formatting across multiple files to enhance overall code readability.

* lint fix

* [Refactor] Update mbarrier functions for improved clarity and consistency

* Refactored `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` to accept explicit parameters for better readability.
* Updated calls in `gemm_ws.py` to use the new function signatures, enhancing code clarity.
* Adjusted `warpgroup.py` to remove unused thread extent variable, streamlining the code.
* Added detailed docstrings to clarify usage examples for memory barrier functions.

* Added blank lines in `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` for improved code readability and separation of logical sections.

* [Feature] Add examples for warp specialization and TMA barrier integration

* Introduced three new example scripts: `example_warp_specialize_gemm.py`, `example_warp_specialize_gemm_barrier4.py`, and `example_warp_specialize_mla.py` demonstrating matrix multiplication with warp specialization and TMA barriers.
* Implemented kernel functions with shared memory allocation and memory barrier synchronization for improved performance.
* Enhanced the TileLang API with new methods for compiling and testing kernels in Python using PyTorch.
* Updated the `phase.py` to include TMA barrier injection in the optimization process.
* Improved documentation and comments for better clarity on usage and functionality.

* [Feature] Add example for warp specialization in GEMM with TMA barriers

* Introduced a new example script `example_warp_specialize_gemm_stage2.py` demonstrating matrix multiplication using warp specialization and TMA barriers.
* Implemented a kernel function with shared memory allocation and memory barrier synchronization for enhanced performance.
* Included functionality to compile the kernel into a PyTorch-compatible function and validate its correctness against PyTorch's reference implementation.
* Enhanced documentation and comments for clarity on usage and functionality.

* lint fix

* [Feature] Implement WarpSpecializedDetector for TMA and MBarrier Detection

* Added the `WarpSpecializedDetector` class to identify the presence of TMA operations and memory barrier operations within a given TIR statement.
* Enhanced the `WarpSpecialized` pass to utilize the detector, allowing for conditional substitution based on the detection results.
* Improved code organization by including necessary headers and utilizing the `IRVisitorWithAnalyzer` for analysis.
* This addition aims to optimize warp specialization by ensuring that only relevant functions are transformed, enhancing performance and correctness.

* lint fix
parent e46653ac
import tilelang
import tilelang.language as T
tilelang.disable_cache()
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype, "shared")
B_shared = T.alloc_shared((block_K, block_N), dtype, "shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# create mbarrier for tma
T.create_list_of_mbarrier(128, 128)
# with T.ws(1):
# for ko in range(T.ceildiv(K, block_K)):
# T.mbarrier_wait_parity(1, (ko & 1) ^ 1)
# T.copy(A[by * block_M, ko * block_K], A_shared)
# T.copy(B[ko * block_K, bx * block_N], B_shared)
# T.mbarrier_arrive(0)
# with T.ws(0):
# T.clear(C_local)
# for ko in range(T.ceildiv(K, block_K)):
# T.mbarrier_wait_parity(0, ko & 1)
# T.gemm(A_shared, B_shared, C_local)
# T.mbarrier_arrive(1)
# T.copy(C_local, C[by * block_M, bx * block_N])
with T.ws(0):
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=2):
with T.ws(1):
T.mbarrier_wait_parity(1, (ko & 1) ^ 1)
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.mbarrier_arrive(0)
with T.ws(0):
T.mbarrier_wait_parity(0, ko & 1)
T.gemm(A_shared, B_shared, C_local)
T.mbarrier_arrive(1)
with T.ws(0):
T.copy(C_local, C[by * block_M, bx * block_N])
return main
K = 64
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(128, 128, K, 128, 128, 32)
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
tilelang.disable_cache()
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
execution_backend="cython",
pass_configs={
"tl.disable_warp_specialized": True,
# "tl.disable_tma_lower": True,
})
tilelang.enable_cache()
print(jit_kernel.get_kernel_source())
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(128, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, 128, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = jit_kernel(a, b)
print(c)
# Reference multiplication using PyTorch
ref_c = a @ b
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# # 4. Retrieve and inspect the generated CUDA source (optional)
# # cuda_source = jit_kernel.get_kernel_source()
# # print("Generated CUDA kernel:\n", cuda_source)
# # 5.Profile latency with kernel
# profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
# latency = profiler.do_bench()
# print(f"Latency: {latency} ms")
import tilelang
import tilelang.language as T
tilelang.disable_cache()
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
num_stages = 2
mbarrier_list = [128, 128] * num_stages
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((num_stages, block_M, block_K), dtype, "shared")
B_shared = T.alloc_shared((num_stages, block_K, block_N), dtype, "shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# create mbarrier for tma
T.create_list_of_mbarrier(mbarrier_list)
with T.ws(1):
for ko in range(T.ceildiv(K, block_K)):
T.mbarrier_wait_parity(ko % num_stages + num_stages,
((ko // num_stages) % num_stages) ^ 1)
T.copy(A[by * block_M:(by + 1) * block_M, ko * block_K:(ko + 1) * block_K],
A_shared[ko % num_stages, :, :])
T.copy(B[ko * block_K:(ko + 1) * block_K, bx * block_N:(bx + 1) * block_N],
B_shared[ko % num_stages, :, :])
T.mbarrier_arrive(ko % num_stages)
with T.ws(0):
T.clear(C_local)
for ko in range(T.ceildiv(K, block_K)):
T.mbarrier_wait_parity(ko % num_stages, (ko // num_stages) % num_stages)
T.gemm(A_shared[ko % num_stages, :, :], B_shared[ko % num_stages, :, :],
C_local)
T.mbarrier_arrive(ko % num_stages + num_stages)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
K = 64
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(128, 128, K, 128, 128, 32)
print(func.script())
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
tilelang.disable_cache()
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
execution_backend="cython",
pass_configs={
"tl.disable_warp_specialized": True,
# "tl.disable_tma_lower": True,
})
tilelang.enable_cache()
print(jit_kernel.get_kernel_source())
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(128, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, 128, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = jit_kernel(a, b)
print(c)
# Reference multiplication using PyTorch
ref_c = a @ b
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# # 4. Retrieve and inspect the generated CUDA source (optional)
# # cuda_source = jit_kernel.get_kernel_source()
# # print("Generated CUDA kernel:\n", cuda_source)
# # 5.Profile latency with kernel
# profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
# latency = profiler.do_bench()
# print(f"Latency: {latency} ms")
import tilelang
import tilelang.language as T
tilelang.disable_cache()
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype, "shared")
B_shared = T.alloc_shared((block_K, block_N), dtype, "shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# create mbarrier for tma
T.create_list_of_mbarrier(128, 128)
# with T.ws(1):
# for ko in range(T.ceildiv(K, block_K)):
# T.mbarrier_wait_parity(1, (ko & 1) ^ 1)
# T.copy(A[by * block_M, ko * block_K], A_shared)
# T.copy(B[ko * block_K, bx * block_N], B_shared)
# T.mbarrier_arrive(0)
# with T.ws(0):
# T.clear(C_local)
# for ko in range(T.ceildiv(K, block_K)):
# T.mbarrier_wait_parity(0, ko & 1)
# T.gemm(A_shared, B_shared, C_local)
# T.mbarrier_arrive(1)
# T.copy(C_local, C[by * block_M, bx * block_N])
with T.ws(0):
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=2):
with T.ws(1):
T.mbarrier_wait_parity(1, (ko & 1) ^ 1)
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.mbarrier_arrive(0)
with T.ws(0):
T.mbarrier_wait_parity(0, ko & 1)
T.gemm(A_shared, B_shared, C_local)
T.mbarrier_arrive(1)
with T.ws(0):
T.copy(C_local, C[by * block_M, bx * block_N])
return main
K = 64
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(128, 128, K, 128, 128, 32)
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
tilelang.disable_cache()
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
execution_backend="cython",
pass_configs={
"tl.disable_warp_specialized": True,
# "tl.disable_tma_lower": True,
})
tilelang.enable_cache()
print(jit_kernel.get_kernel_source())
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(128, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, 128, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = jit_kernel(a, b)
print(c)
# Reference multiplication using PyTorch
ref_c = a @ b
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# # 4. Retrieve and inspect the generated CUDA source (optional)
# # cuda_source = jit_kernel.get_kernel_source()
# # print("Generated CUDA kernel:\n", cuda_source)
# # 5.Profile latency with kernel
# profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
# latency = profiler.do_bench()
# print(f"Latency: {latency} ms")
import tilelang
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype, "shared")
B_shared = T.alloc_shared((block_K, block_N), dtype, "shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# create mbarrier for tma
T.create_list_of_mbarrier(128, 128)
with T.ws(1):
for ko in range(T.ceildiv(K, block_K)):
T.mbarrier_wait_parity(1, ko ^ 1)
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.mbarrier_arrive(0)
with T.ws(0):
T.clear(C_local)
for ko in range(T.ceildiv(K, block_K)):
T.mbarrier_wait_parity(0, ko)
T.gemm(A_shared, B_shared, C_local)
T.mbarrier_arrive(1)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(128, 128, 64, 128, 128, 32)
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
execution_backend="cython",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True,
})
print(jit_kernel.get_kernel_source())
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(128, 64, device="cuda", dtype=torch.float16)
b = torch.randn(64, 128, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = jit_kernel(a, b)
print(c)
# Reference multiplication using PyTorch
ref_c = a @ b
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
......@@ -317,8 +317,9 @@ Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size) {
Var i = InputPlaceholder(0);
Var j = InputPlaceholder(1);
int vector_size = 128 / element_size;
ICHECK(stride % 8 == 0);
ICHECK(continuous % (vector_size * 8) == 0);
ICHECK(stride % 8 == 0) << "stride=" << stride;
ICHECK(continuous % (vector_size * 8) == 0)
<< "continuous=" << continuous << ", vector_size=" << vector_size;
PrimExpr ts = FloorDiv(i, 8);
PrimExpr s = FloorMod(i, 8);
PrimExpr tc = FloorDiv(FloorDiv(j, vector_size), 8);
......
......@@ -104,6 +104,29 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
}
Buffer global_tensor = is_load ? src : dst;
Buffer shared_tensor = is_load ? dst : src;
Array<Range> global_range = is_load ? src_range : dst_range;
Array<Range> shared_range = is_load ? dst_range : src_range;
Array<PrimExpr> indices;
for (auto r : shared_range)
indices.push_back(r->min);
std::vector<PrimExpr> strides;
PrimExpr stride = 1;
for (size_t i = 0; i < shared_tensor->shape.size(); i++) {
auto s = shared_tensor->shape[shared_tensor->shape.size() - i - 1];
strides.insert(strides.begin(), stride);
stride *= s;
}
ICHECK(strides.size() == indices.size())
<< "strides.size() != indices.size()" << strides.size() << " "
<< indices.size();
PrimExpr offset = 0;
for (size_t i = 0; i < indices.size(); i++) {
offset += indices[i] * strides[i];
}
Layout shared_layout;
if (T.layout_map.count(shared_tensor)) {
shared_layout = T.layout_map[shared_tensor];
......@@ -129,7 +152,6 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
desc.data_type = to_CUtensorMapDataType(global_tensor->dtype);
// Global Tensor Shape and Stride
auto global_range = is_load ? src_range : dst_range;
desc.global_addr = global_tensor->data;
desc.global_shape = ReverseArray(global_tensor->shape);
Array<PrimExpr> global_coords =
......@@ -217,16 +239,16 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto op = is_load ? tma_load() : tma_store();
Stmt tma_copy;
if ((*inner_box_dim) != instruction_dim) {
Var loop_var("i");
int loop_extent = (*inner_box_dim) / instruction_dim;
PrimExpr total_elements = 1;
for (auto e : desc.smem_box)
total_elements *= e;
PrimExpr shared_addr =
shared_tensor.access_ptr(is_load ? 2 : 1, DataType::Handle(), 1,
total_elements * loop_var, total_elements);
if ((*inner_box_dim) != instruction_dim) {
Var loop_var("i");
int loop_extent = (*inner_box_dim) / instruction_dim;
PrimExpr shared_addr = shared_tensor.access_ptr(
is_load ? 2 : 1, DataType::Handle(), 1,
offset + total_elements * loop_var, total_elements);
args.push_back(shared_addr);
global_coords.Set(0, global_coords[0] + instruction_dim * loop_var);
for (auto coord : global_coords)
......@@ -234,13 +256,14 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled,
Evaluate(Call(DataType::Handle(), op, args)));
} else {
PrimExpr shared_addr = shared_tensor.access_ptr(is_load ? 2 : 1);
PrimExpr shared_addr = shared_tensor.access_ptr(
is_load ? 2 : 1, DataType::Handle(), 1, offset, total_elements);
args.push_back(shared_addr);
for (auto coord : global_coords)
args.push_back(coord);
tma_copy = Evaluate(Call(DataType::Handle(), op, args));
}
tma_copy = IfThenElse(EQ(T.thread_var, 0), tma_copy);
tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
return tma_copy;
}
......@@ -393,7 +416,7 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T,
args.push_back(offset);
Stmt tma_copy =
IfThenElse(EQ(T.thread_var, 0),
IfThenElse(EQ(T.thread_var, T.thread_bounds->min),
Evaluate(Call(DataType::Handle(), tma_load_im2col(), args)));
return tma_copy;
}
......
......@@ -34,9 +34,12 @@ static std::vector<int> toPrimeFactors(int x) {
}
Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
A = vmap[GetVarFromAccessPtr(args[0])];
B = vmap[GetVarFromAccessPtr(args[1])];
C = vmap[GetVarFromAccessPtr(args[2])];
Aptr = args[0];
Bptr = args[1];
Cptr = args[2];
A = vmap[GetVarFromAccessPtr(Aptr)];
B = vmap[GetVarFromAccessPtr(Bptr)];
C = vmap[GetVarFromAccessPtr(Cptr)];
trans_A = args[3].as<Bool>().value();
trans_B = args[4].as<Bool>().value();
M = args[5].as<IntImm>().value()->value;
......@@ -149,9 +152,9 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Array<PrimExpr> new_args;
new_args.push_back(StringImm(ss.str()));
new_args.push_back(A_buffer.access_ptr(1));
new_args.push_back(B_buffer.access_ptr(1));
new_args.push_back(C_buffer.access_ptr(3));
new_args.push_back(Aptr);
new_args.push_back(Bptr);
new_args.push_back(Cptr);
auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);
return Evaluate(new_call);
}
......@@ -170,9 +173,10 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[0]),
*as_const_int(A->shape[1]), true,
trans_A ? 1 : 2));
int dim_A = A->shape.size();
results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
*as_const_int(A->shape[dim_A - 1]),
true, trans_A ? 1 : 2));
} else if (A.scope() == "local.fragment") {
ICHECK(trans_A == false);
results.Set(A, makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n));
......@@ -181,9 +185,10 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
}
ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[0]),
*as_const_int(B->shape[1]), false,
trans_B ? 2 : 1));
int dim_B = B->shape.size();
results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
*as_const_int(B->shape[dim_B - 1]),
false, trans_B ? 2 : 1));
} else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) {
const int warp_size = 32;
auto [warp_m, warp_n] =
......@@ -193,8 +198,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
const int64_t mat_stride = *as_const_int(A->shape[0]);
const int64_t mat_continuous = *as_const_int(A->shape[1]);
int dim_A = A->shape.size();
const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
results.Set(A,
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
A->dtype.bits(), trans_A ? 1 : 2));
......@@ -206,8 +212,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(0);
}
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
const int64_t mat_stride = *as_const_int(B->shape[0]);
const int64_t mat_continuous = *as_const_int(B->shape[1]);
int dim_B = B->shape.size();
const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
results.Set(B,
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
B->dtype.bits(), trans_B ? 2 : 1));
......@@ -230,8 +237,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
: makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
const int64_t mat_stride = *as_const_int(A->shape[0]);
const int64_t mat_continuous = *as_const_int(A->shape[1]);
int dim_A = A->shape.size();
const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
const int64_t continuity =
trans_A ? mat_continuous / (warp_m / 4) : mat_continuous;
results.Set(A, makeGemmABLayout(mat_stride, mat_continuous, continuity,
......@@ -242,8 +250,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
A->dtype.bits()));
}
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
const int64_t mat_stride = *as_const_int(B->shape[0]);
const int64_t mat_continuous = *as_const_int(B->shape[1]);
int dim_B = B->shape.size();
const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
const int64_t continuity =
trans_B ? mat_continuous : mat_continuous / warp_n;
results.Set(B, makeGemmABLayout(mat_stride, mat_continuous, continuity,
......@@ -262,16 +271,10 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
// Make Linear Memory Access Layout
// auto shared_layout =
// makeGemmLayoutLinear(*as_const_int(A->shape[0]),
// *as_const_int(A->shape[1]));
// Make Swizzle or Pad Layout
auto shared_layout = makeGemmABLayoutCDNA(*as_const_int(A->shape[0]),
*as_const_int(A->shape[1]),
A->dtype.bits(), kPack);
int dim_A = A->shape.size();
auto shared_layout = makeGemmABLayoutCDNA(
*as_const_int(A->shape[dim_A - 2]),
*as_const_int(A->shape[dim_A - 1]), A->dtype.bits(), kPack);
results.Set(A, shared_layout);
} else if (A.scope() == "local.fragment") {
results.Set(A, makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
......@@ -280,15 +283,10 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(0);
}
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
// Make Linear Memory Access Layout
// auto shared_layout =
// makeGemmLayoutLinear(*as_const_int(B->shape[0]),
// *as_const_int(B->shape[1]));
// Make Swizzle or Pad Layout
auto shared_layout = makeGemmABLayoutCDNA(*as_const_int(B->shape[0]),
*as_const_int(B->shape[1]),
B->dtype.bits(), kPack);
int dim_B = B->shape.size();
auto shared_layout = makeGemmABLayoutCDNA(
*as_const_int(B->shape[dim_B - 2]),
*as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack);
results.Set(B, shared_layout);
} else if (B.scope() == "local.fragment") {
......
......@@ -33,6 +33,8 @@ private:
Array<PrimExpr> call_args;
tir::Buffer A, B, C;
// pointer to the A, B, C
PrimExpr Aptr, Bptr, Cptr;
bool trans_A, trans_B;
int M, N, K;
bool clear_accum = false;
......
......@@ -826,7 +826,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
} else if (op->op.same_as(tl::sync_thread_partial())) {
print_extern_call_stmt("tl::syncthreads_partial");
} else if (op->op.same_as(tl::tma_load())) {
print_extern_call_stmt("tl::tma_load");
this->PrintIndent();
ICHECK_GE(op->args.size(), 2);
this->stream << "tl::tma_load(";
auto desc = op->args[0];
this->stream << this->PrintExpr(desc) << ", ";
if (const IntImmNode *imm = op->args[1].as<IntImmNode>()) {
this->stream << "_mbarrier[" << imm->value << "], ";
} else {
this->stream << this->PrintExpr(op->args[1]) << ", ";
}
for (size_t i = 2; i < op->args.size(); i++) {
if (i > 2)
this->stream << ", ";
this->stream << this->PrintExpr(op->args[i]);
}
this->stream << ");\n";
} else if (op->op.same_as(tl::tma_load_im2col())) {
print_extern_call_stmt("tl::tma_load_im2col");
} else if (op->op.same_as(tl::tma_store())) {
......
......@@ -783,19 +783,6 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
print_extern_call_stmt("tl::mbarrier_wait");
} else if (op->op.same_as(tl::sync_thread_partial())) {
print_extern_call_stmt("tl::syncthreads_partial");
} else if (op->op.same_as(tl::tma_load())) {
print_extern_call_stmt("tl::tma_load");
} else if (op->op.same_as(tl::tma_load_im2col())) {
print_extern_call_stmt("tl::tma_load_im2col");
} else if (op->op.same_as(tl::tma_store())) {
print_extern_call_stmt("tl::tma_store");
} else if (op->op.same_as(tl::ptx_ldmatirx())) {
int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
if (trans == 1)
func_name += "_trans";
print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::ptx_stmatirx())) {
int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value;
......@@ -803,15 +790,6 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
if (trans == 1)
func_name += "_trans";
print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::fence_proxy_async())) {
print_extern_call_stmt("tl::fence_proxy_async");
} else if (op->op.same_as(tl::set_max_nreg())) {
this->PrintIndent();
int nreg = Downcast<IntImm>(op->args[0])->value;
int is_inc = Downcast<IntImm>(op->args[1])->value;
std::string func_name =
is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc";
this->stream << func_name << "<" << std::to_string(nreg) << ">();\n";
} else if (op->op.same_as(tl::wait_wgmma())) {
this->PrintIndent();
int num_mma = Downcast<IntImm>(op->args[0])->value;
......
......@@ -37,9 +37,32 @@ static Buffer makeBufferWithLayout(const Buffer &buffer, const Layout &layout) {
} else {
new_var = Var(buffer->data->name_hint, new_type);
}
return Buffer(new_var, buffer->dtype, layout->OutputShape(), {},
buffer->elem_offset, buffer->name, buffer->data_alignment,
buffer->offset_factor, buffer->buffer_type);
Array<PrimExpr> layout_shape = layout->OutputShape();
Array<PrimExpr> output_shape = layout_shape;
if (ptr_type->storage_scope == "shared" ||
ptr_type->storage_scope == "shared.dyn") {
int replicate_extent = 1;
Array<PrimExpr> buffer_shape = buffer->shape;
int buffer_extent = 1;
int layout_extent = 1;
for (size_t i = 0; i < buffer_shape.size(); i++) {
auto shape = buffer_shape[i].as<IntImmNode>();
buffer_extent *= shape->value;
}
for (size_t i = 0; i < layout_shape.size(); i++) {
auto shape = layout_shape[i].as<IntImmNode>();
layout_extent *= shape->value;
}
replicate_extent = buffer_extent / layout_extent;
if (replicate_extent > 1) {
output_shape.insert(output_shape.begin(), replicate_extent);
}
}
return Buffer(new_var, buffer->dtype, output_shape, {}, buffer->elem_offset,
buffer->name, buffer->data_alignment, buffer->offset_factor,
buffer->buffer_type);
}
class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tma_barrier_rewriter.cc
* \brief Rewrite TMA barriers for cuda GPU (sm90+)
*/
#include <tvm/arith/analyzer.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
namespace tvm {
namespace tl {
using namespace tir;
using namespace tir::transform;
using arith::IRMutatorWithAnalyzer;
class TmaTraitsCollector : public StmtExprVisitor {
public:
TmaTraitsCollector() { Initialize(); }
void Initialize() {
bulk_copy_bytes = 0;
loop_extents = 1;
}
void Collect(Stmt stmt) { VisitStmt(stmt); }
PrimExpr BulkCopyBytes() { return bulk_copy_bytes; }
private:
void VisitExpr_(const CallNode *call) final {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
Call access_ptr = Downcast<Call>(call->args[2]);
ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
int type_bytes = access_ptr->args[0]->dtype.bytes();
bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes;
}
StmtExprVisitor::VisitExpr_(call);
}
void VisitStmt_(const ForNode *op) final {
PrimExpr old_loop_evtents = loop_extents;
loop_extents *= op->extent;
StmtExprVisitor::VisitStmt_(op);
loop_extents = old_loop_evtents;
}
PrimExpr bulk_copy_bytes = 0;
PrimExpr loop_extents = 1;
};
class TmaExpectTxRewriter : public IRMutatorWithAnalyzer {
public:
static PrimFunc Rewrite(PrimFunc f, arith::Analyzer *analyzer) {
TmaExpectTxRewriter rewriter(analyzer);
f.CopyOnWrite()->body = rewriter(f->body);
return f;
}
private:
bool inside_tma_block_{false};
bool visited_tma_load_{false};
IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
IterVarType::kDataPar);
PrimExpr GetBarrierId(const PrimExpr &ko) {
// FloorMod(ko, 1)
return FloorMod(ko, IntImm(DataType::Int(32), 1));
}
PrimExpr GetBarrierParity(const PrimExpr &ko) {
// FloorDiv(ko, 1) % 2
return FloorMod(FloorDiv(ko, IntImm(DataType::Int(32), 1)),
IntImm(DataType::Int(32), 2));
}
PrimExpr makeGetBarrier(PrimExpr barrier_id) {
return Call(DataType::Handle(), get_mbarrier(), {barrier_id});
}
Stmt makeExpectTX(PrimExpr barrier_id, PrimExpr bytes) {
auto call = Call(DataType::Handle(), mbarrier_expect_tx(),
{makeGetBarrier(barrier_id), bytes});
return Evaluate(call);
}
TmaExpectTxRewriter(arith::Analyzer *analyzer)
: IRMutatorWithAnalyzer(analyzer) {}
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
ICHECK(iv->dom->extent.as<IntImmNode>());
thread_var_ = iv;
}
}
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
Stmt VisitStmt_(const IfThenElseNode *op) {
// Check if this is the TMA block
const EQNode *eq = op->condition.as<EQNode>();
if (eq != nullptr) {
Stmt ret = IRMutatorWithAnalyzer::VisitStmt_(op);
if (visited_tma_load_) {
auto then_case = op->then_case;
TmaTraitsCollector collector;
collector.Collect(then_case);
Array<Stmt> stmts;
if (!is_zero(collector.BulkCopyBytes())) {
auto expect_tx = makeExpectTX(0, collector.BulkCopyBytes());
stmts.push_back(expect_tx);
}
stmts.push_back(then_case);
if (stmts.size() == 1) {
return IfThenElse(op->condition, stmts[0], op->else_case);
} else {
auto seq_stmt = SeqStmt(stmts);
return IfThenElse(op->condition, seq_stmt, op->else_case);
}
}
visited_tma_load_ = false;
return ret;
}
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(tma_load())) {
visited_tma_load_ = true;
Array<PrimExpr> new_args = op->args;
new_args.Set(1, Call(DataType::Handle(), get_mbarrier(),
{IntImm(DataType::Int(32), 0)}));
return Call(op->dtype, op->op, new_args);
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
};
class TmaBarrierCollector : public StmtExprVisitor {
public:
Map<ObjectRef, PrimExpr> tma_op_to_barrier_id() {
return tma_op_to_barrier_id_;
}
private:
void VisitStmt_(const EvaluateNode *op) final {
if (const auto *call = op->value.as<CallNode>()) {
if (call->op.same_as(tma_load())) {
pending_tma_ops_.push_back(GetRef<Call>(call));
} else if (call->op.same_as(mbarrier_expect_tx())) {
pending_tma_ops_.push_back(GetRef<Call>(call));
} else if (call->op.same_as(builtin::ptx_arrive_barrier())) {
PrimExpr barrier_id = call->args[0];
for (auto tma_call : pending_tma_ops_) {
tma_op_to_barrier_id_.Set(tma_call, barrier_id);
}
pending_tma_ops_.clear();
}
}
StmtExprVisitor::VisitStmt_(op);
}
std::vector<Call> pending_tma_ops_;
Map<ObjectRef, PrimExpr> tma_op_to_barrier_id_;
};
// we trust mbarrier_wait_parity to be correct
class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
public:
TmaBarrierRewriter(arith::Analyzer *analyzer,
Map<ObjectRef, PrimExpr> tma_op_to_barrier_id)
: IRMutatorWithAnalyzer(analyzer),
tma_op_to_barrier_id_(tma_op_to_barrier_id) {}
static PrimFunc Rewrite(PrimFunc f, arith::Analyzer *analyzer) {
f = TmaExpectTxRewriter::Rewrite(f, analyzer);
TmaBarrierCollector collector;
collector(f->body);
TmaBarrierRewriter rewriter(analyzer, collector.tma_op_to_barrier_id());
f.CopyOnWrite()->body = rewriter(f->body);
return f;
}
private:
PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(tma_load())) {
// check this must be in the tma_op_to_barrier_id_
ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op)))
<< "tma_load must be in the tma_op_to_barrier_id_";
auto barrier_id = tma_op_to_barrier_id_[GetRef<Call>(op)];
auto new_args = op->args;
new_args.Set(1, barrier_id);
return Call(op->dtype, op->op, new_args);
} else if (op->op.same_as(mbarrier_expect_tx())) {
ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op)))
<< "mbarrier_expect_tx must be in the tma_op_to_barrier_id_";
auto barrier_id = tma_op_to_barrier_id_[GetRef<Call>(op)];
auto new_args = op->args;
new_args.Set(0, barrier_id);
return Call(op->dtype, op->op, new_args);
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
Map<ObjectRef, PrimExpr> tma_op_to_barrier_id_;
};
tvm::transform::Pass InjectTmaBarrier() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
arith::Analyzer analyzer;
return TmaBarrierRewriter::Rewrite(f, &analyzer);
};
return CreatePrimFuncPass(pass_func, 0, "tl.InjectTmaBarrier", {});
}
TVM_REGISTER_GLOBAL("tl.transform.InjectTmaBarrier")
.set_body_typed(InjectTmaBarrier);
} // namespace tl
} // namespace tvm
......@@ -22,6 +22,7 @@
* \brief Warp specialized Pipeline for cuda GPU (sm90+)
*/
#include "arith/ir_visitor_with_analyzer.h"
#include "tir/analysis/var_use_def_analysis.h"
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
......@@ -35,6 +36,7 @@ namespace tvm {
namespace tl {
using namespace tir;
using arith::IRVisitorWithAnalyzer;
enum class Role { kConsumer, kProducer, kBoth };
......@@ -209,12 +211,6 @@ static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
return Call(DataType::Handle(), get_mbarrier(), {barrier_id});
}
static Stmt makeExpectTX(PrimExpr barrier_id, PrimExpr bytes) {
auto call = Call(DataType::Handle(), mbarrier_expect_tx(),
{makeGetBarrier(barrier_id), bytes});
return Evaluate(call);
}
static Stmt makeArriveBarrier(PrimExpr barrier_id) {
auto call = Call(DataType::Handle(), builtin::ptx_arrive_barrier(),
{makeGetBarrier(barrier_id)});
......@@ -233,95 +229,17 @@ static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) {
return Evaluate(call);
}
// static bool isGemm(Stmt stmt) {
// bool is_gemm = false;
// if (stmt.as<EvaluateNode>()) {
// auto call = Downcast<Evaluate>(stmt)->value.as<CallNode>();
// if (call && call->op.same_as(Op::Get("tir.call_extern"))) {
// if (call->args[0].as<StringImmNode>()) {
// std::string name = Downcast<StringImm>(call->args[0])->value;
// if (name.find("gemm") != std::string::npos) {
// is_gemm = true;
// }
// }
// }
// }
// return is_gemm;
// }
class TMAExpectTxRewriter : public StmtExprMutator {
public:
TMAExpectTxRewriter(Stmt expect_tx) : expect_tx_(expect_tx) {}
static Stmt Rewrite(Stmt stmt, Stmt expect_tx) {
TMAExpectTxRewriter rewriter(expect_tx);
return rewriter(stmt);
}
private:
Stmt VisitStmt_(const ForNode *op) final {
insert_in_evaluate_ = false;
StmtExprMutator::VisitStmt_(op);
insert_in_evaluate_ = true;
if (contain_tma_load_) {
Array<Stmt> new_seq = {expect_tx_, GetRef<For>(op)};
contain_tma_load_ = false;
return SeqStmt(std::move(new_seq));
}
return StmtExprMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
contain_tma_load_ = true;
if (insert_in_evaluate_) {
Array<Stmt> new_seq = {expect_tx_, GetRef<Evaluate>(op)};
return SeqStmt(std::move(new_seq));
}
}
}
return StmtExprMutator::VisitStmt_(op);
}
Stmt expect_tx_;
bool contain_tma_load_;
bool insert_in_evaluate_ = true;
};
class ProducerTraitsCollector : public StmtExprVisitor {
public:
ProducerTraitsCollector() { Clear(); }
void Clear() {
bulk_copy_bytes = 0;
loop_extents = 1;
has_simt_copy = false;
}
void Clear() { has_simt_copy = false; }
void Collect(Stmt stmt) { VisitStmt(stmt); }
bool HasSimtCopy() { return has_simt_copy; }
PrimExpr BulkCopyBytes() { return bulk_copy_bytes; }
private:
void VisitExpr_(const CallNode *call) final {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
Call access_ptr = Downcast<Call>(call->args[2]);
ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
int type_bytes = access_ptr->args[0]->dtype.bytes();
bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes;
}
StmtExprVisitor::VisitExpr_(call);
}
void VisitStmt_(const ForNode *op) final {
PrimExpr old_loop_evtents = loop_extents;
loop_extents *= op->extent;
StmtExprVisitor::VisitStmt_(op);
loop_extents = old_loop_evtents;
}
void VisitStmt_(const IfThenElseNode *op) final {
bool old_in_if_cond = in_if_cond_;
in_if_cond_ = true;
......@@ -342,8 +260,6 @@ private:
}
bool has_simt_copy;
PrimExpr bulk_copy_bytes;
PrimExpr loop_extents;
bool in_if_cond_ = false;
};
......@@ -646,14 +562,7 @@ private:
auto stmt =
MbarrierRewriter::Rewrite(seq_transformed[i], release_barrier_id);
collector.Collect(stmt);
if (!is_zero(collector.BulkCopyBytes())) {
auto expect_tx = IfThenElse(
EQ(thread_var_, 0),
makeExpectTX(release_barrier_id, collector.BulkCopyBytes()));
block_stmt.push_back(TMAExpectTxRewriter::Rewrite(stmt, expect_tx));
} else {
block_stmt.push_back(stmt);
}
if (collector.HasSimtCopy() > 0) {
block_stmt.push_back(makeCpAsyncBarrier(release_barrier_id));
}
......@@ -1276,13 +1185,56 @@ private:
Array<IntImm> nreg_;
};
class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
public:
static bool Detect(Stmt stmt, bool skip_thread_partition = false) {
WarpSpecializedDetector detector;
detector.VisitStmt(stmt);
return detector.has_tma_op_ && detector.has_mbarrier_op_;
}
WarpSpecializedDetector() {
has_tma_op_ = false;
has_mbarrier_op_ = false;
}
private:
void VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(create_list_of_mbarrier()) ||
call->op.same_as(mbarrier_wait_parity()) ||
call->op.same_as(builtin::ptx_arrive_barrier()) ||
call->op.same_as(builtin::ptx_cp_async_barrier())) {
has_mbarrier_op_ = true;
}
}
IRVisitorWithAnalyzer::VisitStmt_(op);
}
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) ||
op->op.same_as(set_max_nreg())) {
has_tma_op_ = true;
}
IRVisitorWithAnalyzer::VisitExpr_(op);
}
bool has_tma_op_{false};
bool has_mbarrier_op_{false};
};
using namespace tir::transform;
tvm::transform::Pass WarpSpecialized() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
bool disable_warp_specialized =
ctx->GetConfig<Bool>(kDisableWarpSpecialized, Bool(false)).value();
bool warp_specialized = WarpSpecializedDetector::Detect(f->body);
if (!warp_specialized) {
return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized);
}
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {});
}
......
......@@ -62,6 +62,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.WarpSpecialized()(mod)
# if tma is not enabled, we can also do pipeline planning
# to get better performance with async copy
mod = tilelang.transform.InjectTmaBarrier()(mod)
mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
# warp_specialized pass will pack the if stmt into the block
......
......@@ -3,20 +3,40 @@
from tilelang import tvm as tvm
from tilelang.language import ptx_arrive_barrier
from tvm import tir
from typing import Union
from tvm.tir import PrimExpr, Var
from typing import Union, Any
from tvm.tir import PrimExpr, Var, Call
def create_list_of_mbarrier(*args):
"""Create a list of memory barrier operations.
Args:
*args: Variable arguments passed to the memory barrier creation operation
Returns:
tir.Call: A handle to the created list of memory barriers
def create_list_of_mbarrier(*args: Any) -> Call:
"""
Create a list of memory barrier handles.
Parameters
----------
*args : list or Any
Either a single list of arguments, or multiple arguments directly.
Returns
-------
tvm.tir.Call
Handle to the created list of memory barriers.
Raises
------
TypeError
If the input is not a list or variadic arguments.
Examples
--------
>>> create_list_of_mbarrier([128, 128])
>>> create_list_of_mbarrier(128, 128)
"""
if len(args) == 1 and isinstance(args[0], list):
return tir.call_intrin("handle", tir.op.Op.get("tl.create_list_of_mbarrier"), *args[0])
elif len(args) >= 1:
return tir.call_intrin("handle", tir.op.Op.get("tl.create_list_of_mbarrier"), *args)
else:
raise TypeError("create_list_of_mbarrier expects a list or one or more arguments.")
def get_mbarrier(*args):
......@@ -115,7 +135,7 @@ def no_set_max_nreg(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.no_set_max_nreg"), *args)
def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr], parity: Union[int, Var]):
def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr, tir.Call], parity: Union[int, Var]):
"""Wait for memory barrier parity condition.
Args:
......@@ -154,20 +174,28 @@ def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr], parity: Union[int, Var]
Returns:
tir.Call: A handle to the barrier wait operation
"""
if isinstance(mbarrier, int):
if isinstance(mbarrier, tir.Call):
mbarrier = mbarrier
elif isinstance(mbarrier, (tir.PrimExpr, int)):
mbarrier = get_mbarrier(mbarrier)
else:
raise TypeError("mbarrier must be an integer or a tir.Call")
return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_wait_parity"), mbarrier, parity)
def mbarrier_arrive(mbarrier: Union[int, PrimExpr]):
def mbarrier_arrive(mbarrier: Union[int, PrimExpr, tir.Call]):
"""Arrive at memory barrier.
Args:
mbarrier: Optional[int, PrimExpr]
The memory barrier to arrive at
"""
if isinstance(mbarrier, int):
if isinstance(mbarrier, tir.Call):
mbarrier = mbarrier
elif isinstance(mbarrier, (tir.PrimExpr, int)):
mbarrier = get_mbarrier(mbarrier)
else:
raise TypeError("mbarrier must be an integer or a tir.Call")
return ptx_arrive_barrier(mbarrier)
......
......@@ -3,7 +3,7 @@
from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T
from tvm import tir
from typing import Union
from typing import Union, List
def gemm(
......@@ -56,14 +56,64 @@ def gemm(
A = legalize_arguments(A)
B = legalize_arguments(B)
C = legalize_arguments(C)
M = C.shape[0]
N = C.shape[1]
K = A.shape[0] if transpose_A else A.shape[1]
K_B = B.shape[1] if transpose_B else B.shape[0]
assert K == K_B, "gemm K shape check failed"
Aptr = A.access_ptr("r")
Bptr = B.access_ptr("r")
Cptr = C.access_ptr("rw")
def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
if isinstance(object, tir.Buffer):
return object.shape
elif isinstance(object, tir.BufferRegion):
region = object.region
shape = []
for r in region:
shape.append(r.extent)
return shape
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
A_shape = retrieve_shape(A)
B_shape = retrieve_shape(B)
C_shape = retrieve_shape(C)
assert len(C_shape) == 2, "current only support C as a 2D tensor"
assert len(A_shape) >= 2, "current only support A as a 2D or higher-order tensor"
assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor"
if len(A_shape) > 2:
for i in range(len(A_shape) - 2):
assert A_shape[i] == 1, \
"current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
if len(B_shape) > 2:
for i in range(len(B_shape) - 2):
assert B_shape[i] == 1, \
"current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
M, N = C_shape
K = A_shape[-2] if transpose_A else A_shape[-1]
K_B = B_shape[-1] if transpose_B else B_shape[-2]
assert K == K_B, f"T.gemm K shape check failed: K_A = {K}, K_B = {K_B}"
def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion],
access_type: str = "r") -> tir.PrimExpr:
if isinstance(object, tir.Buffer):
return object.access_ptr(access_type)
elif isinstance(object, tir.BufferRegion):
buffer, region = object.buffer, object.region
indices = []
for r in region:
indices.append(r.min)
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
offset = 0
for i in range(len(indices)):
offset += indices[i] * strides[i]
return buffer.access_ptr(access_mask=access_type, offset=offset)
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
Aptr = retrieve_ptr(A, "r")
Bptr = retrieve_ptr(B, "r")
Cptr = retrieve_ptr(C, "rw")
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.gemm"),
......
......@@ -189,6 +189,17 @@ def WarpSpecialized():
return _ffi_api.WarpSpecialized() # type: ignore
def InjectTmaBarrier():
"""InjectTmaBarrier
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InjectTmaBarrier() # type: ignore
def InjectFenceProxy():
"""InjectFenceProxy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment