Commit 7bde63d5 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Feat] Introduce new caching mechanism for compiled kernels (#176)

* Add kernel caching mechanism to TileLang

- Implement a new `cached` function in `tilelang/cache/__init__.py` to cache and reuse compiled kernels
- Expose the `cached` function in the main `tilelang/__init__.py`
- Add a test case for cached matrix multiplication in `testing/python/cache/test_tilelang_cache_matmul.py`
- Provide a `clear_cache()` function to reset the kernel cache when needed

* Refactor kernel caching test and implementation

- Simplify the `cached` function in `tilelang/cache/__init__.py`
- Update test script `test_tilelang_cache_matmul.py` to use `tilelang.testing.main()`
- Remove unnecessary whitespace and improve code formatting

* Update import for `cached` function in MHA examples

- Modify import statement in `example_mha_bwd.py` and `test_tilelang_kernel_mha_bwd.py`
- Change import from `tilelang.profiler import cached` to `tilelang import cached`
- Align with recent refactoring of kernel caching mechanism

* Refactor `cached` function signature in kernel caching

- Update function signature to use keyword-only arguments for `target` and `target_host`
- Improve parameter order and readability of the `cached` decorator
- Maintain existing functionality while enhancing function definition
parent fb6b101c
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tilelang import tilelang
from tilelang.profiler import cached from tilelang import cached
from tilelang.autotuner import * from tilelang.autotuner import *
import tilelang.language as T import tilelang.language as T
import argparse import argparse
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang.testing
from tilelang import cached
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_cache_matmul(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = cached(program, [2])
profiler = kernel.get_profiler()
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_cache_matmul_f16f16f16_nn():
run_cache_matmul(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
if __name__ == "__main__":
tilelang.testing.main()
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tilelang import tilelang
from tilelang.profiler import cached from tilelang import cached
import tilelang.language as T import tilelang.language as T
import tilelang.testing import tilelang.testing
......
...@@ -107,6 +107,7 @@ if SKIP_LOADING_TILELANG_SO == "0": ...@@ -107,6 +107,7 @@ if SKIP_LOADING_TILELANG_SO == "0":
from .jit import jit, JITKernel, compile # noqa: F401 from .jit import jit, JITKernel, compile # noqa: F401
from .profiler import Profiler # noqa: F401 from .profiler import Profiler # noqa: F401
from .cache import cached # noqa: F401
from .utils import ( from .utils import (
TensorSupplyType, # noqa: F401 TensorSupplyType, # noqa: F401
......
"""The cache utils"""
from tilelang import compile
from tilelang.jit import JITKernel
from typing import Callable, List, Union
from tvm.target import Target
from tvm.tir import PrimFunc
# Dictionary to store cached kernels
_cached = {}
def cached(
func: Callable,
out_idx: List[int] = None,
*args,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
) -> JITKernel:
"""
Cache and reuse compiled kernels to avoid redundant compilation.
Args:
func: Function to be compiled or a PrimFunc that's already prepared
out_idx: Indices specifying which outputs to return
target: Compilation target platform
target_host: Host target for compilation
*args: Arguments passed to func when calling it
Returns:
JITKernel: The compiled kernel, either freshly compiled or from cache
"""
global _cached
# Create a unique key based on the function, output indices and arguments
key = (func, tuple(out_idx), *args)
# Return cached kernel if available
if key not in _cached:
# Handle both PrimFunc objects and callable functions
program = func if isinstance(func, PrimFunc) else func(*args)
# Compile the program to a kernel
kernel = compile(program, out_idx=out_idx, target=target, target_host=target_host)
# Store in cache for future use
_cached[key] = kernel
return _cached[key]
def clear_cache():
"""
Clear the entire kernel cache.
This function resets the internal cache dictionary that stores compiled kernels.
Use this when you want to free memory or ensure fresh compilation
of kernels in a new context.
"""
global _cached
_cached = {}
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from typing import List, Literal, Optional, Callable from typing import List, Literal, Optional, Callable
...@@ -9,8 +7,6 @@ from contextlib import suppress ...@@ -9,8 +7,6 @@ from contextlib import suppress
import tvm import tvm
from tvm.relay import TensorType from tvm.relay import TensorType
from tilelang.engine import lower
from tilelang.jit.adapter import TorchDLPackKernelAdapter from tilelang.jit.adapter import TorchDLPackKernelAdapter
from tilelang.utils.tensor import ( from tilelang.utils.tensor import (
get_tensor_supply, get_tensor_supply,
...@@ -240,17 +236,3 @@ def do_bench( ...@@ -240,17 +236,3 @@ def do_bench(
ret = ret[0] ret = ret[0]
return ret return ret
return getattr(torch, return_mode)(times).item() return getattr(torch, return_mode)(times).item()
_cached = {}
def cached(func, result_idx: List[int], *args):
global _cached
key = (func, tuple(result_idx), *args)
if key not in _cached:
program = func(*args)
mod, params = lower(program)
mod = TorchDLPackKernelAdapter(mod, params, result_idx)
_cached[key] = mod
return _cached[key]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment