Unverified Commit a874e4e8 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Phaseout PassConfig `kDisableDynamicTailSplit` and...

[Refactor] Phaseout PassConfig `kDisableDynamicTailSplit` and `kDynamicAlignment` as they are legacy (#1486)

* [Cleanup] Remove dynamic shape example and related tests

* Deleted the dynamic shape example script `example_dynamic.py` and its corresponding test file `test_example_dynamic.py` to streamline the codebase.
* Removed unused dynamic tail split and dynamic alignment configurations from `builtin.h` and `pass_config.py`.
* Cleaned up the dynamic shape testing files to eliminate redundancy and improve maintainability.

* build fix
parent 7e8d1f82
import tilelang
import tilelang.language as T
import tilelang.testing
from tilelang import tvm as tvm
@tilelang.jit(pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8})
def matmul_dynamic_mnk(
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
M = tvm.te.var("m")
N = tvm.te.var("n")
K = tvm.te.var("k")
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def dynamic_matmul(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return dynamic_matmul
def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads):
print(
f"M: {M}, N: {N}, K: {K}, block_M: {block_M}, block_N: {block_N}, block_K: {block_K}, trans_A: {trans_A}, trans_B: {trans_B}, in_dtype: {in_dtype}, out_dtype: {out_dtype}, accum_dtype: {accum_dtype}, num_stages: {num_stages}, threads: {threads}"
)
kernel = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads)
import torch
if trans_A:
A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype))
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
if trans_B:
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
else:
B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
kernel(A, B, C)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
# Get Reference Result
ref_c = ref_program(A, B)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench(input_tensors=[A, B, C])
print(f"Latency: {latency} ms")
def main(M=16384, N=16384, K=16384):
block_M, block_N, block_K = 128, 128, 32
trans_A, trans_B = False, False
in_dtype, out_dtype = T.float16, T.float16
accum_dtype = T.float32
num_stages = 3
threads = 128
matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads)
if __name__ == "__main__":
main()
import tilelang.testing
import example_dynamic
def test_example_dynamic():
example_dynamic.main(M=1024, N=1024, K=1024)
if __name__ == "__main__":
tilelang.testing.main()
...@@ -22,8 +22,6 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableSafeMemoryLegalize, Bool); ...@@ -22,8 +22,6 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableSafeMemoryLegalize, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableThreadStorageSync, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableThreadStorageSync, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer); TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kForceLetInline, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kForceLetInline, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool);
......
...@@ -60,14 +60,6 @@ static constexpr const char *kLayoutVisualizationEnable = ...@@ -60,14 +60,6 @@ static constexpr const char *kLayoutVisualizationEnable =
static constexpr const char *kLayoutVisualizationFormats = static constexpr const char *kLayoutVisualizationFormats =
"tl.layout_visualization_formats"; "tl.layout_visualization_formats";
static constexpr const char *kDeviceCompileFlags = "tl.device_compile_flags"; static constexpr const char *kDeviceCompileFlags = "tl.device_compile_flags";
/*!
* \brief Whether to disable dynamic tail split
*
* kDisableDynamicTailSplit = "tl.disable_dynamic_tail_split"
*
*/
static constexpr const char *kDisableDynamicTailSplit =
"tl.disable_dynamic_tail_split";
/*! /*!
* \brief Whether to disable thread storage synchronization * \brief Whether to disable thread storage synchronization
...@@ -91,18 +83,6 @@ static constexpr const char *kDisableThreadStorageSync = ...@@ -91,18 +83,6 @@ static constexpr const char *kDisableThreadStorageSync =
*/ */
static constexpr const char *kForceLetInline = "tl.force_let_inline"; static constexpr const char *kForceLetInline = "tl.force_let_inline";
/*!
* \brief The size of the vectorized dimension in buffer, designed by user
*
* For example, if the vectorized dimension is 128 bits and the dtype of buffer
* A[m, k] is float16, the size of the vectorized dimension (i.e. k) in buffer A
* should be divisible by 8 (8 = 128 / 16).
*
* kDynamicAlignment = "tl.dynamic_alignment"
*
*/
static constexpr const char *kDynamicAlignment = "tl.dynamic_alignment";
/*! /*!
* \brief Get the type of the CUDA tensor map * \brief Get the type of the CUDA tensor map
* *
......
import torch
import torch.backends
from tilelang import tvm as tvm
import tilelang.testing
from tvm import DataType
import tilelang.language as T
from tilelang.intrinsics.utils import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import TensorCoreIntrinEmitter
tilelang.testing.set_random_seed(0)
def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape
can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)
def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]
return T.Layout(shape, transform_func)
def tl_matmul_macro(
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
T.float16,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == T.int32:
micro_size_k = 32
# This is a debug config
block_row_warps = 1
block_col_warps = 1
warp_row_tiles = 16
warp_col_tiles = 16
chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
M = tvm.te.var("m")
A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size), in_dtype)
B_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
}
)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
# Perform STMatrix
mma_emitter.stmatrix(
C_local,
C_shared,
)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
return main
def assert_tl_matmul_macro_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul_macro(N, K, in_dtype, out_dtype, accum_dtype)
kernel = tilelang.compile(matmul, out_idx=[2])
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
C = kernel(A, B)
# Get Reference Result
ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype))
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
def tl_matmul_block(
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
M = tvm.te.var("m")
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def assert_tl_matmul_block_correctness(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = tl_matmul_block(
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = tilelang.compile(program, out_idx=[2])
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
C = kernel(A, B)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
# Get Reference Result
ref_c = ref_program(A, B)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
def tl_matmul_block_all_dynamic(
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
M = tvm.te.var("m")
N = tvm.te.var("n")
K = tvm.te.var("k")
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def assert_tl_matmul_block_all_dynamic_correctness(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = tl_matmul_block_all_dynamic(
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = tilelang.compile(program)
if trans_A:
A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype))
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
if trans_B:
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
else:
B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
kernel(A, B, C)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
# Get Reference Result
ref_c = ref_program(A, B)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
def assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
dynamic_alignment=8,
):
program = tl_matmul_block_all_dynamic(
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_DYNAMIC_TAIL_SPLIT: dynamic_alignment != 0,
tilelang.PassConfigKey.TL_DYNAMIC_ALIGNMENT: dynamic_alignment,
}
if M % 64 == 0 or N % 64 == 0 or K % 64 != 0:
# workaround for hopper tma lower pass
pass_configs[tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER] = True
pass_configs[tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED] = True
kernel = tilelang.compile(program, pass_configs=pass_configs)
if trans_A:
A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype))
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
if trans_B:
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
else:
B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
kernel(A, B, C)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
# Get Reference Result
ref_c = ref_program(A, B)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
def test_assert_tl_matmul_macro():
assert_tl_matmul_macro_correctness(128, 128, 128, T.float16, T.float16, T.float16)
assert_tl_matmul_macro_correctness(66, 128, 128, T.float16, T.float16, T.float16)
assert_tl_matmul_macro_correctness(32, 128, 128, T.float16, T.float16, T.float16)
def test_assert_tl_matmul_block():
assert_tl_matmul_block_correctness(128, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32)
assert_tl_matmul_block_correctness(67, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32)
assert_tl_matmul_block_correctness(36, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32)
def test_assert_tl_matmul_block_all_dynamic():
assert_tl_matmul_block_all_dynamic_correctness(128, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(67, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32)
def test_assert_tl_matmul_block_all_dynamic_with_pass_config():
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
128, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32, dynamic_alignment=8
)
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
64, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32, dynamic_alignment=8
)
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
64, 128, 60, False, False, T.float16, T.float16, T.float16, 64, 64, 32, dynamic_alignment=4
)
# Tail split is enabled with dynamic alignment 0
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
64, 128, 64, False, False, T.float16, T.float16, T.float16, 64, 64, 32, dynamic_alignment=0
)
if __name__ == "__main__":
tilelang.testing.main()
import torch
from tilelang import tvm as tvm
import tilelang.testing
import tilelang.language as T
def tl_matmul_block_static(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
num_threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def assert_tl_matmul_block_static(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages=3,
num_threads=128,
):
program = tl_matmul_block_static(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
num_threads,
)
kernel = tilelang.compile(program)
if trans_A:
A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype))
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
if trans_B:
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
else:
B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
kernel(A, B, C)
# print(kernel.get_kernel_source())
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
# Get Reference Result
ref_c = ref_program(A, B)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
print(f"Static Latency: {latency} ms")
def tl_matmul_block_dynamic_m(
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
num_threads,
):
M = tvm.te.var("m")
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def assert_tl_matmul_block_dynamic_m(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages=3,
num_threads=128,
pass_configs=None,
):
program = tl_matmul_block_dynamic_m(
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = tilelang.compile(program, pass_configs=pass_configs)
if trans_A:
A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype))
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
if trans_B:
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
else:
B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
kernel(A, B, C)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
# Get Reference Result
ref_c = ref_program(A, B)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench(input_tensors=[A, B, C])
print(f"Dynamic M Latency with pass_configs: {pass_configs} is {latency} ms")
def tl_matmul_block_dynamic_mn(
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
num_threads,
):
M = tvm.te.var("m")
N = tvm.te.var("n")
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def assert_tl_matmul_block_dynamic_mn(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages=3,
num_threads=128,
pass_configs=None,
):
program = tl_matmul_block_dynamic_mn(
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = tilelang.compile(program, pass_configs=pass_configs)
if trans_A:
A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype))
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
if trans_B:
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
else:
B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
kernel(A, B, C)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
# Get Reference Result
ref_c = ref_program(A, B)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench(input_tensors=[A, B, C])
print(f"Dynamic MN Latency with pass_configs: {pass_configs} is {latency} ms")
def tl_matmul_block_dynamic_mnk(
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
num_threads,
):
M = tvm.te.var("m")
N = tvm.te.var("n")
K = tvm.te.var("k")
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def assert_tl_matmul_block_dynamic_mnk(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages=3,
num_threads=128,
pass_configs=None,
):
program = tl_matmul_block_dynamic_mnk(
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = tilelang.compile(program, pass_configs=pass_configs)
if trans_A:
A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype))
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
if trans_B:
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
else:
B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
kernel(A, B, C)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
# Get Reference Result
ref_c = ref_program(A, B)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench(input_tensors=[A, B, C])
print(f"Dynamic MNK Latency with pass_configs: {pass_configs} is {latency} ms")
def run_assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K, False, False, T.float16, T.float16, T.float32)
def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_dynamic_m(
M,
N,
K,
block_M,
block_N,
block_K,
False,
False,
T.float16,
T.float16,
T.float32,
pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8},
)
assert_tl_matmul_block_dynamic_m(
M,
N,
K,
block_M,
block_N,
block_K,
False,
False,
T.float16,
T.float16,
T.float32,
pass_configs={"tl.disable_dynamic_tail_split": False},
)
def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_dynamic_mn(
M,
N,
K,
block_M,
block_N,
block_K,
False,
False,
T.float16,
T.float16,
T.float32,
pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8},
)
assert_tl_matmul_block_dynamic_mn(
M,
N,
K,
block_M,
block_N,
block_K,
False,
False,
T.float16,
T.float16,
T.float32,
pass_configs={"tl.disable_dynamic_tail_split": False},
)
def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_dynamic_mnk(
M,
N,
K,
block_M,
block_N,
block_K,
False,
False,
T.float16,
T.float16,
T.float32,
pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 4},
)
assert_tl_matmul_block_dynamic_mnk(
M,
N,
K,
block_M,
block_N,
block_K,
False,
False,
T.float16,
T.float16,
T.float32,
pass_configs={"tl.disable_dynamic_tail_split": False},
)
def test_all():
run_assert_tl_matmul_block_static(1024, 1024, 1024, 128, 128, 32)
run_assert_tl_matmul_block_dynamic_m(1024, 1024, 1024, 128, 128, 32)
run_assert_tl_matmul_block_dynamic_mn(1024, 1024, 1024, 128, 128, 32)
run_assert_tl_matmul_block_dynamic_mnk(1024, 1024, 1024, 128, 128, 32)
if __name__ == "__main__":
tilelang.testing.main()
...@@ -10,12 +10,6 @@ class PassConfigKey(str, Enum): ...@@ -10,12 +10,6 @@ class PassConfigKey(str, Enum):
TL_SIMPLIFY = "tl.Simplify" TL_SIMPLIFY = "tl.Simplify"
"""Enable/disable TileLang simplification passes. Default: True""" """Enable/disable TileLang simplification passes. Default: True"""
TL_DYNAMIC_ALIGNMENT = "tl.dynamic_alignment"
"""Memory alignment requirement for dynamic shapes. Default: 16"""
TL_DISABLE_DYNAMIC_TAIL_SPLIT = "tl.disable_dynamic_tail_split"
"""Disable dynamic tail splitting optimization. Default: False"""
TL_DISABLE_WARP_SPECIALIZED = "tl.disable_warp_specialized" TL_DISABLE_WARP_SPECIALIZED = "tl.disable_warp_specialized"
"""Disable warp specialization optimization. Default: False""" """Disable warp specialization optimization. Default: False"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment