"...git@developer.sourcefind.cn:yangql/composable_kernel.git" did not exist on "5b7a18c50601583b28e54905b56e1ac7342b22c3"
Commit fce16b00 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Separate warp specialize rewriter and tma barrier injector pass (#447)

* [Refactor] Update KernelLaunch to clarify CPU and GPU kernel launch logic

* Added comments to distinguish between CPU and GPU kernel launch sections for better code readability.
* Changed the creation of empty blocks to use a consistent "root" identifier, enhancing clarity in frame management.

* [Refactor] Rename operations for consistency in lower_hopper_intrin and related files

* Updated function names from CamelCase to snake_case for better consistency across the codebase.
* Refactored calls to `CreateTMADescriptorOp`, `CreateListofMBarrierOp`, and similar functions to their new names: `create_tma_descriptor`, `create_list_of_mbarrier`, etc.
* Adjusted corresponding test cases to reflect these changes, ensuring compatibility with the new naming conventions.

* [Refactor] Rename operations to snake_case for consistency

* Updated function names from CamelCase to snake_case across various files, including `CreateTMADescriptorOp` to `create_tma_descriptor`, `GetMBarrierOp` to `get_mbarrier`, and others.
* Adjusted corresponding calls and definitions in the codebase to reflect these naming changes, ensuring uniformity and improved readability.
* Enhanced layout inference and loop partitioning logic to accommodate the new naming conventions.

* [Feature] Introduce Warp Specialization and Eliminate Storage Sync for MBarrier

* Added a new example `gemm_ws.py` demonstrating matrix multiplication with warp specialization using TileLang.
* Implemented `WarpSpecializeFrame` and `WarpSpecialize` functionality to manage warp group indices in TIR frames.
* Introduced `EliminateStorageSyncForMBarrier` transformation to optimize storage synchronization in mbarrier regions.
* Enhanced the TileLang API with new methods for retrieving block and thread extents.
* Updated the `LowerAndLegalize` and `OptimizeForTarget` functions to incorporate the new transformation.
* Improved layout inference and kernel launch logic for better performance and clarity.

* [Refactor] Clean up code formatting and improve readability

* Added blank lines for better separation of code blocks in `gemm_ws.py`, `phase.py`, `kernel.py`, and `warpgroup.py`.
* Reformatted the `tilelang.compile` call in `gemm_ws.py` for improved clarity.
* Updated comments in `warpgroup.py` to clarify the availability of the `WarpSpecialize` function for NVIDIA GPUs.
* Ensured consistent spacing and formatting across multiple files to enhance overall code readability.

* lint fix

* [Refactor] Update mbarrier functions for improved clarity and consistency

* Refactored `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` to accept explicit parameters for better readability.
* Updated calls in `gemm_ws.py` to use the new function signatures, enhancing code clarity.
* Adjusted `warpgroup.py` to remove unused thread extent variable, streamlining the code.
* Added detailed docstrings to clarify usage examples for memory barrier functions.

* Added blank lines in `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` for improved code readability and separation of logical sections.

* [Feature] Add examples for warp specialization and TMA barrier integration

* Introduced three new example scripts: `example_warp_specialize_gemm.py`, `example_warp_specialize_gemm_barrier4.py`, and `example_warp_specialize_mla.py` demonstrating matrix multiplication with warp specialization and TMA barriers.
* Implemented kernel functions with shared memory allocation and memory barrier synchronization for improved performance.
* Enhanced the TileLang API with new methods for compiling and testing kernels in Python using PyTorch.
* Updated the `phase.py` to include TMA barrier injection in the optimization process.
* Improved documentation and comments for better clarity on usage and functionality.

* [Feature] Add example for warp specialization in GEMM with TMA barriers

* Introduced a new example script `example_warp_specialize_gemm_stage2.py` demonstrating matrix multiplication using warp specialization and TMA barriers.
* Implemented a kernel function with shared memory allocation and memory barrier synchronization for enhanced performance.
* Included functionality to compile the kernel into a PyTorch-compatible function and validate its correctness against PyTorch's reference implementation.
* Enhanced documentation and comments for clarity on usage and functionality.

* lint fix

* [Feature] Implement WarpSpecializedDetector for TMA and MBarrier Detection

* Added the `WarpSpecializedDetector` class to identify the presence of TMA operations and memory barrier operations within a given TIR statement.
* Enhanced the `WarpSpecialized` pass to utilize the detector, allowing for conditional substitution based on the detection results.
* Improved code organization by including necessary headers and utilizing the `IRVisitorWithAnalyzer` for analysis.
* This addition aims to optimize warp specialization by ensuring that only relevant functions are transformed, enhancing performance and correctness.

* lint fix
parent e46653ac
import tilelang
import tilelang.language as T
tilelang.disable_cache()
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype, "shared")
B_shared = T.alloc_shared((block_K, block_N), dtype, "shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# create mbarrier for tma
T.create_list_of_mbarrier(128, 128)
# with T.ws(1):
# for ko in range(T.ceildiv(K, block_K)):
# T.mbarrier_wait_parity(1, (ko & 1) ^ 1)
# T.copy(A[by * block_M, ko * block_K], A_shared)
# T.copy(B[ko * block_K, bx * block_N], B_shared)
# T.mbarrier_arrive(0)
# with T.ws(0):
# T.clear(C_local)
# for ko in range(T.ceildiv(K, block_K)):
# T.mbarrier_wait_parity(0, ko & 1)
# T.gemm(A_shared, B_shared, C_local)
# T.mbarrier_arrive(1)
# T.copy(C_local, C[by * block_M, bx * block_N])
with T.ws(0):
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=2):
with T.ws(1):
T.mbarrier_wait_parity(1, (ko & 1) ^ 1)
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.mbarrier_arrive(0)
with T.ws(0):
T.mbarrier_wait_parity(0, ko & 1)
T.gemm(A_shared, B_shared, C_local)
T.mbarrier_arrive(1)
with T.ws(0):
T.copy(C_local, C[by * block_M, bx * block_N])
return main
K = 64
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(128, 128, K, 128, 128, 32)
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
tilelang.disable_cache()
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
execution_backend="cython",
pass_configs={
"tl.disable_warp_specialized": True,
# "tl.disable_tma_lower": True,
})
tilelang.enable_cache()
print(jit_kernel.get_kernel_source())
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(128, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, 128, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = jit_kernel(a, b)
print(c)
# Reference multiplication using PyTorch
ref_c = a @ b
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# # 4. Retrieve and inspect the generated CUDA source (optional)
# # cuda_source = jit_kernel.get_kernel_source()
# # print("Generated CUDA kernel:\n", cuda_source)
# # 5.Profile latency with kernel
# profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
# latency = profiler.do_bench()
# print(f"Latency: {latency} ms")
import tilelang
import tilelang.language as T
tilelang.disable_cache()
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
num_stages = 2
mbarrier_list = [128, 128] * num_stages
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((num_stages, block_M, block_K), dtype, "shared")
B_shared = T.alloc_shared((num_stages, block_K, block_N), dtype, "shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# create mbarrier for tma
T.create_list_of_mbarrier(mbarrier_list)
with T.ws(1):
for ko in range(T.ceildiv(K, block_K)):
T.mbarrier_wait_parity(ko % num_stages + num_stages,
((ko // num_stages) % num_stages) ^ 1)
T.copy(A[by * block_M:(by + 1) * block_M, ko * block_K:(ko + 1) * block_K],
A_shared[ko % num_stages, :, :])
T.copy(B[ko * block_K:(ko + 1) * block_K, bx * block_N:(bx + 1) * block_N],
B_shared[ko % num_stages, :, :])
T.mbarrier_arrive(ko % num_stages)
with T.ws(0):
T.clear(C_local)
for ko in range(T.ceildiv(K, block_K)):
T.mbarrier_wait_parity(ko % num_stages, (ko // num_stages) % num_stages)
T.gemm(A_shared[ko % num_stages, :, :], B_shared[ko % num_stages, :, :],
C_local)
T.mbarrier_arrive(ko % num_stages + num_stages)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
K = 64
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(128, 128, K, 128, 128, 32)
print(func.script())
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
tilelang.disable_cache()
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
execution_backend="cython",
pass_configs={
"tl.disable_warp_specialized": True,
# "tl.disable_tma_lower": True,
})
tilelang.enable_cache()
print(jit_kernel.get_kernel_source())
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(128, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, 128, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = jit_kernel(a, b)
print(c)
# Reference multiplication using PyTorch
ref_c = a @ b
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# # 4. Retrieve and inspect the generated CUDA source (optional)
# # cuda_source = jit_kernel.get_kernel_source()
# # print("Generated CUDA kernel:\n", cuda_source)
# # 5.Profile latency with kernel
# profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
# latency = profiler.do_bench()
# print(f"Latency: {latency} ms")
import tilelang
import tilelang.language as T
tilelang.disable_cache()
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype, "shared")
B_shared = T.alloc_shared((block_K, block_N), dtype, "shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# create mbarrier for tma
T.create_list_of_mbarrier(128, 128)
# with T.ws(1):
# for ko in range(T.ceildiv(K, block_K)):
# T.mbarrier_wait_parity(1, (ko & 1) ^ 1)
# T.copy(A[by * block_M, ko * block_K], A_shared)
# T.copy(B[ko * block_K, bx * block_N], B_shared)
# T.mbarrier_arrive(0)
# with T.ws(0):
# T.clear(C_local)
# for ko in range(T.ceildiv(K, block_K)):
# T.mbarrier_wait_parity(0, ko & 1)
# T.gemm(A_shared, B_shared, C_local)
# T.mbarrier_arrive(1)
# T.copy(C_local, C[by * block_M, bx * block_N])
with T.ws(0):
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=2):
with T.ws(1):
T.mbarrier_wait_parity(1, (ko & 1) ^ 1)
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.mbarrier_arrive(0)
with T.ws(0):
T.mbarrier_wait_parity(0, ko & 1)
T.gemm(A_shared, B_shared, C_local)
T.mbarrier_arrive(1)
with T.ws(0):
T.copy(C_local, C[by * block_M, bx * block_N])
return main
K = 64
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(128, 128, K, 128, 128, 32)
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
tilelang.disable_cache()
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
execution_backend="cython",
pass_configs={
"tl.disable_warp_specialized": True,
# "tl.disable_tma_lower": True,
})
tilelang.enable_cache()
print(jit_kernel.get_kernel_source())
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(128, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, 128, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = jit_kernel(a, b)
print(c)
# Reference multiplication using PyTorch
ref_c = a @ b
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# # 4. Retrieve and inspect the generated CUDA source (optional)
# # cuda_source = jit_kernel.get_kernel_source()
# # print("Generated CUDA kernel:\n", cuda_source)
# # 5.Profile latency with kernel
# profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
# latency = profiler.do_bench()
# print(f"Latency: {latency} ms")
import tilelang
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype, "shared")
B_shared = T.alloc_shared((block_K, block_N), dtype, "shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# create mbarrier for tma
T.create_list_of_mbarrier(128, 128)
with T.ws(1):
for ko in range(T.ceildiv(K, block_K)):
T.mbarrier_wait_parity(1, ko ^ 1)
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.mbarrier_arrive(0)
with T.ws(0):
T.clear(C_local)
for ko in range(T.ceildiv(K, block_K)):
T.mbarrier_wait_parity(0, ko)
T.gemm(A_shared, B_shared, C_local)
T.mbarrier_arrive(1)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(128, 128, 64, 128, 128, 32)
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
execution_backend="cython",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True,
})
print(jit_kernel.get_kernel_source())
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(128, 64, device="cuda", dtype=torch.float16)
b = torch.randn(64, 128, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = jit_kernel(a, b)
print(c)
# Reference multiplication using PyTorch
ref_c = a @ b
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
...@@ -317,8 +317,9 @@ Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size) { ...@@ -317,8 +317,9 @@ Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size) {
Var i = InputPlaceholder(0); Var i = InputPlaceholder(0);
Var j = InputPlaceholder(1); Var j = InputPlaceholder(1);
int vector_size = 128 / element_size; int vector_size = 128 / element_size;
ICHECK(stride % 8 == 0); ICHECK(stride % 8 == 0) << "stride=" << stride;
ICHECK(continuous % (vector_size * 8) == 0); ICHECK(continuous % (vector_size * 8) == 0)
<< "continuous=" << continuous << ", vector_size=" << vector_size;
PrimExpr ts = FloorDiv(i, 8); PrimExpr ts = FloorDiv(i, 8);
PrimExpr s = FloorMod(i, 8); PrimExpr s = FloorMod(i, 8);
PrimExpr tc = FloorDiv(FloorDiv(j, vector_size), 8); PrimExpr tc = FloorDiv(FloorDiv(j, vector_size), 8);
......
...@@ -104,6 +104,29 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -104,6 +104,29 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
} }
Buffer global_tensor = is_load ? src : dst; Buffer global_tensor = is_load ? src : dst;
Buffer shared_tensor = is_load ? dst : src; Buffer shared_tensor = is_load ? dst : src;
Array<Range> global_range = is_load ? src_range : dst_range;
Array<Range> shared_range = is_load ? dst_range : src_range;
Array<PrimExpr> indices;
for (auto r : shared_range)
indices.push_back(r->min);
std::vector<PrimExpr> strides;
PrimExpr stride = 1;
for (size_t i = 0; i < shared_tensor->shape.size(); i++) {
auto s = shared_tensor->shape[shared_tensor->shape.size() - i - 1];
strides.insert(strides.begin(), stride);
stride *= s;
}
ICHECK(strides.size() == indices.size())
<< "strides.size() != indices.size()" << strides.size() << " "
<< indices.size();
PrimExpr offset = 0;
for (size_t i = 0; i < indices.size(); i++) {
offset += indices[i] * strides[i];
}
Layout shared_layout; Layout shared_layout;
if (T.layout_map.count(shared_tensor)) { if (T.layout_map.count(shared_tensor)) {
shared_layout = T.layout_map[shared_tensor]; shared_layout = T.layout_map[shared_tensor];
...@@ -129,7 +152,6 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -129,7 +152,6 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
desc.data_type = to_CUtensorMapDataType(global_tensor->dtype); desc.data_type = to_CUtensorMapDataType(global_tensor->dtype);
// Global Tensor Shape and Stride // Global Tensor Shape and Stride
auto global_range = is_load ? src_range : dst_range;
desc.global_addr = global_tensor->data; desc.global_addr = global_tensor->data;
desc.global_shape = ReverseArray(global_tensor->shape); desc.global_shape = ReverseArray(global_tensor->shape);
Array<PrimExpr> global_coords = Array<PrimExpr> global_coords =
...@@ -217,16 +239,16 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -217,16 +239,16 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto op = is_load ? tma_load() : tma_store(); auto op = is_load ? tma_load() : tma_store();
Stmt tma_copy; Stmt tma_copy;
PrimExpr total_elements = 1;
for (auto e : desc.smem_box)
total_elements *= e;
if ((*inner_box_dim) != instruction_dim) { if ((*inner_box_dim) != instruction_dim) {
Var loop_var("i"); Var loop_var("i");
int loop_extent = (*inner_box_dim) / instruction_dim; int loop_extent = (*inner_box_dim) / instruction_dim;
PrimExpr total_elements = 1;
for (auto e : desc.smem_box) PrimExpr shared_addr = shared_tensor.access_ptr(
total_elements *= e; is_load ? 2 : 1, DataType::Handle(), 1,
PrimExpr shared_addr = offset + total_elements * loop_var, total_elements);
shared_tensor.access_ptr(is_load ? 2 : 1, DataType::Handle(), 1,
total_elements * loop_var, total_elements);
args.push_back(shared_addr); args.push_back(shared_addr);
global_coords.Set(0, global_coords[0] + instruction_dim * loop_var); global_coords.Set(0, global_coords[0] + instruction_dim * loop_var);
for (auto coord : global_coords) for (auto coord : global_coords)
...@@ -234,13 +256,14 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -234,13 +256,14 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled, tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled,
Evaluate(Call(DataType::Handle(), op, args))); Evaluate(Call(DataType::Handle(), op, args)));
} else { } else {
PrimExpr shared_addr = shared_tensor.access_ptr(is_load ? 2 : 1); PrimExpr shared_addr = shared_tensor.access_ptr(
is_load ? 2 : 1, DataType::Handle(), 1, offset, total_elements);
args.push_back(shared_addr); args.push_back(shared_addr);
for (auto coord : global_coords) for (auto coord : global_coords)
args.push_back(coord); args.push_back(coord);
tma_copy = Evaluate(Call(DataType::Handle(), op, args)); tma_copy = Evaluate(Call(DataType::Handle(), op, args));
} }
tma_copy = IfThenElse(EQ(T.thread_var, 0), tma_copy); tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
return tma_copy; return tma_copy;
} }
...@@ -393,7 +416,7 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T, ...@@ -393,7 +416,7 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T,
args.push_back(offset); args.push_back(offset);
Stmt tma_copy = Stmt tma_copy =
IfThenElse(EQ(T.thread_var, 0), IfThenElse(EQ(T.thread_var, T.thread_bounds->min),
Evaluate(Call(DataType::Handle(), tma_load_im2col(), args))); Evaluate(Call(DataType::Handle(), tma_load_im2col(), args)));
return tma_copy; return tma_copy;
} }
......
...@@ -34,9 +34,12 @@ static std::vector<int> toPrimeFactors(int x) { ...@@ -34,9 +34,12 @@ static std::vector<int> toPrimeFactors(int x) {
} }
Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) { Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
A = vmap[GetVarFromAccessPtr(args[0])]; Aptr = args[0];
B = vmap[GetVarFromAccessPtr(args[1])]; Bptr = args[1];
C = vmap[GetVarFromAccessPtr(args[2])]; Cptr = args[2];
A = vmap[GetVarFromAccessPtr(Aptr)];
B = vmap[GetVarFromAccessPtr(Bptr)];
C = vmap[GetVarFromAccessPtr(Cptr)];
trans_A = args[3].as<Bool>().value(); trans_A = args[3].as<Bool>().value();
trans_B = args[4].as<Bool>().value(); trans_B = args[4].as<Bool>().value();
M = args[5].as<IntImm>().value()->value; M = args[5].as<IntImm>().value()->value;
...@@ -149,9 +152,9 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -149,9 +152,9 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Array<PrimExpr> new_args; Array<PrimExpr> new_args;
new_args.push_back(StringImm(ss.str())); new_args.push_back(StringImm(ss.str()));
new_args.push_back(A_buffer.access_ptr(1)); new_args.push_back(Aptr);
new_args.push_back(B_buffer.access_ptr(1)); new_args.push_back(Bptr);
new_args.push_back(C_buffer.access_ptr(3)); new_args.push_back(Cptr);
auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args); auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);
return Evaluate(new_call); return Evaluate(new_call);
} }
...@@ -170,9 +173,10 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -170,9 +173,10 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment); results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") { if (A.scope() == "shared" || A.scope() == "shared.dyn") {
results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[0]), int dim_A = A->shape.size();
*as_const_int(A->shape[1]), true, results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
trans_A ? 1 : 2)); *as_const_int(A->shape[dim_A - 1]),
true, trans_A ? 1 : 2));
} else if (A.scope() == "local.fragment") { } else if (A.scope() == "local.fragment") {
ICHECK(trans_A == false); ICHECK(trans_A == false);
results.Set(A, makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n)); results.Set(A, makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n));
...@@ -181,9 +185,10 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -181,9 +185,10 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
} }
ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn"); ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[0]), int dim_B = B->shape.size();
*as_const_int(B->shape[1]), false, results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
trans_B ? 2 : 1)); *as_const_int(B->shape[dim_B - 1]),
false, trans_B ? 2 : 1));
} else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) { } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) {
const int warp_size = 32; const int warp_size = 32;
auto [warp_m, warp_n] = auto [warp_m, warp_n] =
...@@ -193,8 +198,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -193,8 +198,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
results.Set(C, fragment); results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") { if (A.scope() == "shared" || A.scope() == "shared.dyn") {
const int64_t mat_stride = *as_const_int(A->shape[0]); int dim_A = A->shape.size();
const int64_t mat_continuous = *as_const_int(A->shape[1]); const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
results.Set(A, results.Set(A,
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
A->dtype.bits(), trans_A ? 1 : 2)); A->dtype.bits(), trans_A ? 1 : 2));
...@@ -206,8 +212,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -206,8 +212,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(0); ICHECK(0);
} }
if (B.scope() == "shared" || B.scope() == "shared.dyn") { if (B.scope() == "shared" || B.scope() == "shared.dyn") {
const int64_t mat_stride = *as_const_int(B->shape[0]); int dim_B = B->shape.size();
const int64_t mat_continuous = *as_const_int(B->shape[1]); const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
results.Set(B, results.Set(B,
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
B->dtype.bits(), trans_B ? 2 : 1)); B->dtype.bits(), trans_B ? 2 : 1));
...@@ -230,8 +237,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -230,8 +237,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
: makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment); results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") { if (A.scope() == "shared" || A.scope() == "shared.dyn") {
const int64_t mat_stride = *as_const_int(A->shape[0]); int dim_A = A->shape.size();
const int64_t mat_continuous = *as_const_int(A->shape[1]); const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
const int64_t continuity = const int64_t continuity =
trans_A ? mat_continuous / (warp_m / 4) : mat_continuous; trans_A ? mat_continuous / (warp_m / 4) : mat_continuous;
results.Set(A, makeGemmABLayout(mat_stride, mat_continuous, continuity, results.Set(A, makeGemmABLayout(mat_stride, mat_continuous, continuity,
...@@ -242,8 +250,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -242,8 +250,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
A->dtype.bits())); A->dtype.bits()));
} }
if (B.scope() == "shared" || B.scope() == "shared.dyn") { if (B.scope() == "shared" || B.scope() == "shared.dyn") {
const int64_t mat_stride = *as_const_int(B->shape[0]); int dim_B = B->shape.size();
const int64_t mat_continuous = *as_const_int(B->shape[1]); const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
const int64_t continuity = const int64_t continuity =
trans_B ? mat_continuous : mat_continuous / warp_n; trans_B ? mat_continuous : mat_continuous / warp_n;
results.Set(B, makeGemmABLayout(mat_stride, mat_continuous, continuity, results.Set(B, makeGemmABLayout(mat_stride, mat_continuous, continuity,
...@@ -262,16 +271,10 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -262,16 +271,10 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
results.Set(C, fragment); results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") { if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size();
// Make Linear Memory Access Layout auto shared_layout = makeGemmABLayoutCDNA(
// auto shared_layout = *as_const_int(A->shape[dim_A - 2]),
// makeGemmLayoutLinear(*as_const_int(A->shape[0]), *as_const_int(A->shape[dim_A - 1]), A->dtype.bits(), kPack);
// *as_const_int(A->shape[1]));
// Make Swizzle or Pad Layout
auto shared_layout = makeGemmABLayoutCDNA(*as_const_int(A->shape[0]),
*as_const_int(A->shape[1]),
A->dtype.bits(), kPack);
results.Set(A, shared_layout); results.Set(A, shared_layout);
} else if (A.scope() == "local.fragment") { } else if (A.scope() == "local.fragment") {
results.Set(A, makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n, results.Set(A, makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
...@@ -280,15 +283,10 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -280,15 +283,10 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(0); ICHECK(0);
} }
if (B.scope() == "shared" || B.scope() == "shared.dyn") { if (B.scope() == "shared" || B.scope() == "shared.dyn") {
// Make Linear Memory Access Layout int dim_B = B->shape.size();
// auto shared_layout = auto shared_layout = makeGemmABLayoutCDNA(
// makeGemmLayoutLinear(*as_const_int(B->shape[0]), *as_const_int(B->shape[dim_B - 2]),
// *as_const_int(B->shape[1])); *as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack);
// Make Swizzle or Pad Layout
auto shared_layout = makeGemmABLayoutCDNA(*as_const_int(B->shape[0]),
*as_const_int(B->shape[1]),
B->dtype.bits(), kPack);
results.Set(B, shared_layout); results.Set(B, shared_layout);
} else if (B.scope() == "local.fragment") { } else if (B.scope() == "local.fragment") {
......
...@@ -33,6 +33,8 @@ private: ...@@ -33,6 +33,8 @@ private:
Array<PrimExpr> call_args; Array<PrimExpr> call_args;
tir::Buffer A, B, C; tir::Buffer A, B, C;
// pointer to the A, B, C
PrimExpr Aptr, Bptr, Cptr;
bool trans_A, trans_B; bool trans_A, trans_B;
int M, N, K; int M, N, K;
bool clear_accum = false; bool clear_accum = false;
......
...@@ -826,7 +826,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -826,7 +826,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
} else if (op->op.same_as(tl::sync_thread_partial())) { } else if (op->op.same_as(tl::sync_thread_partial())) {
print_extern_call_stmt("tl::syncthreads_partial"); print_extern_call_stmt("tl::syncthreads_partial");
} else if (op->op.same_as(tl::tma_load())) { } else if (op->op.same_as(tl::tma_load())) {
print_extern_call_stmt("tl::tma_load"); this->PrintIndent();
ICHECK_GE(op->args.size(), 2);
this->stream << "tl::tma_load(";
auto desc = op->args[0];
this->stream << this->PrintExpr(desc) << ", ";
if (const IntImmNode *imm = op->args[1].as<IntImmNode>()) {
this->stream << "_mbarrier[" << imm->value << "], ";
} else {
this->stream << this->PrintExpr(op->args[1]) << ", ";
}
for (size_t i = 2; i < op->args.size(); i++) {
if (i > 2)
this->stream << ", ";
this->stream << this->PrintExpr(op->args[i]);
}
this->stream << ");\n";
} else if (op->op.same_as(tl::tma_load_im2col())) { } else if (op->op.same_as(tl::tma_load_im2col())) {
print_extern_call_stmt("tl::tma_load_im2col"); print_extern_call_stmt("tl::tma_load_im2col");
} else if (op->op.same_as(tl::tma_store())) { } else if (op->op.same_as(tl::tma_store())) {
......
...@@ -783,19 +783,6 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -783,19 +783,6 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
print_extern_call_stmt("tl::mbarrier_wait"); print_extern_call_stmt("tl::mbarrier_wait");
} else if (op->op.same_as(tl::sync_thread_partial())) { } else if (op->op.same_as(tl::sync_thread_partial())) {
print_extern_call_stmt("tl::syncthreads_partial"); print_extern_call_stmt("tl::syncthreads_partial");
} else if (op->op.same_as(tl::tma_load())) {
print_extern_call_stmt("tl::tma_load");
} else if (op->op.same_as(tl::tma_load_im2col())) {
print_extern_call_stmt("tl::tma_load_im2col");
} else if (op->op.same_as(tl::tma_store())) {
print_extern_call_stmt("tl::tma_store");
} else if (op->op.same_as(tl::ptx_ldmatirx())) {
int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
if (trans == 1)
func_name += "_trans";
print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::ptx_stmatirx())) { } else if (op->op.same_as(tl::ptx_stmatirx())) {
int trans = Downcast<IntImm>(op->args[0])->value; int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value; int num = Downcast<IntImm>(op->args[1])->value;
...@@ -803,15 +790,6 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -803,15 +790,6 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
if (trans == 1) if (trans == 1)
func_name += "_trans"; func_name += "_trans";
print_extern_call_stmt(func_name, 2); print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::fence_proxy_async())) {
print_extern_call_stmt("tl::fence_proxy_async");
} else if (op->op.same_as(tl::set_max_nreg())) {
this->PrintIndent();
int nreg = Downcast<IntImm>(op->args[0])->value;
int is_inc = Downcast<IntImm>(op->args[1])->value;
std::string func_name =
is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc";
this->stream << func_name << "<" << std::to_string(nreg) << ">();\n";
} else if (op->op.same_as(tl::wait_wgmma())) { } else if (op->op.same_as(tl::wait_wgmma())) {
this->PrintIndent(); this->PrintIndent();
int num_mma = Downcast<IntImm>(op->args[0])->value; int num_mma = Downcast<IntImm>(op->args[0])->value;
......
...@@ -37,9 +37,32 @@ static Buffer makeBufferWithLayout(const Buffer &buffer, const Layout &layout) { ...@@ -37,9 +37,32 @@ static Buffer makeBufferWithLayout(const Buffer &buffer, const Layout &layout) {
} else { } else {
new_var = Var(buffer->data->name_hint, new_type); new_var = Var(buffer->data->name_hint, new_type);
} }
return Buffer(new_var, buffer->dtype, layout->OutputShape(), {}, Array<PrimExpr> layout_shape = layout->OutputShape();
buffer->elem_offset, buffer->name, buffer->data_alignment, Array<PrimExpr> output_shape = layout_shape;
buffer->offset_factor, buffer->buffer_type);
if (ptr_type->storage_scope == "shared" ||
ptr_type->storage_scope == "shared.dyn") {
int replicate_extent = 1;
Array<PrimExpr> buffer_shape = buffer->shape;
int buffer_extent = 1;
int layout_extent = 1;
for (size_t i = 0; i < buffer_shape.size(); i++) {
auto shape = buffer_shape[i].as<IntImmNode>();
buffer_extent *= shape->value;
}
for (size_t i = 0; i < layout_shape.size(); i++) {
auto shape = layout_shape[i].as<IntImmNode>();
layout_extent *= shape->value;
}
replicate_extent = buffer_extent / layout_extent;
if (replicate_extent > 1) {
output_shape.insert(output_shape.begin(), replicate_extent);
}
}
return Buffer(new_var, buffer->dtype, output_shape, {}, buffer->elem_offset,
buffer->name, buffer->data_alignment, buffer->offset_factor,
buffer->buffer_type);
} }
class LowerTileOpPass : arith::IRMutatorWithAnalyzer { class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tma_barrier_rewriter.cc
* \brief Rewrite TMA barriers for cuda GPU (sm90+)
*/
#include <tvm/arith/analyzer.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
namespace tvm {
namespace tl {
using namespace tir;
using namespace tir::transform;
using arith::IRMutatorWithAnalyzer;
class TmaTraitsCollector : public StmtExprVisitor {
public:
TmaTraitsCollector() { Initialize(); }
void Initialize() {
bulk_copy_bytes = 0;
loop_extents = 1;
}
void Collect(Stmt stmt) { VisitStmt(stmt); }
PrimExpr BulkCopyBytes() { return bulk_copy_bytes; }
private:
void VisitExpr_(const CallNode *call) final {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
Call access_ptr = Downcast<Call>(call->args[2]);
ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
int type_bytes = access_ptr->args[0]->dtype.bytes();
bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes;
}
StmtExprVisitor::VisitExpr_(call);
}
void VisitStmt_(const ForNode *op) final {
PrimExpr old_loop_evtents = loop_extents;
loop_extents *= op->extent;
StmtExprVisitor::VisitStmt_(op);
loop_extents = old_loop_evtents;
}
PrimExpr bulk_copy_bytes = 0;
PrimExpr loop_extents = 1;
};
class TmaExpectTxRewriter : public IRMutatorWithAnalyzer {
public:
static PrimFunc Rewrite(PrimFunc f, arith::Analyzer *analyzer) {
TmaExpectTxRewriter rewriter(analyzer);
f.CopyOnWrite()->body = rewriter(f->body);
return f;
}
private:
bool inside_tma_block_{false};
bool visited_tma_load_{false};
IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
IterVarType::kDataPar);
PrimExpr GetBarrierId(const PrimExpr &ko) {
// FloorMod(ko, 1)
return FloorMod(ko, IntImm(DataType::Int(32), 1));
}
PrimExpr GetBarrierParity(const PrimExpr &ko) {
// FloorDiv(ko, 1) % 2
return FloorMod(FloorDiv(ko, IntImm(DataType::Int(32), 1)),
IntImm(DataType::Int(32), 2));
}
PrimExpr makeGetBarrier(PrimExpr barrier_id) {
return Call(DataType::Handle(), get_mbarrier(), {barrier_id});
}
Stmt makeExpectTX(PrimExpr barrier_id, PrimExpr bytes) {
auto call = Call(DataType::Handle(), mbarrier_expect_tx(),
{makeGetBarrier(barrier_id), bytes});
return Evaluate(call);
}
TmaExpectTxRewriter(arith::Analyzer *analyzer)
: IRMutatorWithAnalyzer(analyzer) {}
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
ICHECK(iv->dom->extent.as<IntImmNode>());
thread_var_ = iv;
}
}
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
Stmt VisitStmt_(const IfThenElseNode *op) {
// Check if this is the TMA block
const EQNode *eq = op->condition.as<EQNode>();
if (eq != nullptr) {
Stmt ret = IRMutatorWithAnalyzer::VisitStmt_(op);
if (visited_tma_load_) {
auto then_case = op->then_case;
TmaTraitsCollector collector;
collector.Collect(then_case);
Array<Stmt> stmts;
if (!is_zero(collector.BulkCopyBytes())) {
auto expect_tx = makeExpectTX(0, collector.BulkCopyBytes());
stmts.push_back(expect_tx);
}
stmts.push_back(then_case);
if (stmts.size() == 1) {
return IfThenElse(op->condition, stmts[0], op->else_case);
} else {
auto seq_stmt = SeqStmt(stmts);
return IfThenElse(op->condition, seq_stmt, op->else_case);
}
}
visited_tma_load_ = false;
return ret;
}
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(tma_load())) {
visited_tma_load_ = true;
Array<PrimExpr> new_args = op->args;
new_args.Set(1, Call(DataType::Handle(), get_mbarrier(),
{IntImm(DataType::Int(32), 0)}));
return Call(op->dtype, op->op, new_args);
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
};
class TmaBarrierCollector : public StmtExprVisitor {
public:
Map<ObjectRef, PrimExpr> tma_op_to_barrier_id() {
return tma_op_to_barrier_id_;
}
private:
void VisitStmt_(const EvaluateNode *op) final {
if (const auto *call = op->value.as<CallNode>()) {
if (call->op.same_as(tma_load())) {
pending_tma_ops_.push_back(GetRef<Call>(call));
} else if (call->op.same_as(mbarrier_expect_tx())) {
pending_tma_ops_.push_back(GetRef<Call>(call));
} else if (call->op.same_as(builtin::ptx_arrive_barrier())) {
PrimExpr barrier_id = call->args[0];
for (auto tma_call : pending_tma_ops_) {
tma_op_to_barrier_id_.Set(tma_call, barrier_id);
}
pending_tma_ops_.clear();
}
}
StmtExprVisitor::VisitStmt_(op);
}
std::vector<Call> pending_tma_ops_;
Map<ObjectRef, PrimExpr> tma_op_to_barrier_id_;
};
// we trust mbarrier_wait_parity to be correct
class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
public:
TmaBarrierRewriter(arith::Analyzer *analyzer,
Map<ObjectRef, PrimExpr> tma_op_to_barrier_id)
: IRMutatorWithAnalyzer(analyzer),
tma_op_to_barrier_id_(tma_op_to_barrier_id) {}
static PrimFunc Rewrite(PrimFunc f, arith::Analyzer *analyzer) {
f = TmaExpectTxRewriter::Rewrite(f, analyzer);
TmaBarrierCollector collector;
collector(f->body);
TmaBarrierRewriter rewriter(analyzer, collector.tma_op_to_barrier_id());
f.CopyOnWrite()->body = rewriter(f->body);
return f;
}
private:
PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(tma_load())) {
// check this must be in the tma_op_to_barrier_id_
ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op)))
<< "tma_load must be in the tma_op_to_barrier_id_";
auto barrier_id = tma_op_to_barrier_id_[GetRef<Call>(op)];
auto new_args = op->args;
new_args.Set(1, barrier_id);
return Call(op->dtype, op->op, new_args);
} else if (op->op.same_as(mbarrier_expect_tx())) {
ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op)))
<< "mbarrier_expect_tx must be in the tma_op_to_barrier_id_";
auto barrier_id = tma_op_to_barrier_id_[GetRef<Call>(op)];
auto new_args = op->args;
new_args.Set(0, barrier_id);
return Call(op->dtype, op->op, new_args);
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
Map<ObjectRef, PrimExpr> tma_op_to_barrier_id_;
};
tvm::transform::Pass InjectTmaBarrier() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
arith::Analyzer analyzer;
return TmaBarrierRewriter::Rewrite(f, &analyzer);
};
return CreatePrimFuncPass(pass_func, 0, "tl.InjectTmaBarrier", {});
}
TVM_REGISTER_GLOBAL("tl.transform.InjectTmaBarrier")
.set_body_typed(InjectTmaBarrier);
} // namespace tl
} // namespace tvm
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* \brief Warp specialized Pipeline for cuda GPU (sm90+) * \brief Warp specialized Pipeline for cuda GPU (sm90+)
*/ */
#include "arith/ir_visitor_with_analyzer.h"
#include "tir/analysis/var_use_def_analysis.h" #include "tir/analysis/var_use_def_analysis.h"
#include <tvm/tir/analysis.h> #include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
...@@ -35,6 +36,7 @@ namespace tvm { ...@@ -35,6 +36,7 @@ namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
using arith::IRVisitorWithAnalyzer;
enum class Role { kConsumer, kProducer, kBoth }; enum class Role { kConsumer, kProducer, kBoth };
...@@ -209,12 +211,6 @@ static PrimExpr makeGetBarrier(PrimExpr barrier_id) { ...@@ -209,12 +211,6 @@ static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
return Call(DataType::Handle(), get_mbarrier(), {barrier_id}); return Call(DataType::Handle(), get_mbarrier(), {barrier_id});
} }
static Stmt makeExpectTX(PrimExpr barrier_id, PrimExpr bytes) {
auto call = Call(DataType::Handle(), mbarrier_expect_tx(),
{makeGetBarrier(barrier_id), bytes});
return Evaluate(call);
}
static Stmt makeArriveBarrier(PrimExpr barrier_id) { static Stmt makeArriveBarrier(PrimExpr barrier_id) {
auto call = Call(DataType::Handle(), builtin::ptx_arrive_barrier(), auto call = Call(DataType::Handle(), builtin::ptx_arrive_barrier(),
{makeGetBarrier(barrier_id)}); {makeGetBarrier(barrier_id)});
...@@ -233,95 +229,17 @@ static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) { ...@@ -233,95 +229,17 @@ static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) {
return Evaluate(call); return Evaluate(call);
} }
// static bool isGemm(Stmt stmt) {
// bool is_gemm = false;
// if (stmt.as<EvaluateNode>()) {
// auto call = Downcast<Evaluate>(stmt)->value.as<CallNode>();
// if (call && call->op.same_as(Op::Get("tir.call_extern"))) {
// if (call->args[0].as<StringImmNode>()) {
// std::string name = Downcast<StringImm>(call->args[0])->value;
// if (name.find("gemm") != std::string::npos) {
// is_gemm = true;
// }
// }
// }
// }
// return is_gemm;
// }
class TMAExpectTxRewriter : public StmtExprMutator {
public:
TMAExpectTxRewriter(Stmt expect_tx) : expect_tx_(expect_tx) {}
static Stmt Rewrite(Stmt stmt, Stmt expect_tx) {
TMAExpectTxRewriter rewriter(expect_tx);
return rewriter(stmt);
}
private:
Stmt VisitStmt_(const ForNode *op) final {
insert_in_evaluate_ = false;
StmtExprMutator::VisitStmt_(op);
insert_in_evaluate_ = true;
if (contain_tma_load_) {
Array<Stmt> new_seq = {expect_tx_, GetRef<For>(op)};
contain_tma_load_ = false;
return SeqStmt(std::move(new_seq));
}
return StmtExprMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
contain_tma_load_ = true;
if (insert_in_evaluate_) {
Array<Stmt> new_seq = {expect_tx_, GetRef<Evaluate>(op)};
return SeqStmt(std::move(new_seq));
}
}
}
return StmtExprMutator::VisitStmt_(op);
}
Stmt expect_tx_;
bool contain_tma_load_;
bool insert_in_evaluate_ = true;
};
class ProducerTraitsCollector : public StmtExprVisitor { class ProducerTraitsCollector : public StmtExprVisitor {
public: public:
ProducerTraitsCollector() { Clear(); } ProducerTraitsCollector() { Clear(); }
void Clear() { void Clear() { has_simt_copy = false; }
bulk_copy_bytes = 0;
loop_extents = 1;
has_simt_copy = false;
}
void Collect(Stmt stmt) { VisitStmt(stmt); } void Collect(Stmt stmt) { VisitStmt(stmt); }
bool HasSimtCopy() { return has_simt_copy; } bool HasSimtCopy() { return has_simt_copy; }
PrimExpr BulkCopyBytes() { return bulk_copy_bytes; }
private: private:
void VisitExpr_(const CallNode *call) final {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
Call access_ptr = Downcast<Call>(call->args[2]);
ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
int type_bytes = access_ptr->args[0]->dtype.bytes();
bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes;
}
StmtExprVisitor::VisitExpr_(call);
}
void VisitStmt_(const ForNode *op) final {
PrimExpr old_loop_evtents = loop_extents;
loop_extents *= op->extent;
StmtExprVisitor::VisitStmt_(op);
loop_extents = old_loop_evtents;
}
void VisitStmt_(const IfThenElseNode *op) final { void VisitStmt_(const IfThenElseNode *op) final {
bool old_in_if_cond = in_if_cond_; bool old_in_if_cond = in_if_cond_;
in_if_cond_ = true; in_if_cond_ = true;
...@@ -342,8 +260,6 @@ private: ...@@ -342,8 +260,6 @@ private:
} }
bool has_simt_copy; bool has_simt_copy;
PrimExpr bulk_copy_bytes;
PrimExpr loop_extents;
bool in_if_cond_ = false; bool in_if_cond_ = false;
}; };
...@@ -646,14 +562,7 @@ private: ...@@ -646,14 +562,7 @@ private:
auto stmt = auto stmt =
MbarrierRewriter::Rewrite(seq_transformed[i], release_barrier_id); MbarrierRewriter::Rewrite(seq_transformed[i], release_barrier_id);
collector.Collect(stmt); collector.Collect(stmt);
if (!is_zero(collector.BulkCopyBytes())) { block_stmt.push_back(stmt);
auto expect_tx = IfThenElse(
EQ(thread_var_, 0),
makeExpectTX(release_barrier_id, collector.BulkCopyBytes()));
block_stmt.push_back(TMAExpectTxRewriter::Rewrite(stmt, expect_tx));
} else {
block_stmt.push_back(stmt);
}
if (collector.HasSimtCopy() > 0) { if (collector.HasSimtCopy() > 0) {
block_stmt.push_back(makeCpAsyncBarrier(release_barrier_id)); block_stmt.push_back(makeCpAsyncBarrier(release_barrier_id));
} }
...@@ -1276,13 +1185,56 @@ private: ...@@ -1276,13 +1185,56 @@ private:
Array<IntImm> nreg_; Array<IntImm> nreg_;
}; };
class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
public:
static bool Detect(Stmt stmt, bool skip_thread_partition = false) {
WarpSpecializedDetector detector;
detector.VisitStmt(stmt);
return detector.has_tma_op_ && detector.has_mbarrier_op_;
}
WarpSpecializedDetector() {
has_tma_op_ = false;
has_mbarrier_op_ = false;
}
private:
void VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(create_list_of_mbarrier()) ||
call->op.same_as(mbarrier_wait_parity()) ||
call->op.same_as(builtin::ptx_arrive_barrier()) ||
call->op.same_as(builtin::ptx_cp_async_barrier())) {
has_mbarrier_op_ = true;
}
}
IRVisitorWithAnalyzer::VisitStmt_(op);
}
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col()) ||
op->op.same_as(set_max_nreg())) {
has_tma_op_ = true;
}
IRVisitorWithAnalyzer::VisitExpr_(op);
}
bool has_tma_op_{false};
bool has_mbarrier_op_{false};
};
using namespace tir::transform; using namespace tir::transform;
tvm::transform::Pass WarpSpecialized() { tvm::transform::Pass WarpSpecialized() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
bool disable_warp_specialized = bool disable_warp_specialized =
ctx->GetConfig<Bool>(kDisableWarpSpecialized, Bool(false)).value(); ctx->GetConfig<Bool>(kDisableWarpSpecialized, Bool(false)).value();
return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized); bool warp_specialized = WarpSpecializedDetector::Detect(f->body);
if (!warp_specialized) {
return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized);
}
return f;
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {}); return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {});
} }
......
...@@ -62,6 +62,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -62,6 +62,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.WarpSpecialized()(mod) mod = tilelang.transform.WarpSpecialized()(mod)
# if tma is not enabled, we can also do pipeline planning # if tma is not enabled, we can also do pipeline planning
# to get better performance with async copy # to get better performance with async copy
mod = tilelang.transform.InjectTmaBarrier()(mod)
mod = tilelang.transform.PipelinePlanning()(mod) mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod) mod = tilelang.transform.InjectSoftwarePipeline()(mod)
# warp_specialized pass will pack the if stmt into the block # warp_specialized pass will pack the if stmt into the block
......
...@@ -3,20 +3,40 @@ ...@@ -3,20 +3,40 @@
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang.language import ptx_arrive_barrier from tilelang.language import ptx_arrive_barrier
from tvm import tir from tvm import tir
from typing import Union from typing import Union, Any
from tvm.tir import PrimExpr, Var from tvm.tir import PrimExpr, Var, Call
def create_list_of_mbarrier(*args): def create_list_of_mbarrier(*args: Any) -> Call:
"""Create a list of memory barrier operations. """
Create a list of memory barrier handles.
Args:
*args: Variable arguments passed to the memory barrier creation operation Parameters
----------
Returns: *args : list or Any
tir.Call: A handle to the created list of memory barriers Either a single list of arguments, or multiple arguments directly.
Returns
-------
tvm.tir.Call
Handle to the created list of memory barriers.
Raises
------
TypeError
If the input is not a list or variadic arguments.
Examples
--------
>>> create_list_of_mbarrier([128, 128])
>>> create_list_of_mbarrier(128, 128)
""" """
return tir.call_intrin("handle", tir.op.Op.get("tl.create_list_of_mbarrier"), *args) if len(args) == 1 and isinstance(args[0], list):
return tir.call_intrin("handle", tir.op.Op.get("tl.create_list_of_mbarrier"), *args[0])
elif len(args) >= 1:
return tir.call_intrin("handle", tir.op.Op.get("tl.create_list_of_mbarrier"), *args)
else:
raise TypeError("create_list_of_mbarrier expects a list or one or more arguments.")
def get_mbarrier(*args): def get_mbarrier(*args):
...@@ -115,7 +135,7 @@ def no_set_max_nreg(*args): ...@@ -115,7 +135,7 @@ def no_set_max_nreg(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.no_set_max_nreg"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.no_set_max_nreg"), *args)
def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr], parity: Union[int, Var]): def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr, tir.Call], parity: Union[int, Var]):
"""Wait for memory barrier parity condition. """Wait for memory barrier parity condition.
Args: Args:
...@@ -154,20 +174,28 @@ def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr], parity: Union[int, Var] ...@@ -154,20 +174,28 @@ def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr], parity: Union[int, Var]
Returns: Returns:
tir.Call: A handle to the barrier wait operation tir.Call: A handle to the barrier wait operation
""" """
if isinstance(mbarrier, int): if isinstance(mbarrier, tir.Call):
mbarrier = mbarrier
elif isinstance(mbarrier, (tir.PrimExpr, int)):
mbarrier = get_mbarrier(mbarrier) mbarrier = get_mbarrier(mbarrier)
else:
raise TypeError("mbarrier must be an integer or a tir.Call")
return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_wait_parity"), mbarrier, parity) return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_wait_parity"), mbarrier, parity)
def mbarrier_arrive(mbarrier: Union[int, PrimExpr]): def mbarrier_arrive(mbarrier: Union[int, PrimExpr, tir.Call]):
"""Arrive at memory barrier. """Arrive at memory barrier.
Args: Args:
mbarrier: Optional[int, PrimExpr] mbarrier: Optional[int, PrimExpr]
The memory barrier to arrive at The memory barrier to arrive at
""" """
if isinstance(mbarrier, int): if isinstance(mbarrier, tir.Call):
mbarrier = mbarrier
elif isinstance(mbarrier, (tir.PrimExpr, int)):
mbarrier = get_mbarrier(mbarrier) mbarrier = get_mbarrier(mbarrier)
else:
raise TypeError("mbarrier must be an integer or a tir.Call")
return ptx_arrive_barrier(mbarrier) return ptx_arrive_barrier(mbarrier)
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from tilelang.primitives.gemm.base import GemmWarpPolicy from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T import tilelang.language as T
from tvm import tir from tvm import tir
from typing import Union from typing import Union, List
def gemm( def gemm(
...@@ -56,14 +56,64 @@ def gemm( ...@@ -56,14 +56,64 @@ def gemm(
A = legalize_arguments(A) A = legalize_arguments(A)
B = legalize_arguments(B) B = legalize_arguments(B)
C = legalize_arguments(C) C = legalize_arguments(C)
M = C.shape[0]
N = C.shape[1] def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
K = A.shape[0] if transpose_A else A.shape[1] if isinstance(object, tir.Buffer):
K_B = B.shape[1] if transpose_B else B.shape[0] return object.shape
assert K == K_B, "gemm K shape check failed" elif isinstance(object, tir.BufferRegion):
Aptr = A.access_ptr("r") region = object.region
Bptr = B.access_ptr("r") shape = []
Cptr = C.access_ptr("rw") for r in region:
shape.append(r.extent)
return shape
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
A_shape = retrieve_shape(A)
B_shape = retrieve_shape(B)
C_shape = retrieve_shape(C)
assert len(C_shape) == 2, "current only support C as a 2D tensor"
assert len(A_shape) >= 2, "current only support A as a 2D or higher-order tensor"
assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor"
if len(A_shape) > 2:
for i in range(len(A_shape) - 2):
assert A_shape[i] == 1, \
"current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
if len(B_shape) > 2:
for i in range(len(B_shape) - 2):
assert B_shape[i] == 1, \
"current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
M, N = C_shape
K = A_shape[-2] if transpose_A else A_shape[-1]
K_B = B_shape[-1] if transpose_B else B_shape[-2]
assert K == K_B, f"T.gemm K shape check failed: K_A = {K}, K_B = {K_B}"
def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion],
access_type: str = "r") -> tir.PrimExpr:
if isinstance(object, tir.Buffer):
return object.access_ptr(access_type)
elif isinstance(object, tir.BufferRegion):
buffer, region = object.buffer, object.region
indices = []
for r in region:
indices.append(r.min)
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
offset = 0
for i in range(len(indices)):
offset += indices[i] * strides[i]
return buffer.access_ptr(access_mask=access_type, offset=offset)
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
Aptr = retrieve_ptr(A, "r")
Bptr = retrieve_ptr(B, "r")
Cptr = retrieve_ptr(C, "rw")
return tir.call_intrin( return tir.call_intrin(
"handle", "handle",
tir.op.Op.get("tl.gemm"), tir.op.Op.get("tl.gemm"),
......
...@@ -189,6 +189,17 @@ def WarpSpecialized(): ...@@ -189,6 +189,17 @@ def WarpSpecialized():
return _ffi_api.WarpSpecialized() # type: ignore return _ffi_api.WarpSpecialized() # type: ignore
def InjectTmaBarrier():
"""InjectTmaBarrier
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InjectTmaBarrier() # type: ignore
def InjectFenceProxy(): def InjectFenceProxy():
"""InjectFenceProxy """InjectFenceProxy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment