Commit 5e259239 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Doc] Addd debug relevant testing and documentations (#58)

* implement jit test case

* [Dev] implement auto tune test case for matrix multiplication

* Implement test for legalize memory access and vectorized loop

* lint fix

* introduce run_once

* Refactor callback function names for consistency and improve code readability

* enhance documentations

* lint fix

* lint fix

* lint fix

* lint fix

* fix formatting issues in rt_mod_hip.cc

* add random seed initialization for deterministic testing
parent 0d8421f1
......@@ -68,9 +68,9 @@ def debug_print_register_files(M=16, N=16):
@T.prim_func
def program(Q: T.Buffer((M, N), dtype)):
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
shared_buf = T.alloc_fragment([M, N], dtype)
register_buf = T.alloc_fragment([M, N], dtype)
for i, j in T.Parallel(M, N):
T.print(shared_buf[i, j])
T.print(register_buf[i, j])
jit_kernel = tilelang.JITKernel(program, target="cuda")
profiler = jit_kernel.get_profiler()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang.testing
import tilelang
import torch
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
stramp = "&*(XS)"
@tvm.register_func
def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code
return code
matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dl_pack")
kernel_source = matmul_kernel.get_kernel_source()
assert stramp in kernel_source, f"Expected {stramp} in the kernel source"
def test_gemm_f16f16f16_nn():
run_gemm(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
def matmu_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_jit_kernel(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmu_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dl_pack")
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
if trans_A:
A = A.T
if trans_B:
B = B.T
def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
ref_C = ref_program(A, B)
C = matmul_kernel(A, B)
tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_gemm_jit_kernel():
run_gemm_jit_kernel(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -6,6 +6,8 @@ import tilelang.testing
import tilelang as tl
import tilelang.language as T
tilelang.testing.set_random_seed(0)
def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, block_M, block_N,
block_K, block_Dstate, num_stages, threads):
......
......@@ -98,8 +98,8 @@ def compile_hip(code,
return data
@tvm._ffi.register_func("tvm_callback_hip_compile", override=True)
def tvm_callback_hip_compile(code, target):
@tvm._ffi.register_func("tilelang_callback_hip_compile", override=True)
def tilelang_callback_hip_compile(code, target):
"""use hipcc to generate fatbin code for better optimization"""
hsaco = compile_hip(code, target_format="hsaco")
return hsaco
......@@ -188,14 +188,14 @@ def get_cuda_version(cuda_path=None):
raise RuntimeError("Cannot read cuda version file")
@tvm._ffi.register_func("tvm_callback_cuda_compile", override=True)
def tvm_callback_cuda_compile(code, target): # pylint: disable=unused-argument
@tvm._ffi.register_func("tilelang_callback_cuda_compile", override=True)
def tilelang_callback_cuda_compile(code, target): # pylint: disable=unused-argument
"""use nvcc to generate fatbin code for better optimization"""
ptx = compile_cuda(code, target_format="fatbin")
return ptx
@tvm._ffi.register_func("tvm_callback_libdevice_path", override=True)
@tvm._ffi.register_func("tilelang_callback_libdevice_path", override=True)
def find_libdevice_path(arch):
"""Utility function to find libdevice
......
......@@ -31,8 +31,8 @@ def is_host_call(func: tir.PrimFunc):
return not is_device_call(func)
@tvm.register_func("tvm_callback_cuda_compile", override=True)
def tvm_callback_cuda_compile(code, target):
@tvm.register_func("tilelang_callback_cuda_compile", override=True)
def tilelang_callback_cuda_compile(code, target):
project_root = osp.join(osp.dirname(__file__), "../..")
if "TL_TEMPLATE_PATH" in os.environ:
tl_template_path = os.environ["TL_TEMPLATE_PATH"]
......@@ -73,8 +73,8 @@ def tvm_callback_cuda_compile(code, target):
return ptx
@tvm.register_func("tvm_callback_hip_compile", override=True)
def tvm_callback_hip_compile(code, target):
@tvm.register_func("tilelang_callback_hip_compile", override=True)
def tilelang_callback_hip_compile(code, target):
project_root = osp.join(osp.dirname(__file__), "../..")
tl_template_path = osp.abspath(osp.join(project_root, "src"))
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import List, Union, Any, Callable, Literal
from typing import List, Union, Any, Callable, Literal, Optional
from tvm.target import Target
import tilelang
from tilelang import tvm as tvm
......@@ -194,3 +194,6 @@ class JITKernel(object):
The source code of the compiled kernel function.
"""
return self.rt_module.imported_modules[0].get_source()
def run_once(self, func: Optional[Callable] = None) -> None:
return self.get_profiler().run_once(func)
......@@ -2,7 +2,7 @@
# Licensed under the MIT License.
"""The profiler and convert to torch utils"""
from typing import List, Literal
from typing import List, Literal, Optional, Callable
from functools import partial
import torch
from contextlib import suppress
......@@ -41,7 +41,7 @@ class Profiler(TorchDLPackKernelAdapter):
def assert_allclose(
self,
reference_program: callable,
reference_program: Callable,
atol: float = 1e-2,
rtol: float = 1e-2,
max_mismatched_ratio=0.01,
......@@ -87,11 +87,7 @@ class Profiler(TorchDLPackKernelAdapter):
rhs,
]
def run_once(self, func=None):
import ctypes
libcuda = ctypes.CDLL("libcuda.so") # noqa: F841
def run_once(self, func: Optional[Callable] = None):
ins = self._get_inputs()
if not func:
func = self.__call__
......@@ -99,11 +95,11 @@ class Profiler(TorchDLPackKernelAdapter):
def do_bench(
self,
func: callable = None,
warmup=25,
rep=100,
n_warmup=1,
n_repeat=1,
func: Optional[Callable] = None,
warmup: int = 25,
rep: int = 100,
n_warmup: int = 1,
n_repeat: int = 1,
profiler: Literal["torch", "tvm", "auto"] = "auto",
input_tensors: List[torch.Tensor] = None,
):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment