"tests/vscode:/vscode.git/clone" did not exist on "69fdb8720ffdf3b0c629d5d2372032115b23c805"
Commit 38ba083b authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Dev] Implement test case for tilelang transformations (#53)

* implement jit test case

* [Dev] implement auto tune test case for matrix multiplication

* Implement test for legalize memory access and vectorized loop

* lint fix
parent 34de04a6
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import itertools
import logging
import tilelang as tl
import tilelang.testing
import tilelang.language as T
from tilelang.autotuner import autotune, jit
# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
def ref_program(A, B):
"""
A reference matrix multiplication program, used to compare performance.
Parameters
----------
A : numpy.ndarray
The matrix with shape (M, K).
B : numpy.ndarray
The matrix with shape (N, K).
Returns
-------
np.ndarray
The result of A @ B.T, shape (M, N).
"""
return A @ B.T
def get_configs(M, N, K, with_roller=False):
"""
Generate a list of configuration dictionaries that will be used for tuning.
Parameters
----------
with_roller : bool
Whether to enable bitblas roller to deduce search spaces
Returns
-------
list of dict
Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning.
"""
if with_roller:
from bitblas.base.utils import get_roller_hints_from_func
from bitblas.ops.general_matmul.tirscript import matmul_select_implementation
from bitblas.base.arch import CUDA
from bitblas.base.roller.rasterization import NoRasterization
arch = CUDA("cuda")
topk = 20
# Simple TIR Compute Expression
ir_module = matmul_select_implementation(
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
)
roller_hints = get_roller_hints_from_func(
ir_module,
arch,
topk,
tensorcore_only=True,
allow_gemv=True,
)
if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling")
configs = []
for hint in roller_hints:
config = {}
block_m, block_n = hint.block
warp_m, warp_n = hint.warp
config["block_M"] = block_m
config["block_N"] = block_n
config["block_K"] = hint.rstep[0]
config["num_stages"] = 0
config["thread_num"] = (block_m * block_n) // (warp_m * warp_n) * 32
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
configs.append(config)
for config in configs:
print(config)
else:
block_M = [64]
block_N = [64]
block_K = [32]
num_stages = [0, 1]
thread_num = [128]
enable_rasterization = [False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasterization,
))
configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat
} for c in _configs
]
return configs
def matmul(M, N, K, with_roller):
"""
Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K)
- B: (N, K)
- C: (M, N)
Parameters
----------
M : int
The dimension M of the matrix multiplication.
N : int
The dimension N of the matrix multiplication.
K : int
The dimension K of the matrix multiplication.
Returns
-------
(best_latency, best_config, ref_latency)
best_latency : float
The best latency found among the tuned configurations.
best_config : dict
The parameter configuration that yielded best_latency.
ref_latency : float
The baseline latency of the reference program (for computing speedup).
"""
# Decorate the kernel with autotune & jit, specifying:
# - Tuning config list
# - Profiling keys
# - Warmup and repetition counts for better measurement
# - A reference program for correctness verification
# - The "tvm" profiler backend
# - HIP as the compilation target (modify as needed for your hardware)
if with_roller:
# check out bitblas is installed
try:
import bitblas # noqa: F401
except ImportError as e:
raise ImportError(
"BitBlas is not installed. Please install it via 'pip install bitblas'.") from e
@autotune(
configs=get_configs(M, N, K, with_roller),
keys=[
"block_M",
"block_N",
"block_K",
"num_stages",
"thread_num",
"enable_rasteration",
],
warmup=3,
rep=5,
)
@jit(
out_idx=[2],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=True,
profiler="auto",
target="auto",
)
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
enable_rasteration=None,
):
"""
The actual kernel to compute C = A @ B^T.
Parameters
----------
block_M : int
Block size in M dimension.
block_N : int
Block size in N dimension.
block_K : int
Block size in K dimension.
num_stages : int
Number of pipelined stages (for asynchronous load).
thread_num : int
Number of threads to use per block.
enable_rasteration : bool
Whether to enable rasterization (swizzling) optimization.
k_pack : int
K dimension packing factor to improve memory coalescing.
Returns
-------
Function
A TVM Tensor Language function (T.prim_func) that computes matmul.
"""
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((M, N), dtype),
):
"""
The compiled TVM function for block-level matrix multiplication.
- We divide the entire (M, N) domain into blocks of shape
(block_M, block_N).
- Each block has its own allocated shared memory for sub-blocks
of A and B.
- The partial results go into C_local, and then we copy them back
to global memory C.
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared = T.alloc_shared((block_N, block_K), dtype)
# Allocate a local fragment for intermediate accumulation
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Enable (or disable) swizzling optimization
T.use_swizzle(panel_size=10, enable=enable_rasteration)
# Clear out the accumulation buffer
T.clear(C_local)
# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared
T.copy(
A[by * block_M, k * block_K],
A_shared,
)
# Load a sub-block of B from global memory into B_shared
T.copy(
B[bx * block_N, k * block_K],
B_shared,
)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared^T
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
)
# Write back the results from C_local to the global memory C
T.copy(C_local, C[by * block_M, bx * block_N])
return main
return kernel()
def test_autotune_get_configs():
get_configs(8192, 8192, 8192, with_roller=False)
def test_autotune_matmul():
matmul(8192, 8192, 8192, with_roller=False)
if __name__ == "__main__":
tilelang.testing.main()
...@@ -127,6 +127,123 @@ def test_gemm_f16f16f16_nn(): ...@@ -127,6 +127,123 @@ def test_gemm_f16f16f16_nn():
) )
def matmu_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_jit_kernel(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmu_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dl_pack")
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
if trans_A:
A = A.T
if trans_B:
B = B.T
def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
ref_C = ref_program(A, B)
C = matmul_kernel(A, B)
tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_gemm_jit_kernel():
run_gemm_jit_kernel(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
if __name__ == "__main__": if __name__ == "__main__":
# tilelang.testing.main() tilelang.testing.main()
test_gemm_f16f16f16_nn()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang as tl
import tilelang.language as T
import tilelang.testing
def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2):
dtype = "float32"
@T.prim_func
def main(A: T.Buffer((M, N), dtype="float32"),):
with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N), dtype=dtype)
tid = T.get_thread_binding()
for j in T.serial(N):
A_shared[tid, j] = A[tid + M_offset, j + N_offset]
@T.prim_func
def expected(A: T.Buffer((M, N), dtype="float32"),):
with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N), dtype=dtype)
tid = T.get_thread_binding()
T.reads(A[tid + M_offset, N_offset:N + N_offset])
for j in T.serial(N):
A_shared[tid, j] = T.if_then_else(
j + N_offset < N,
T.if_then_else(tid + M_offset < M, A[tid + M_offset, j + N_offset],
T.float32(0)), T.float32(0))
return main, expected
def assert_vectorize_access(M: int = 64, N: int = 64):
func, expected = vectorize_access_legalize(M, N)
mod = tvm.IRModule({func.attrs["global_symbol"]: func})
transformed = tl.transform.LegalizeSafeMemoryAccess()(mod)
tvm.ir.assert_structural_equal(transformed["main"].body, expected.body)
def test_vectorize_access():
assert_vectorize_access(64, 64)
if __name__ == "__main__":
tilelang.testing.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang as tl
import tilelang.language as T
import tilelang.testing
def vectorize_access_legalize(M: int = 64, N: int = 64):
dtype = "float32"
vec_len = 8
@T.prim_func
def main(A: T.Buffer((M, N, vec_len), dtype="float32"),):
with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype)
tid = T.get_thread_binding()
for j in T.serial(N):
for v in T.vectorized(vec_len):
A_shared[tid, j, v] = A[tid, j, v]
@T.prim_func
def expected(A: T.Buffer((M, N, vec_len), dtype="float32"),):
with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype)
tid = T.get_thread_binding()
for j, v_2 in T.grid(M, vec_len // 4):
for vec in T.vectorized(4):
A_shared[tid, j, v_2 * 4 + vec] = A[tid, j, v_2 * 4 + vec]
return main, expected
def assert_vectorize_access(M: int = 64, N: int = 64):
func, expected = vectorize_access_legalize(M, N)
mod = tvm.IRModule({func.attrs["global_symbol"]: func})
transformed = tl.transform.LegalizeVectorizedLoop()(mod)
tvm.ir.assert_structural_equal(transformed["main"].body, expected.body)
def test_vectorize_access():
assert_vectorize_access(64, 64)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment