Commit b5faf25a authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Add new examples for warp specialization and TMA integration (#448)

* [Refactor] Update KernelLaunch to clarify CPU and GPU kernel launch logic

* Added comments to distinguish between CPU and GPU kernel launch sections for better code readability.
* Changed the creation of empty blocks to use a consistent "root" identifier, enhancing clarity in frame management.

* [Refactor] Rename operations for consistency in lower_hopper_intrin and related files

* Updated function names from CamelCase to snake_case for better consistency across the codebase.
* Refactored calls to `CreateTMADescriptorOp`, `CreateListofMBarrierOp`, and similar functions to their new names: `create_tma_descriptor`, `create_list_of_mbarrier`, etc.
* Adjusted corresponding test cases to reflect these changes, ensuring compatibility with the new naming conventions.

* [Refactor] Rename operations to snake_case for consistency

* Updated function names from CamelCase to snake_case across various files, including `CreateTMADescriptorOp` to `create_tma_descriptor`, `GetMBarrierOp` to `get_mbarrier`, and others.
* Adjusted corresponding calls and definitions in the codebase to reflect these naming changes, ensuring uniformity and improved readability.
* Enhanced layout inference and loop partitioning logic to accommodate the new naming conventions.

* [Feature] Introduce Warp Specialization and Eliminate Storage Sync for MBarrier

* Added a new example `gemm_ws.py` demonstrating matrix multiplication with warp specialization using TileLang.
* Implemented `WarpSpecializeFrame` and `WarpSpecialize` functionality to manage warp group indices in TIR frames.
* Introduced `EliminateStorageSyncForMBarrier` transformation to optimize storage synchronization in mbarrier regions.
* Enhanced the TileLang API with new methods for retrieving block and thread extents.
* Updated the `LowerAndLegalize` and `OptimizeForTarget` functions to incorporate the new transformation.
* Improved layout inference and kernel launch logic for better performance and clarity.

* [Refactor] Clean up code formatting and improve readability

* Added blank lines for better separation of code blocks in `gemm_ws.py`, `phase.py`, `kernel.py`, and `warpgroup.py`.
* Reformatted the `tilelang.compile` call in `gemm_ws.py` for improved clarity.
* Updated comments in `warpgroup.py` to clarify the availability of the `WarpSpecialize` function for NVIDIA GPUs.
* Ensured consistent spacing and formatting across multiple files to enhance overall code readability.

* lint fix

* [Refactor] Update mbarrier functions for improved clarity and consistency

* Refactored `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` to accept explicit parameters for better readability.
* Updated calls in `gemm_ws.py` to use the new function signatures, enhancing code clarity.
* Adjusted `warpgroup.py` to remove unused thread extent variable, streamlining the code.
* Added detailed docstrings to clarify usage examples for memory barrier functions.

* Added blank lines in `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` for improved code readability and separation of logical sections.

* [Feature] Add examples for warp specialization and TMA barrier integration

* Introduced three new example scripts: `example_warp_specialize_gemm.py`, `example_warp_specialize_gemm_barrier4.py`, and `example_warp_specialize_mla.py` demonstrating matrix multiplication with warp specialization and TMA barriers.
* Implemented kernel functions with shared memory allocation and memory barrier synchronization for improved performance.
* Enhanced the TileLang API with new methods for compiling and testing kernels in Python using PyTorch.
* Updated the `phase.py` to include TMA barrier injection in the optimization process.
* Improved documentation and comments for better clarity on usage and functionality.

* [Feature] Add example for warp specialization in GEMM with TMA barriers

* Introduced a new example script `example_warp_specialize_gemm_stage2.py` demonstrating matrix multiplication using warp specialization and TMA barriers.
* Implemented a kernel function with shared memory allocation and memory barrier synchronization for enhanced performance.
* Included functionality to compile the kernel into a PyTorch-compatible function and validate its correctness against PyTorch's reference implementation.
* Enhanced documentation and comments for clarity on usage and functionality.

* lint fix

* [Feature] Implement WarpSpecializedDetector for TMA and MBarrier Detection

* Added the `WarpSpecializedDetector` class to identify the presence of TMA operations and memory barrier operations within a given TIR statement.
* Enhanced the `WarpSpecialized` pass to utilize the detector, allowing for conditional substitution based on the detection results.
* Improved code organization by including necessary headers and utilizing the `IRVisitorWithAnalyzer` for analysis.
* This addition aims to optimize warp specialization by ensuring that only relevant functions are transformed, enhancing performance and correctness.

* lint fix

* [Feature] Add new examples for warp specialization and TMA integration

* Introduced multiple new example scripts demonstrating warp specialization techniques, including `example_warp_specialize_flashmla.py`, `example_warp_specialize_gemm_barrierpipe_stage2.py`, `example_warp_specialize_gemm_copy_0_gemm_1.py`, `example_warp_specialize_gemm_copy_1_gemm_0.py`, and `example_warp_specialize_gemm_softpipe_stage2.py`.
* Each example showcases matrix multiplication with warp specialization and TMA barriers, implementing kernel functions with shared memory allocation and memory barrier synchronization for enhanced performance.
* Added a test suite in `test_example_warp_specialize.py` to validate the functionality of the new examples.
* Updated the TileLang API to support these examples and improve kernel compilation and testing processes.
* Removed outdated example scripts to streamline the codebase and enhance clarity on available functionalities.

* lint fix

* Remove outdated example scripts for warp specialization and TMA integration to streamline the codebase. This includes `example_warp_specialize_gemm.py`, `example_warp_specialize_gemm_barrier4.py`, `example_warp_specialize_gemm_stage2.py`, and `example_warp_specialize_mla.py`, which are no longer needed following recent updates and improvements in the TileLang API.
parent fce16b00
# use default stage 1 template, not the optimal
# schedule, please checkout examples/deepseek_mla
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from einops import rearrange, einsum
import argparse
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1"
@T.macro
def flash_attn(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=384) as (bx, by):
Q_shared = T.alloc_shared([block_H, dim], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
KV_shared = T.alloc_shared([block_N, dim], dtype)
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
O_shared = T.alloc_shared([block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
})
T.create_list_of_mbarrier(128, 128, 256, 128)
loop_range = T.ceildiv(seqlen_kv, block_N)
with T.ws(2):
T.dec_max_nreg(24)
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.mbarrier_arrive(T.get_mbarrier(3))
for k in T.serial(loop_range):
T.mbarrier_wait_parity(
T.FloorMod(k, 1) + 2, T.bitwise_xor(T.FloorDiv(k, 1) % 2, 1))
T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
T.mbarrier_arrive(T.FloorMod(k, 1))
T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.mbarrier_arrive(T.FloorMod(k, 1) + 1)
with T.ws(0, 1):
T.inc_max_nreg(240)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
T.mbarrier_wait_parity(T.get_mbarrier(3), 0)
for k in T.serial(loop_range):
T.clear(acc_s)
T.mbarrier_wait_parity(T.get_mbarrier(T.FloorMod(k, 1)), T.FloorDiv(k, 1) % 2)
T.gemm(
Q_shared,
KV_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.mbarrier_wait_parity(
T.get_mbarrier(T.FloorMod(k, 1) + 1),
T.FloorDiv(k, 1) % 2)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
T.copy(acc_s, S_shared)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
T.mbarrier_arrive(T.get_mbarrier(T.FloorMod(k, 1) + 2))
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :])
@T.prim_func
def main_no_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
flash_attn(Q, Q_pe, KV, K_pe, Output)
return main_no_split
def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
# """
# Inputs:
# - q (Tensor): [batch, heads, dim]
# - q_pe (Tensor): [batch, heads, pe_dim]
# - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
# - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
# - glse (Tensor): [batch, heads, num_split]
# - Output_partial (Tensor): [batch, heads, num_split, dim]
# Outputs:
# - output (Tensor): [batch, heads, dim]
# """
dim = q.shape[-1]
pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim)**0.5
q = rearrange(
q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange(
q_pe, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1)
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
pv_flops = 2 * batch * heads * kv_ctx * dim
total_flops = qk_flops + pv_flops
BLOCK_N = 64
BLOCK_H = 64
num_split = 1
program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
kernel = tilelang.compile(program, out_idx=[6])
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
latency = profiler.do_bench(warmup=500)
print(f"Latency: {latency} ms")
print(f"TFlops: {total_flops / latency * 1e-9} TFlops")
if __name__ == "__main__":
main()
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
tilelang.disable_cache()
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
...@@ -11,85 +9,89 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -11,85 +9,89 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor[(M, K), dtype],
B: T.Tensor((K, N), dtype), B: T.Tensor[(K, N), dtype],
C: T.Tensor((M, N), dtype), C: T.Tensor[(M, N), dtype],
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((num_stages, block_M, block_K), dtype, "shared") A_shared = T.alloc_shared((num_stages, block_M, block_K), dtype)
B_shared = T.alloc_shared((num_stages, block_K, block_N), dtype, "shared") B_shared = T.alloc_shared((num_stages, block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# create mbarrier for tma # create mbarrier for tma
T.create_list_of_mbarrier(mbarrier_list) T.create_list_of_mbarrier(mbarrier_list)
with T.ws(1): with T.ws(0):
for ko in range(T.ceildiv(K, block_K)): T.clear(C_local)
T.mbarrier_wait_parity(ko % num_stages + num_stages,
((ko // num_stages) % num_stages) ^ 1) for ko in range(T.ceildiv(K, block_K)):
with T.ws(1):
T.mbarrier_wait_parity(
mbarrier=ko % num_stages + num_stages,
parity=((ko // num_stages) % num_stages) ^ 1)
T.copy(A[by * block_M:(by + 1) * block_M, ko * block_K:(ko + 1) * block_K], T.copy(A[by * block_M:(by + 1) * block_M, ko * block_K:(ko + 1) * block_K],
A_shared[ko % num_stages, :, :]) A_shared[ko % num_stages, :, :])
T.copy(B[ko * block_K:(ko + 1) * block_K, bx * block_N:(bx + 1) * block_N], T.copy(B[ko * block_K:(ko + 1) * block_K, bx * block_N:(bx + 1) * block_N],
B_shared[ko % num_stages, :, :]) B_shared[ko % num_stages, :, :])
T.mbarrier_arrive(ko % num_stages) T.mbarrier_arrive(mbarrier=ko % num_stages)
with T.ws(0): with T.ws(0):
T.clear(C_local) T.mbarrier_wait_parity(
for ko in range(T.ceildiv(K, block_K)): mbarrier=ko % num_stages, parity=(ko // num_stages) % num_stages)
T.mbarrier_wait_parity(ko % num_stages, (ko // num_stages) % num_stages)
T.gemm(A_shared[ko % num_stages, :, :], B_shared[ko % num_stages, :, :], T.gemm(A_shared[ko % num_stages, :, :], B_shared[ko % num_stages, :, :],
C_local) C_local)
T.mbarrier_arrive(ko % num_stages + num_stages) T.mbarrier_arrive(mbarrier=ko % num_stages + num_stages)
with T.ws(0):
T.copy(C_local, C[by * block_M, bx * block_N]) T.copy(C_local, C[by * block_M, bx * block_N])
return main return main
K = 64 def main():
# 1. Define the kernel (matmul) and compile/lower it into an executable module M = 16384
func = matmul(128, 128, K, 128, 128, 32) N = 16384
print(func.script()) K = 16384
# 2. Compile the kernel into a torch function block_M = 128
# out_idx specifies the index of the output buffer in the argument list block_N = 128
# if out_idx is specified, the tensor will be created during runtime block_K = 64
# target currently can be "cuda" or "hip" or "cpu". # 1. Define the kernel (matmul) and compile/lower it into an executable module
tilelang.disable_cache() func = matmul(M, N, K, block_M, block_N, block_K)
jit_kernel = tilelang.compile(
func, # 2. Compile the kernel into a torch function
out_idx=[2], # out_idx specifies the index of the output buffer in the argument list
target="cuda", # if out_idx is specified, the tensor will be created during runtime
execution_backend="cython", # target currently can be "cuda" or "hip" or "cpu".
pass_configs={ jit_kernel = tilelang.compile(func, out_idx=[2])
"tl.disable_warp_specialized": True,
# "tl.disable_tma_lower": True, # 3. Test the kernel in Python with PyTorch data
}) import torch
tilelang.enable_cache()
print(jit_kernel.get_kernel_source()) # Create random input tensors on the GPU
# 3. Test the kernel in Python with PyTorch data a = torch.randn(M, K, device="cuda", dtype=torch.float16)
import torch b = torch.randn(K, N, device="cuda", dtype=torch.float16)
# Create random input tensors on the GPU # Run the kernel through the Profiler
a = torch.randn(128, K, device="cuda", dtype=torch.float16) c = jit_kernel(a, b)
b = torch.randn(K, 128, device="cuda", dtype=torch.float16)
# Reference multiplication using PyTorch
# Run the kernel through the Profiler ref_c = a @ b
c = jit_kernel(a, b)
# Validate correctness
print(c) torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
# Reference multiplication using PyTorch print("Kernel output matches PyTorch reference.")
ref_c = a @ b
# 4. Retrieve and inspect the generated CUDA source (optional)
# Validate correctness # cuda_source = jit_kernel.get_kernel_source()
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) # print("Generated CUDA kernel:\n", cuda_source)
print("Kernel output matches PyTorch reference.")
# 5.Profile latency with kernel
# # 4. Retrieve and inspect the generated CUDA source (optional) profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
# # cuda_source = jit_kernel.get_kernel_source()
# # print("Generated CUDA kernel:\n", cuda_source) latency = profiler.do_bench()
# # 5.Profile latency with kernel print(f"Latency: {latency} ms")
# profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
# latency = profiler.do_bench() if __name__ == "__main__":
main()
# print(f"Latency: {latency} ms")
import tilelang
import tilelang.language as T
def matmul_warp_specialize_copy_0_gemm_1(M,
N,
K,
block_M,
block_N,
block_K,
dtype="float16",
accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# create mbarrier for tma
T.create_list_of_mbarrier(128, 128)
with T.ws(1):
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
with T.ws(0):
T.mbarrier_wait_parity(1, (ko & 1) ^ 1)
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.mbarrier_arrive(0)
with T.ws(1):
T.mbarrier_wait_parity(0, ko & 1)
T.gemm(A_shared, B_shared, C_local)
T.mbarrier_arrive(1)
with T.ws(1):
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def main():
M = 16384
N = 16384
K = 16384
block_M = 128
block_N = 128
block_K = 64
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul_warp_specialize_copy_0_gemm_1(M, N, K, block_M, block_N, block_K)
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(
func,
out_idx=[2],
)
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = jit_kernel(a, b)
# Reference multiplication using PyTorch
ref_c = a @ b
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
if __name__ == "__main__":
main()
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
tilelang.disable_cache()
def matmul_warp_specialize_copy_1_gemm_0(M,
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): N,
K,
block_M,
block_N,
block_K,
dtype="float16",
accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
@T.prim_func @T.prim_func
def main( def main(
...@@ -14,27 +19,13 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -14,27 +19,13 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype, "shared") A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype, "shared") B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# create mbarrier for tma # create mbarrier for tma
T.create_list_of_mbarrier(128, 128) T.create_list_of_mbarrier(128, 128)
# with T.ws(1):
# for ko in range(T.ceildiv(K, block_K)):
# T.mbarrier_wait_parity(1, (ko & 1) ^ 1)
# T.copy(A[by * block_M, ko * block_K], A_shared)
# T.copy(B[ko * block_K, bx * block_N], B_shared)
# T.mbarrier_arrive(0)
# with T.ws(0):
# T.clear(C_local)
# for ko in range(T.ceildiv(K, block_K)):
# T.mbarrier_wait_parity(0, ko & 1)
# T.gemm(A_shared, B_shared, C_local)
# T.mbarrier_arrive(1)
# T.copy(C_local, C[by * block_M, bx * block_N])
with T.ws(0): with T.ws(0):
T.clear(C_local) T.clear(C_local)
...@@ -55,51 +46,54 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -55,51 +46,54 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return main return main
K = 64 def main():
# 1. Define the kernel (matmul) and compile/lower it into an executable module M = 16384
func = matmul(128, 128, K, 128, 128, 32) N = 16384
K = 16384
# 2. Compile the kernel into a torch function block_M = 128
# out_idx specifies the index of the output buffer in the argument list block_N = 128
# if out_idx is specified, the tensor will be created during runtime block_K = 64
# target currently can be "cuda" or "hip" or "cpu".
tilelang.disable_cache() # 1. Define the kernel (matmul) and compile/lower it into an executable module
jit_kernel = tilelang.compile( func = matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K)
func,
out_idx=[2], # 2. Compile the kernel into a torch function
target="cuda", # out_idx specifies the index of the output buffer in the argument list
execution_backend="cython", # if out_idx is specified, the tensor will be created during runtime
pass_configs={ # target currently can be "cuda" or "hip" or "cpu".
"tl.disable_warp_specialized": True, jit_kernel = tilelang.compile(
# "tl.disable_tma_lower": True, func,
}) out_idx=[2],
tilelang.enable_cache() )
print(jit_kernel.get_kernel_source())
# 3. Test the kernel in Python with PyTorch data # 3. Test the kernel in Python with PyTorch data
import torch import torch
# Create random input tensors on the GPU # Create random input tensors on the GPU
a = torch.randn(128, K, device="cuda", dtype=torch.float16) a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, 128, device="cuda", dtype=torch.float16) b = torch.randn(K, N, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler # Run the kernel through the Profiler
c = jit_kernel(a, b) c = jit_kernel(a, b)
print(c) # Reference multiplication using PyTorch
# Reference multiplication using PyTorch ref_c = a @ b
ref_c = a @ b
# Validate correctness
# Validate correctness torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) print("Kernel output matches PyTorch reference.")
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
# # 4. Retrieve and inspect the generated CUDA source (optional) # cuda_source = jit_kernel.get_kernel_source()
# # cuda_source = jit_kernel.get_kernel_source() # print("Generated CUDA kernel:\n", cuda_source)
# # print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
# # 5.Profile latency with kernel profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
# profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
# latency = profiler.do_bench()
print(f"Latency: {latency} ms")
# print(f"Latency: {latency} ms")
if __name__ == "__main__":
main()
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
tilelang.disable_cache()
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor[(M, K), dtype],
B: T.Tensor((K, N), dtype), B: T.Tensor[(K, N), dtype],
C: T.Tensor((M, N), dtype), C: T.Tensor[(M, N), dtype],
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype, "shared") A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype, "shared") B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# create mbarrier for tma # create mbarrier for tma
T.create_list_of_mbarrier(128, 128) T.create_list_of_mbarrier(128, 128)
# with T.ws(1):
# for ko in range(T.ceildiv(K, block_K)):
# T.mbarrier_wait_parity(1, (ko & 1) ^ 1)
# T.copy(A[by * block_M, ko * block_K], A_shared)
# T.copy(B[ko * block_K, bx * block_N], B_shared)
# T.mbarrier_arrive(0)
# with T.ws(0):
# T.clear(C_local)
# for ko in range(T.ceildiv(K, block_K)):
# T.mbarrier_wait_parity(0, ko & 1)
# T.gemm(A_shared, B_shared, C_local)
# T.mbarrier_arrive(1)
# T.copy(C_local, C[by * block_M, bx * block_N])
with T.ws(0): with T.ws(0):
T.clear(C_local) T.clear(C_local)
...@@ -55,51 +39,51 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -55,51 +39,51 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return main return main
K = 64 def main():
# 1. Define the kernel (matmul) and compile/lower it into an executable module M = 16384
func = matmul(128, 128, K, 128, 128, 32) N = 16384
K = 16384
# 2. Compile the kernel into a torch function block_M = 128
# out_idx specifies the index of the output buffer in the argument list block_N = 128
# if out_idx is specified, the tensor will be created during runtime block_K = 64
# target currently can be "cuda" or "hip" or "cpu". # 1. Define the kernel (matmul) and compile/lower it into an executable module
tilelang.disable_cache() func = matmul(M, N, K, block_M, block_N, block_K)
jit_kernel = tilelang.compile(
func, # 2. Compile the kernel into a torch function
out_idx=[2], # out_idx specifies the index of the output buffer in the argument list
target="cuda", # if out_idx is specified, the tensor will be created during runtime
execution_backend="cython", # target currently can be "cuda" or "hip" or "cpu".
pass_configs={ tilelang.disable_cache()
"tl.disable_warp_specialized": True, jit_kernel = tilelang.compile(func, out_idx=[2])
# "tl.disable_tma_lower": True,
}) # 3. Test the kernel in Python with PyTorch data
tilelang.enable_cache() import torch
print(jit_kernel.get_kernel_source())
# 3. Test the kernel in Python with PyTorch data # Create random input tensors on the GPU
import torch a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
# Create random input tensors on the GPU
a = torch.randn(128, K, device="cuda", dtype=torch.float16) # Run the kernel through the Profiler
b = torch.randn(K, 128, device="cuda", dtype=torch.float16) c = jit_kernel(a, b)
# Run the kernel through the Profiler # Reference multiplication using PyTorch
c = jit_kernel(a, b) ref_c = a @ b
print(c) # Validate correctness
# Reference multiplication using PyTorch torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
ref_c = a @ b print("Kernel output matches PyTorch reference.")
# Validate correctness # 4. Retrieve and inspect the generated CUDA source (optional)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) # cuda_source = jit_kernel.get_kernel_source()
print("Kernel output matches PyTorch reference.") # print("Generated CUDA kernel:\n", cuda_source)
# # 4. Retrieve and inspect the generated CUDA source (optional) # 5.Profile latency with kernel
# # cuda_source = jit_kernel.get_kernel_source() profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
# # print("Generated CUDA kernel:\n", cuda_source)
latency = profiler.do_bench()
# # 5.Profile latency with kernel
# profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) print(f"Latency: {latency} ms")
# latency = profiler.do_bench()
if __name__ == "__main__":
# print(f"Latency: {latency} ms") main()
import tilelang
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype, "shared")
B_shared = T.alloc_shared((block_K, block_N), dtype, "shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# create mbarrier for tma
T.create_list_of_mbarrier(128, 128)
with T.ws(1):
for ko in range(T.ceildiv(K, block_K)):
T.mbarrier_wait_parity(1, ko ^ 1)
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.mbarrier_arrive(0)
with T.ws(0):
T.clear(C_local)
for ko in range(T.ceildiv(K, block_K)):
T.mbarrier_wait_parity(0, ko)
T.gemm(A_shared, B_shared, C_local)
T.mbarrier_arrive(1)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(128, 128, 64, 128, 128, 32)
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
execution_backend="cython",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True,
})
print(jit_kernel.get_kernel_source())
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(128, 64, device="cuda", dtype=torch.float16)
b = torch.randn(64, 128, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = jit_kernel(a, b)
print(c)
# Reference multiplication using PyTorch
ref_c = a @ b
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
import tilelang
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype, "shared")
B_shared = T.alloc_shared((block_K, block_N), dtype, "shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# create mbarrier for tma
T.create_list_of_mbarrier(128, 128)
with T.ws(1):
for ko in range(T.ceildiv(K, block_K)):
T.mbarrier_wait_parity(1, ko ^ 1)
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.mbarrier_arrive(0)
with T.ws(0):
T.clear(C_local)
for ko in range(T.ceildiv(K, block_K)):
T.mbarrier_wait_parity(0, ko)
T.gemm(A_shared, B_shared, C_local)
T.mbarrier_arrive(1)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(128, 128, 64, 128, 128, 32)
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
execution_backend="cython",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True,
})
print(jit_kernel.get_kernel_source())
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(128, 64, device="cuda", dtype=torch.float16)
b = torch.randn(64, 128, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = jit_kernel(a, b)
print(c)
# Reference multiplication using PyTorch
ref_c = a @ b
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
import tilelang.testing
import example_warp_specialize_flashmla
import example_warp_specialize_gemm_barrierpipe_stage2
import example_warp_specialize_gemm_copy_0_gemm_1
import example_warp_specialize_gemm_copy_1_gemm_0
import example_warp_specialize_gemm_softpipe_stage2
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_example_warp_specialize_flashmla():
example_warp_specialize_flashmla.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_example_warp_specialize_gemm_barrierpipe_stage2():
example_warp_specialize_gemm_barrierpipe_stage2.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_example_warp_specialize_gemm_copy_0_gemm_1():
example_warp_specialize_gemm_copy_0_gemm_1.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_example_warp_specialize_gemm_copy_1_gemm_0():
example_warp_specialize_gemm_copy_1_gemm_0.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_example_warp_specialize_gemm_softpipe_stage2():
example_warp_specialize_gemm_softpipe_stage2.main()
if __name__ == "__main__":
tilelang.testing.main()
...@@ -254,13 +254,38 @@ public: ...@@ -254,13 +254,38 @@ public:
WarpSpecializeFrameNode); WarpSpecializeFrameNode);
}; };
WarpSpecializeFrame WarpSpecialize(int warp_group_idx, PrimExpr thread_idx, WarpSpecializeFrame WarpSpecialize(Array<IntImm> warp_group_ids,
PrimExpr thread_idx,
int warp_group_size = 128) { int warp_group_size = 128) {
ObjectPtr<WarpSpecializeFrameNode> n = make_object<WarpSpecializeFrameNode>(); ObjectPtr<WarpSpecializeFrameNode> n = make_object<WarpSpecializeFrameNode>();
PrimExpr min_bound = PrimExpr condition;
max(0, IntImm(thread_idx.dtype(), warp_group_idx) * warp_group_size); std::vector<int> warp_groups;
PrimExpr max_bound = min_bound + warp_group_size; for (int i = 0; i < warp_group_ids.size(); i++) {
PrimExpr condition = thread_idx >= min_bound && thread_idx < max_bound; warp_groups.push_back(Downcast<IntImm>(warp_group_ids[i])->value);
}
std::sort(warp_groups.begin(), warp_groups.end());
// Merge consecutive groups
std::vector<std::pair<int, int>> merged;
for (int group : warp_groups) {
if (merged.empty() || group != merged.back().second) {
merged.emplace_back(group, group + 1);
} else {
merged.back().second = group + 1;
}
}
for (const auto &[start, end] : merged) {
PrimExpr min_bound = IntImm(thread_idx.dtype(), start) * warp_group_size;
PrimExpr max_bound = IntImm(thread_idx.dtype(), end) * warp_group_size;
PrimExpr range_cond = (thread_idx >= min_bound) && (thread_idx < max_bound);
if (condition.defined()) {
condition = tir::Or(condition, range_cond);
} else {
condition = range_cond;
}
}
IfFrame if_frame = If(condition); IfFrame if_frame = If(condition);
n->frames.push_back(if_frame); n->frames.push_back(if_frame);
n->frames.push_back(Then()); n->frames.push_back(Then());
......
...@@ -381,6 +381,7 @@ std::string FragmentNode::DebugOutput() const { ...@@ -381,6 +381,7 @@ std::string FragmentNode::DebugOutput() const {
ss << " -> thread: " << ThreadExtent(); ss << " -> thread: " << ThreadExtent();
ss << " -> forward_thread: " << forward_thread_; ss << " -> forward_thread: " << forward_thread_;
ss << " -> forward_index: " << GetForwardIndex(); ss << " -> forward_index: " << GetForwardIndex();
ss << " -> thread_range: " << thread_range_;
return ss.str(); return ss.str();
} }
......
...@@ -127,6 +127,12 @@ public: ...@@ -127,6 +127,12 @@ public:
Optional<Var> replicate_var); Optional<Var> replicate_var);
TVM_DEFINE_OBJECT_REF_METHODS(Fragment, Layout, FragmentNode); TVM_DEFINE_OBJECT_REF_METHODS(Fragment, Layout, FragmentNode);
Fragment SetThreadRange(Range thread_range) {
auto node = make_object<FragmentNode>(*this->get());
node->SetThreadRange(thread_range);
return Fragment(node);
}
}; };
Var InputPlaceholder(size_t idx); Var InputPlaceholder(size_t idx);
......
...@@ -17,6 +17,7 @@ namespace tvm { ...@@ -17,6 +17,7 @@ namespace tvm {
namespace tl { namespace tl {
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableSafeMemoryLegalize, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer); TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool);
......
...@@ -14,6 +14,8 @@ namespace tvm { ...@@ -14,6 +14,8 @@ namespace tvm {
namespace tl { namespace tl {
static constexpr const char *kDisableTMALower = "tl.disable_tma_lower"; static constexpr const char *kDisableTMALower = "tl.disable_tma_lower";
static constexpr const char *kDisableSafeMemoryLegalize =
"tl.disable_safe_memory_legalize";
static constexpr const char *kDisableWarpSpecialized = static constexpr const char *kDisableWarpSpecialized =
"tl.disable_warp_specialized"; "tl.disable_warp_specialized";
static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth"; static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth";
......
...@@ -68,11 +68,11 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target, ...@@ -68,11 +68,11 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
if (this->policy == GemmWarpPolicy::kFullRow || if (this->policy == GemmWarpPolicy::kFullRow ||
this->policy == GemmWarpPolicy::kSquare) { this->policy == GemmWarpPolicy::kSquare) {
m_warp = num_warps; m_warp = num_warps;
ICHECK(this->M % num_warps == 0); ICHECK(this->M % num_warps == 0) << this->M << " % " << num_warps;
} else if (this->policy == GemmWarpPolicy::kFullCol) { } else if (this->policy == GemmWarpPolicy::kFullCol) {
m_warp = 4; m_warp = 4;
n_warp = num_warps / 4; n_warp = num_warps / 4;
ICHECK(this->N % n_warp == 0); ICHECK(this->N % n_warp == 0) << this->N << " % " << n_warp;
} else { } else {
ICHECK(0) << "Unknown GemmWarpPolicy"; ICHECK(0) << "Unknown GemmWarpPolicy";
} }
...@@ -80,10 +80,10 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target, ...@@ -80,10 +80,10 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
} }
if (this->policy == GemmWarpPolicy::kFullRow) { if (this->policy == GemmWarpPolicy::kFullRow) {
m_warp = num_warps; m_warp = num_warps;
ICHECK(this->M % num_warps == 0); ICHECK(this->M % num_warps == 0) << this->M << " % " << num_warps;
} else if (this->policy == GemmWarpPolicy::kFullCol) { } else if (this->policy == GemmWarpPolicy::kFullCol) {
n_warp = num_warps; n_warp = num_warps;
ICHECK(this->N % num_warps == 0); ICHECK(this->N % num_warps == 0) << this->N << " % " << num_warps;
} else if (this->policy == GemmWarpPolicy::kSquare) { } else if (this->policy == GemmWarpPolicy::kSquare) {
auto factors = toPrimeFactors(num_warps); auto factors = toPrimeFactors(num_warps);
for (int factor : factors) { for (int factor : factors) {
...@@ -164,14 +164,15 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -164,14 +164,15 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
return {}; return {};
LayoutMap results; LayoutMap results;
ICHECK(C.scope() == "local.fragment"); ICHECK(C.scope() == "local.fragment");
auto block_size = *as_const_int(T.thread_bounds->extent); auto thread_range = T.thread_bounds;
auto block_size = *as_const_int(thread_range->extent);
if (TargetIsVolta(T.target)) { if (TargetIsVolta(T.target)) {
const int warp_size = 32; const int warp_size = 32;
auto [warp_m, warp_n] = auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target); ComputeWarpPartition(block_size / warp_size, T.target);
auto fragment = auto fragment =
makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment); results.Set(C, fragment.SetThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") { if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size(); int dim_A = A->shape.size();
results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]), results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
...@@ -179,7 +180,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -179,7 +180,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
true, trans_A ? 1 : 2)); true, trans_A ? 1 : 2));
} else if (A.scope() == "local.fragment") { } else if (A.scope() == "local.fragment") {
ICHECK(trans_A == false); ICHECK(trans_A == false);
results.Set(A, makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n)); auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
results.Set(A, fragment.SetThreadRange(thread_range));
} else { } else {
ICHECK(0); ICHECK(0);
} }
...@@ -195,7 +197,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -195,7 +197,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ComputeWarpPartition(block_size / warp_size, T.target); ComputeWarpPartition(block_size / warp_size, T.target);
auto fragment = auto fragment =
makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment); results.Set(C, fragment.SetThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") { if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size(); int dim_A = A->shape.size();
...@@ -206,8 +208,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -206,8 +208,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
A->dtype.bits(), trans_A ? 1 : 2)); A->dtype.bits(), trans_A ? 1 : 2));
} else if (A.scope() == "local.fragment") { } else if (A.scope() == "local.fragment") {
ICHECK(trans_A == false); ICHECK(trans_A == false);
results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, auto fragment =
A->dtype.bits())); makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits());
results.Set(A, fragment.SetThreadRange(thread_range));
} else { } else {
ICHECK(0); ICHECK(0);
} }
...@@ -221,7 +224,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -221,7 +224,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
} else if (B.scope() == "local.fragment") { } else if (B.scope() == "local.fragment") {
ICHECK(trans_B == false) << "B is local.fragment, trans_B must be false, " ICHECK(trans_B == false) << "B is local.fragment, trans_B must be false, "
"please raise an issue if you see this"; "please raise an issue if you see this";
results.Set(B, makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n)); auto fragment = makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n);
results.Set(B, fragment.SetThreadRange(thread_range));
} else { } else {
ICHECK(0); ICHECK(0);
} }
...@@ -235,7 +239,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -235,7 +239,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n, ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
C->dtype.bits()) C->dtype.bits())
: makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment); results.Set(C, fragment.SetThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") { if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size(); int dim_A = A->shape.size();
const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]); const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
...@@ -246,8 +250,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -246,8 +250,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
A->dtype.bits(), trans_A ? 1 : 2)); A->dtype.bits(), trans_A ? 1 : 2));
} else { } else {
ICHECK(trans_A == false); ICHECK(trans_A == false);
results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, auto fragment =
A->dtype.bits())); makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits());
results.Set(A, fragment.SetThreadRange(thread_range));
} }
if (B.scope() == "shared" || B.scope() == "shared.dyn") { if (B.scope() == "shared" || B.scope() == "shared.dyn") {
int dim_B = B->shape.size(); int dim_B = B->shape.size();
...@@ -267,8 +272,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -267,8 +272,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
auto fragment = auto fragment =
makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits()); makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment.SetThreadRange(thread_range));
results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") { if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size(); int dim_A = A->shape.size();
...@@ -277,8 +281,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -277,8 +281,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
*as_const_int(A->shape[dim_A - 1]), A->dtype.bits(), kPack); *as_const_int(A->shape[dim_A - 1]), A->dtype.bits(), kPack);
results.Set(A, shared_layout); results.Set(A, shared_layout);
} else if (A.scope() == "local.fragment") { } else if (A.scope() == "local.fragment") {
results.Set(A, makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n, auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
A->dtype.bits(), trans_A)); A->dtype.bits(), trans_A);
results.Set(A, fragment.SetThreadRange(thread_range));
} else { } else {
ICHECK(0); ICHECK(0);
} }
...@@ -290,7 +295,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -290,7 +295,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
results.Set(B, shared_layout); results.Set(B, shared_layout);
} else if (B.scope() == "local.fragment") { } else if (B.scope() == "local.fragment") {
results.Set(B, makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n)); auto fragment = makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n);
results.Set(B, fragment.SetThreadRange(thread_range));
} else { } else {
ICHECK(0); ICHECK(0);
} }
......
...@@ -181,7 +181,8 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -181,7 +181,8 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
IterVarType::kDataPar); IterVarType::kDataPar);
PrimExpr loop_var_to_thread = PrimExpr loop_var_to_thread =
src_layout->ForwardThread(indice_map_[buffer], rep); src_layout->ForwardThread(indice_map_[buffer], rep);
return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter); return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
.SetThreadRange(T.thread_bounds);
} }
}; };
if (source_buffer.defined()) { if (source_buffer.defined()) {
...@@ -258,7 +259,8 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -258,7 +259,8 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
LayoutMap results; LayoutMap results;
for (const auto &[buffer, _] : indice_map_) { for (const auto &[buffer, _] : indice_map_) {
if (!T.layout_map.count(buffer)) { if (!T.layout_map.count(buffer)) {
results.Set(buffer, CompleteBufferFragment(buffer)); results.Set(buffer, CompleteBufferFragment(buffer).SetThreadRange(
T.thread_bounds));
} }
// Though they may exist some conflicts, but it's fine. // Though they may exist some conflicts, but it's fine.
...@@ -269,7 +271,8 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -269,7 +271,8 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (T.layout_map.count(buffer)) { if (T.layout_map.count(buffer)) {
const FragmentNode *src_layout = const FragmentNode *src_layout =
T.layout_map[buffer].as<Fragment>().get(); T.layout_map[buffer].as<Fragment>().get();
Fragment dst_layout_fragment = CompleteBufferFragment(buffer); Fragment dst_layout_fragment =
CompleteBufferFragment(buffer).SetThreadRange(T.thread_bounds);
const FragmentNode *dst_layout = const FragmentNode *dst_layout =
dst_layout_fragment.as<Fragment>().get(); dst_layout_fragment.as<Fragment>().get();
if (src_layout && dst_layout) { if (src_layout && dst_layout) {
......
/*! /*!
* \file eliminate_storage_sync_for_mbarrier.cc * \file eliminate_storage_sync_for_mbarrier.cc
*/ */
#include "../op/builtin.h"
#include "./storage_access.h" #include "./storage_access.h"
#include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h" #include "arith/ir_visitor_with_analyzer.h"
...@@ -27,6 +28,7 @@ public: ...@@ -27,6 +28,7 @@ public:
} }
Eliminator(arith::Analyzer *analyzer) : IRMutatorWithAnalyzer(analyzer) { Eliminator(arith::Analyzer *analyzer) : IRMutatorWithAnalyzer(analyzer) {
im_mbarrier_for_ = false;
in_mbarrier_region_ = false; in_mbarrier_region_ = false;
} }
...@@ -49,7 +51,8 @@ public: ...@@ -49,7 +51,8 @@ public:
call = static_cast<const CallNode *>(op->value.get()); call = static_cast<const CallNode *>(op->value.get());
if (call->op.same_as(builtin::tvm_storage_sync())) { if (call->op.same_as(builtin::tvm_storage_sync())) {
// Skip storage sync if we're in a region with mbarrier operations // Skip storage sync if we're in a region with mbarrier operations
if (in_mbarrier_region_) { // and we're not in a for loop with mbarrier operations
if (in_mbarrier_region_ || im_mbarrier_for_) {
return Stmt(); return Stmt();
} }
} else if (call->op.same_as(builtin::ptx_arrive_barrier()) || } else if (call->op.same_as(builtin::ptx_arrive_barrier()) ||
...@@ -77,7 +80,24 @@ public: ...@@ -77,7 +80,24 @@ public:
return ret; return ret;
} }
Stmt VisitStmt_(const ForNode *op) final {
PostOrderVisit(GetRef<For>(op), [&](const ObjectRef &node) {
if (const auto *call = node.as<CallNode>()) {
if (call->op.same_as(create_list_of_mbarrier()) ||
call->op.same_as(mbarrier_wait_parity()) ||
call->op.same_as(builtin::ptx_arrive_barrier()) ||
call->op.same_as(builtin::ptx_cp_async_barrier())) {
im_mbarrier_for_ = true;
}
}
});
auto stmt = IRMutatorWithAnalyzer::VisitStmt_(op);
im_mbarrier_for_ = false;
return stmt;
}
private: private:
bool im_mbarrier_for_;
bool in_mbarrier_region_; bool in_mbarrier_region_;
const AttrStmtNode *thread_extent_{nullptr}; const AttrStmtNode *thread_extent_{nullptr};
}; };
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <queue> #include <queue>
#include "../op/builtin.h"
#include "../op/parallel.h" #include "../op/parallel.h"
#include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_mutator_with_analyzer.h"
#include "loop_partition.h" #include "loop_partition.h"
...@@ -285,6 +286,11 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() { ...@@ -285,6 +286,11 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() {
using namespace tir::transform; using namespace tir::transform;
// Define the transformation function to be applied // Define the transformation function to be applied
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
bool disable_safe_memory_legalize =
ctx->GetConfig<Bool>(kDisableSafeMemoryLegalize, Bool(false)).value();
if (disable_safe_memory_legalize) {
return f;
}
return SafeMemoryLegalizer::Substitute(std::move(f)); return SafeMemoryLegalizer::Substitute(std::move(f));
}; };
// Create and return a PrimFunc pass with the transformation function // Create and return a PrimFunc pass with the transformation function
......
...@@ -191,9 +191,7 @@ Fragment PlanLoopPartition(For op, int vectorize_size, Range thread_range) { ...@@ -191,9 +191,7 @@ Fragment PlanLoopPartition(For op, int vectorize_size, Range thread_range) {
size_t num_thread = *as_const_int(thread_range->extent); size_t num_thread = *as_const_int(thread_range->extent);
LoopPartitioner partitioner; LoopPartitioner partitioner;
Fragment fragment = partitioner.Partition(op, num_thread, vectorize_size); Fragment fragment = partitioner.Partition(op, num_thread, vectorize_size);
auto node = make_object<FragmentNode>(*fragment.get()); return fragment.SetThreadRange(thread_range);
node->SetThreadRange(thread_range);
return Fragment(node);
} }
For LoopPragmaUnroll(For stmt) { For LoopPragmaUnroll(For stmt) {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*! /*!
* \file lower hopper intrin.cc * \file lower hopper intrin.cc
* \brief Lower Hopper intrinsics cuda GPU(sm90+) * \brief Lower Hopper intrinsics cuda GPU(sm90+)
......
...@@ -32,19 +32,21 @@ ...@@ -32,19 +32,21 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "arith/ir_visitor_with_analyzer.h"
#include "runtime/thread_storage_scope.h" #include "runtime/thread_storage_scope.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
using arith::IRVisitorWithAnalyzer;
using runtime::StorageRank; using runtime::StorageRank;
using runtime::StorageScope; using runtime::StorageScope;
/*! /*!
* \brief Base class of storage access analysis * \brief Base class of storage access analysis
*/ */
class TileLangStorageAccessVisitor : public StmtExprVisitor { class TileLangStorageAccessVisitor : public IRVisitorWithAnalyzer {
public: public:
/*! \brief Storage access type */ /*! \brief Storage access type */
enum AccessType { enum AccessType {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment