Unverified Commit 0921328d authored by Kuris's avatar Kuris Committed by GitHub
Browse files

[Language] Tilelang LazyJIT Experimental Version (#1337)



* initial step

* modify builder

* scratch version of new frontend

* write some tests

* add many tests

* add typing stub for tir.ir

* remove idents

* minor update

* minor update

* First version of jitv2 (renamed to LazyJIT)

* fix pre-commit error

* minor fix

* fix lint error

* fix lint error

* Fix conditional check for PrimFunc instance

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent 8d019eb9
...@@ -35,11 +35,7 @@ repos: ...@@ -35,11 +35,7 @@ repos:
rev: v21.1.6 # sync with requirements-lint.txt rev: v21.1.6 # sync with requirements-lint.txt
hooks: hooks:
- id: clang-format - id: clang-format
exclude: | types_or: [c++, c]
(?ix)(
^.+\.(cu|cuh)$|
^.+\.json$
)
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.7 # sync with requirements-lint.txt rev: v0.14.7 # sync with requirements-lint.txt
hooks: hooks:
...@@ -66,4 +62,4 @@ repos: ...@@ -66,4 +62,4 @@ repos:
^.+\.(cpp|hpp|cxx|cc|c|h|cu|cuh)$| ^.+\.(cpp|hpp|cxx|cc|c|h|cu|cuh)$|
^.+\.svg$| ^.+\.svg$|
^.*\brequirements\b.*\.txt$ ^.*\brequirements\b.*\.txt$
) )
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
...@@ -252,9 +252,9 @@ def test_marco_return(): ...@@ -252,9 +252,9 @@ def test_marco_return():
c = macro_return_expr(4.0) c = macro_return_expr(4.0)
d = macro_apply_func(5.0, lambda x: x * 2.0) d = macro_apply_func(5.0, lambda x: x * 2.0)
check(a, (int, float, T.PrimExpr)) check(a, (int, float, T.PrimExpr))
check(b, T.PrimExpr) check(b, (int, float, T.PrimExpr))
check(c, T.PrimExpr) check(c, (int, float, T.PrimExpr))
check(d, T.PrimExpr) check(d, (int, float, T.PrimExpr))
def test_prim_func_generator(): def test_prim_func_generator():
......
from dataclasses import dataclass, field
import tilelang.testing
import tilelang
import tilelang.language as T
from typing import Any
from itertools import product
import torch
def _gemm_impl():
@T.macro
def gemm_impl(
A: T.Tensor[[int, int], Any],
B: T.Tensor[[int, int], Any],
C: T.Tensor[[int, int], Any],
out_dtype: T.dtype,
block_M: int,
block_N: int,
block_K: int,
):
dtype = A.dtype
M, K = A.shape
K, N = B.shape
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), out_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[bx * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, by * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[bx * block_M, by * block_N])
return gemm_impl
def test_jit2_gemm_annot():
@tilelang.lazy_jit
def gemm(
A: T.Tensor[[int, int], Any],
B: T.Tensor[[int, int], Any],
out_dtype: T.dtype = T.float32,
block_M: int = 64,
block_N: int = 64,
block_K: int = 32,
):
M, K = A.shape
K, N = B.shape
C = T.empty(M, N, dtype=out_dtype)
_gemm_impl()(A, B, C, out_dtype, block_M, block_N, block_K)
return C
prod = product([T.float16, T.float32], [T.float32])
gemm.par_compile([{
'A': T.Tensor((1024, 1024), dtype=in_dtype),
'B': T.Tensor((1024, 1024), dtype=in_dtype),
'out_dtype': out_dtype
} for in_dtype, out_dtype in prod])
for in_dtype, out_dtype in prod:
in_dtype = in_dtype.torch()
out_dtype = out_dtype.torch()
A = torch.randn(1024, 1024, dtype=in_dtype, device='cuda')
B = torch.randn(1024, 1024, dtype=in_dtype, device='cuda')
C_ref = out_dtype(A @ B)
C = gemm(A, B)
torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)
def test_jit2_gemm_ptr():
@tilelang.lazy_jit
def gemm_ptr(
A: T.ptr,
B: T.ptr,
C: T.ptr,
M: int,
N: int,
K: int,
dtype: T.dtype,
out_dtype: T.dtype,
block_M: int = 64,
block_N: int = 64,
block_K: int = 32,
):
A = T.make_tensor(A, (M, K), dtype)
B = T.make_tensor(B, (K, N), dtype)
C = T.make_tensor(C, (M, N), out_dtype)
_gemm_impl()(A, B, C, out_dtype, block_M, block_N, block_K)
prod = product([T.float16, T.float32], [T.float32])
gemm_ptr.par_compile([{
'A': T.ptr(),
'B': T.ptr(),
'C': T.ptr(),
'M': 1024,
'N': 1024,
'K': 1024,
'dtype': in_dtype,
'out_dtype': out_dtype
} for in_dtype, out_dtype in prod])
for in_dtype, out_dtype in prod:
in_dtype = in_dtype.torch()
out_dtype = out_dtype.torch()
A = torch.randn(1024, 1024, dtype=in_dtype, device='cuda')
B = torch.randn(1024, 1024, dtype=in_dtype, device='cuda')
C_ref = out_dtype(A @ B)
C = torch.empty(1024, 1024, dtype=out_dtype, device='cuda')
gemm_ptr(A, B, C, 1024, 1024, 1024, in_dtype, out_dtype)
torch.testing.assert_close(C, C_ref, atol=1e-2, rtol=1e-2)
def test_jit2_annot():
from tilelang.language.v2.annot import Annot, ArgVarTable
from tilelang.language.v2.builder import Builder
import traceback
@dataclass
class AnnotTest:
annot: Annot
promote: Any
match_ok: list[Any] = field(default_factory=list)
match_ng: list[Any] = field(default_factory=list)
tests = [
AnnotTest(
annot=T.Tensor[[int, int], T.float32],
promote=False,
match_ok=[torch.randn(1, 1, dtype=torch.float32),
T.Tensor((1, 1), dtype=T.float32)],
match_ng=[
torch.randn(1, 1, dtype=torch.float16),
T.Tensor(1, dtype=T.float32),
T.Tensor((1, 1), dtype=T.float16),
],
),
AnnotTest(
annot=T.Tensor[[int], Any],
promote=False,
match_ok=[
torch.randn(12, dtype=torch.float32),
torch.randn(12, dtype=torch.float16),
T.Tensor((1,), dtype=T.float32),
T.Tensor((1,), dtype=T.float16),
],
match_ng=[torch.randn((1, 1), dtype=torch.float32),
T.Tensor((1, 1), dtype=T.float16)]),
AnnotTest(
annot=T.Tensor[[int, 1], Any],
promote=False,
match_ok=[
torch.randn(12, 1, dtype=torch.float32),
torch.randn(12, 1, dtype=torch.float16),
T.Tensor((12, 1), T.float32),
T.Tensor((12, 1), T.float16),
],
match_ng=[torch.randn(12, 12, dtype=torch.float32),
T.Tensor((12, 12), T.float32)]),
AnnotTest(
annot=T.Tensor[[T.dyn, 1], Any],
promote=False,
match_ok=[
torch.randn(12, 1, dtype=torch.float32),
torch.randn(12, 1, dtype=torch.float16),
T.Tensor((12, 1), T.float32),
T.Tensor((12, 1), T.float16),
],
match_ng=[torch.randn(12, 12, dtype=torch.float32),
T.Tensor((12, 12), T.float32)]),
AnnotTest(
annot=T.Tensor[[1024, 1024], T.float32],
promote=True,
),
AnnotTest(annot=T.dyn[int, 'X'], promote=False, match_ok=[1, 2, 3, 4]),
AnnotTest(annot=T.dyn, promote=False, match_ok=[1, 2, 3, 4])
]
for test in tests:
promote = test.annot.promote()
promoted = promote is not None
if promoted != test.promote:
raise AssertionError(
f'Promote mismatch for {test.annot}: expected {test.promote}, got {promoted}')
with Builder().prim_func('_test'):
for match_ok in test.match_ok:
try:
vt = ArgVarTable()
test.annot.create_prim_func_arg('arg', match_ok, vt)
except Exception as e:
traceback.print_exc()
raise AssertionError(
f'Match failed for {test.annot} with value {match_ok}: {e}') from e
for match_ng in test.match_ng:
try:
vt = ArgVarTable()
test.annot.create_prim_func_arg('arg', match_ng, vt)
raise AssertionError(
f'Match unexpectedly succeeded for {test.annot} with value {match_ng}')
except Exception:
pass
def test_jit2_many_annot():
@T.macro
def copy_impl(A, B):
M, N = A.shape
M_, N_ = B.shape
assert M == M_, f"M mismatch {M} {M_}"
assert N == N_, f"N mismatch {N} {N_}"
# assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}"
with T.Kernel(T.ceildiv(M, 128), T.ceildiv(N, 128), threads=128) as (bx, by):
T.copy(A[bx * 128:bx * 128 + 128, by * 128:by * 128 + 128], B[bx * 128:bx * 128 + 128,
by * 128:by * 128 + 128])
@tilelang.lazy_jit
def copy1(
A: T.Tensor[[int, int], T.float32],
B: T.Tensor[[int, int], T.float32],
):
copy_impl(A, B)
@tilelang.lazy_jit
def copy2(
A: T.Tensor[[128, 128], T.float32],
B: T.Tensor[[128, 128], T.float32],
):
copy_impl(A, B)
@tilelang.lazy_jit
def copy3(
A: T.Tensor[[int, 128], T.float32],
B: T.Tensor[[int, 128], T.float32],
):
copy_impl(A, B)
@tilelang.lazy_jit
def copy4(
A: T.Tensor[[T.dyn, int], T.float32],
B: T.Tensor[[T.dyn, int], T.float32],
):
copy_impl(A, B)
@tilelang.lazy_jit
def copy5(
A: T.StridedTensor[[int, int], [int, int], T.float32],
B: T.StridedTensor[[int, int], [int, int], T.float32],
):
copy_impl(A, B)
@tilelang.lazy_jit
def copy6(
A: T.StridedTensor[[T.dyn, int], [int, int], T.float32],
B: T.StridedTensor[[T.dyn, int], [int, int], T.float32],
):
copy_impl(A, B)
for copy in [copy1, copy2, copy3, copy4]:
A = torch.randn(128, 128, device='cuda')
B = torch.empty(128, 128, device='cuda')
copy(A, B)
assert torch.equal(B, A)
for copy in [copy5, copy6]:
A = torch.randn(128, 2, 128, 2, device='cuda')
B = torch.randn(128, 2, 128, 2, device='cuda')
copy(A[:, 0, :, 0], B[:, 0, :, 0])
assert torch.equal(A[:, 0, :, 0], B[:, 0, :, 0])
def test_jit2_return():
@T.macro
def copy_impl(A):
M, N = A.shape
B = T.empty(M, N, dtype=A.dtype)
M, N = A.shape
M_, N_ = B.shape
assert M == M_, f"M mismatch {M} {M_}"
assert N == N_, f"N mismatch {N} {N_}"
# assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}"
with T.Kernel(T.ceildiv(M, 128), T.ceildiv(N, 128), threads=128) as (bx, by):
T.copy(A[bx * 128:bx * 128 + 128, by * 128:by * 128 + 128], B[bx * 128:bx * 128 + 128,
by * 128:by * 128 + 128])
return B
@tilelang.lazy_jit
def copy0(A: T.Tensor[[int, int], Any]):
return copy_impl(A)
@tilelang.lazy_jit
def copy1(A: T.Tensor[[int, int], T.float32],):
return copy_impl(A)
@tilelang.lazy_jit
def copy2(A: T.Tensor[[128, 128], T.float32],):
return copy_impl(A)
@tilelang.lazy_jit
def copy3(A: T.Tensor[[int, 128], T.float32],):
return copy_impl(A)
@tilelang.lazy_jit
def copy4(A: T.Tensor[[T.dyn, int], T.float32],):
return copy_impl(A)
@tilelang.lazy_jit
def copy5(A: T.StridedTensor[[int, int], [int, int], T.float32],):
return copy_impl(A)
@tilelang.lazy_jit
def copy6(A: T.StridedTensor[[T.dyn, int], [int, int], T.float32],):
return copy_impl(A)
for copy in [copy0, copy1, copy2, copy3, copy4]:
A = torch.randn(128, 128, device='cuda')
B = copy(A)
assert torch.equal(B, A)
for copy in [copy5, copy6]:
A = torch.randn(128, 2, 128, 2, device='cuda')
B = copy(A[:, 0, :, 0])
assert torch.equal(A[:, 0, :, 0], B)
def test_jit2_deepseek_deepgemm():
@tilelang.lazy_jit
def deep_gemm(
A: T.Tensor[[int, int], T.float8_e4m3],
B: T.Tensor[[int, int], T.float8_e4m3],
scales_a: T.Tensor[[int, int], T.float32],
scales_b: T.Tensor[[int, int], T.float32],
out_dtype: T.dtype = T.bfloat16,
accum_dtype: T.dtype = T.float32,
block_N: int = 128,
block_M: int = 128,
block_K: int = 128,
):
# A: [M, K]
# B: [N, K]
# scales_a: [M, K // 128]
# scales_b: [N, K // 128]
# C: [M, N]
group_size = 128
in_dtype = A.dtype
M, K = A.shape
N, K = B.shape
C = T.empty(M, N, dtype=out_dtype)
assert out_dtype in [
T.bfloat16, T.float32
], f"Expect out_dtype to be one of [T.float16, T.float32], got {out_dtype}"
assert scales_a.shape == [M, T.ceildiv(K, group_size)
], f"Expect scales_a shape to be f{[M, T.ceildiv(K, group_size)]}"
assert scales_b.shape == [N, T.ceildiv(K, group_size)
], f"Expect scales_b shape to be f{[N, T.ceildiv(K, group_size)]}"
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), in_dtype)
B_shared = T.alloc_shared((block_N, block_K), in_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
scale_C_shared = T.alloc_shared((block_M,), T.float32)
C_local = T.alloc_fragment((block_M, block_K), accum_dtype)
C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
T.use_swizzle(panel_size=10)
T.clear(C_local)
T.clear(C_local_accum)
K_iters = T.ceildiv(K, block_K)
for k in T.Pipelined(K_iters, num_stages=4):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
Scale_B = scales_b[bx * block_N // group_size, k]
for i in T.Parallel(block_M):
scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
for i, j in T.Parallel(block_M, block_N):
C_local_accum[i, j] += C_local[i, j] * scale_C_shared[i]
T.clear(C_local)
T.copy(C_local_accum, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return C
# def ceildiv(a, b):
# return (a + b - 1) // b
# def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype):
# # A_scale: (M, K//128) ==> (M//128, K//128, 128)
# # B_scale: (N//128, K//128) ==> (N//128, K//128, 128)
# # A_fp8: (M, K)
# # B_fp8: (N, K)
# # out_dtype: float16 or float32
# # return C: (M, N)
# M, N, K = A_fp8.shape[0], B_fp8.shape[0], A_fp8.shape[1]
# A_scales = A_scale.view(M // 128, 128, K // 128).permute(0, 2, 1)
# B_scales = B_scale.repeat_interleave(128, dim=1).view(N // 128, K // 128, 128)
# C = torch.zeros(M, N, device="cuda", dtype=out_dtype)
# c_acc = torch.zeros(128, 128, device="cuda", dtype=torch.float32)
# for i in range(ceildiv(M, 128)):
# for j in range(ceildiv(N, 128)):
# c_acc.zero_()
# for k in range(ceildiv(K, 128)):
# c = torch._scaled_mm(
# A_fp8[i * 128:(i + 1) * 128, k * 128:(k + 1) * 128],
# B_fp8[j * 128:(j + 1) * 128, k * 128:(k + 1) * 128].T,
# scale_a=A_scales[i, k].view(128, 1).contiguous(),
# scale_b=B_scales[j, k].view(1, 128).contiguous(),
# out_dtype=torch.bfloat16)
# c_acc += c.to(torch.float32)
# C[i * 128:(i + 1) * 128, j * 128:(j + 1) * 128] = c_acc.to(out_dtype)
# return C
# M, N, K = 1024, 1024, 8192
# A = torch.randn((M, K), dtype=torch.float8_e4m3fn, )
if __name__ == '__main__':
tilelang.testing.main()
...@@ -120,7 +120,7 @@ def _load_tile_lang_lib(): ...@@ -120,7 +120,7 @@ def _load_tile_lang_lib():
if env.SKIP_LOADING_TILELANG_SO == "0": if env.SKIP_LOADING_TILELANG_SO == "0":
_LIB, _LIB_PATH = _load_tile_lang_lib() _LIB, _LIB_PATH = _load_tile_lang_lib()
from .jit import jit, JITKernel, compile # noqa: F401 from .jit import jit, lazy_jit, JITKernel, compile, par_compile # noqa: F401
from .profiler import Profiler # noqa: F401 from .profiler import Profiler # noqa: F401
from .cache import clear_cache # noqa: F401 from .cache import clear_cache # noqa: F401
......
...@@ -141,6 +141,10 @@ def extrac_params(func: tir.PrimFunc) -> list[KernelParam]: ...@@ -141,6 +141,10 @@ def extrac_params(func: tir.PrimFunc) -> list[KernelParam]:
if var in func.buffer_map: if var in func.buffer_map:
tensor_types.append(KernelParam.from_buffer(func.buffer_map[var])) tensor_types.append(KernelParam.from_buffer(func.buffer_map[var]))
else: else:
if var.dtype == 'handle':
raise ValueError(
f'Handle parameter {var} must be mapped to a buffer.\n'
f'Please use T.tensor({var.name}, shape=..., dtype=...) to map it to a buffer.')
tensor_types.append(KernelParam.from_var(var)) tensor_types.append(KernelParam.from_var(var))
return tensor_types return tensor_types
......
...@@ -16,13 +16,15 @@ from typing import ( ...@@ -16,13 +16,15 @@ from typing import (
Literal, Literal,
) )
from collections.abc import Iterable from collections.abc import Iterable
# Python 3.9 compatibility for ParamSpec # Python 3.9 compatibility for ParamSpec
try: try:
from typing import ParamSpec from typing import ParamSpec
except ImportError: # Python < 3.10 except ImportError: # Python < 3.10
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang.language.v2 import PrimFunc from tilelang.language.v2 import PrimFunc, PrimFuncCreater, prim_func
from tilelang.language.v2.annot import Annot
from tvm.target import Target from tvm.target import Target
from tilelang.jit.kernel import JITKernel from tilelang.jit.kernel import JITKernel
...@@ -40,6 +42,7 @@ logger = getLogger(__name__) ...@@ -40,6 +42,7 @@ logger = getLogger(__name__)
_P = ParamSpec('_P') _P = ParamSpec('_P')
_KP = ParamSpec('_KP') _KP = ParamSpec('_KP')
_T = TypeVar('_T') _T = TypeVar('_T')
_Ret = TypeVar('_Ret')
def compile( def compile(
...@@ -74,10 +77,19 @@ def compile( ...@@ -74,10 +77,19 @@ def compile(
Additional keyword arguments to pass to the Compiler PassContext. Additional keyword arguments to pass to the Compiler PassContext.
Refer to `tilelang.transform.PassConfigKey` for supported options. Refer to `tilelang.transform.PassConfigKey` for supported options.
""" """
assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}" assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}"
if isinstance(compile_flags, str): if isinstance(compile_flags, str):
compile_flags = [compile_flags] compile_flags = [compile_flags]
if hasattr(func, 'out_idx_override'):
if func.out_idx_override is not None and out_idx is not None:
raise ValueError(
"Out index conflict: out_idx is specified and prim_func have returned `T.empty` tensors"
)
out_idx = func.out_idx_override or out_idx
# This path is not a performance critical path, so we can afford to convert the target. # This path is not a performance critical path, so we can afford to convert the target.
target = Target(determine_target(target)) target = Target(determine_target(target))
...@@ -176,8 +188,76 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]], ...@@ -176,8 +188,76 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
@dataclass @dataclass
class JITImpl(Generic[_P, _KP, _T]): class JITImpl(Generic[_P, _KP, _T, _Ret]):
func: Callable[_P, _T] | PrimFunc[_KP, _T] '''
Detailed Just-In-Time wrapper for TileLang programs.
This dataclass encapsulates the configuration and runtime helpers used by the
top-level `jit` and `jit2` decorators. It represents a configured JIT
"factory" that can (a) elaborate TileLang/PrimFunc creators into concrete
TIR (PrimFunc), (b) compile those TIR functions into runnable kernels via
the TVM bridge, (c) cache compiled kernels keyed by call-site arguments
(and optional tuning parameters), and (d) provide parallel compilation
helpers for batch autotuning workflows.
Attributes
----------
out_idx : list[int] | int | None
Which output tensor(s) of the compiled kernel should be returned to the
caller. Accepts a single index, a list of indices, or None to return all.
execution_backend : Literal["dlpack", "ctypes", "cython"]
Backend used for exchanging arguments and executing the generated kernel.
target : str | tvm.target.Target
TVM compilation target (e.g. "cuda", "llvm", or "auto").
target_host : str | tvm.target.Target | None
Host target used for cross-compilation, or None to infer/default.
verbose : bool
Enable verbose messages during compilation/build.
pass_configs : dict[str, Any] | None
Extra TVM pass configuration options forwarded to the compiler's
PassContext.
debug_root_path : str | None
If provided, compiled kernel source and the elaborated Python program
are written to this directory to ease debugging and inspection.
compile_flags : list[str] | str | None
Additional flags passed to the compiler. A single string will be converted
to a single-element list.
func_source : str
Original Python source string from which the PrimFunc or creator was
derived. Used for diagnostics and debug dumps.
signature : inspect.Signature
Function signature of the original Python function (useful for tooling).
v2 : bool
Indicates whether the object wraps a "v2" PrimFunc creator (True) or a
plain callable / PrimFunc (False). v2-mode enables argument conversion
hooks and a distinct cache keying strategy.
func : Callable | PrimFunc | PrimFuncCreater
The underlying object: either a user function that returns a PrimFunc
(creator), a PrimFuncCreater, or an already-constructed PrimFunc.
For presentation/readability the function is stored last in the dataclass.
Behavioral summary
------------------
- get_tir(*args, **kwargs)
Converts provided call-site arguments into a concrete PrimFunc. If the
wrapped object is a PrimFuncCreater or a user callable, it is invoked
with the given arguments. If the wrapped object is already a PrimFunc,
it is returned as-is.
- compile(...)
A convenience wrapper that elaborates and immediately compiles a single
PrimFunc into a JITKernel using the module-level `compile` function.
When `debug_root_path` is set, the compiled C kernel and the source
Python program are saved for inspection.
- par_compile(configs, ...)
Accepts an iterable of configs (either dicts mapping keyword args or
tuples mapping to positional args). Each config is elaborated to a
PrimFunc and the resulting set is compiled in parallel via the
module-level `par_compile` helper. Returns a list of JITKernel objects
in the same order as the provided configs.
'''
out_idx: list[int] | int | None out_idx: list[int] | int | None
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"]
target: str | Target target: str | Target
...@@ -188,6 +268,14 @@ class JITImpl(Generic[_P, _KP, _T]): ...@@ -188,6 +268,14 @@ class JITImpl(Generic[_P, _KP, _T]):
compile_flags: list[str] | str | None compile_flags: list[str] | str | None
func_source: str func_source: str
signature: inspect.Signature signature: inspect.Signature
lazy_jit: bool
# place func at the last element for better __repr__
func: Callable[_P, _T] | PrimFunc[_KP, _T]
@property
def annot(self) -> dict[str, Annot]:
assert self.lazy_jit, "annot is only support in @tilelang.jit2"
return self.func.func_annot.annots
def __post_init__(self): def __post_init__(self):
if self.debug_root_path is not None and not path.isabs(self.debug_root_path): if self.debug_root_path is not None and not path.isabs(self.debug_root_path):
...@@ -197,21 +285,47 @@ class JITImpl(Generic[_P, _KP, _T]): ...@@ -197,21 +285,47 @@ class JITImpl(Generic[_P, _KP, _T]):
except NameError: except NameError:
self.debug_root_path = path.abspath(self.debug_root_path) self.debug_root_path = path.abspath(self.debug_root_path)
self._kernel_cache: dict[tuple, Kernel] = {} self._kernel_cache: dict[tuple, Kernel] = {}
self._tuner_cache: dict[tuple, Kernel] = {}
def get_tir(self, *args: _P.args, **kwargs: _P.kwargs) -> PrimFunc[_KP, _T]: def get_tir(self, *args: _P.args, **kwargs: _P.kwargs) -> PrimFunc[_KP, _T]:
program_result_source = self.func """
if isinstance(program_result_source, PrimFunc): Retrieve a TIR (Tensor Intermediate Representation) PrimFunc from the stored callable or object.
program_result = program_result_source """
elif callable(program_result_source): if isinstance(self.func, PrimFuncCreater):
program_result = program_result_source(*args, **kwargs) tir = self.func(*args, **kwargs)
elif isinstance(self.func, PrimFunc):
tir = self.func
elif callable(self.func):
tir = self.func(*args, **kwargs)
else: else:
raise ValueError(f"Invalid function type: {type(program_result_source)}") raise ValueError(f"Invalid function type: {type(self.func)}")
return program_result assert isinstance(tir, PrimFunc), f"target function must be a PrimFunc but got {type(tir)}"
return tir
def par_compile(self, def par_compile(self,
configs: Iterable[dict[str, Any] | tuple[str, Any]], configs: Iterable[dict[str, Any] | tuple[str, Any]],
num_workers: int = None, num_workers: int = None,
ignore_error: bool = False) -> list[JITKernel[_KP, _T]]: ignore_error: bool = False) -> list[JITKernel[_KP, _T]]:
"""
Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.
Parameters
----------
configs : Iterable[Union[dict[str, Any], tuple[Any, ...]]]
The configurations to elaborate and compile. Each config can be either
a dictionary mapping keyword arguments to values, or a tuple of positional
arguments.
num_workers : int, optional
Number of parallel workers to use for compilation. Defaults to None,
which lets the system decide.
ignore_error : bool, optional
If True, compilation errors for individual configs will be logged
as warnings and the corresponding result will be None. If False,
any compilation error will raise an exception. Defaults to False.
Returns
-------
List[JITKernel]
A list of compiled JITKernel objects corresponding to the provided configs.
"""
configs = list(configs) configs = list(configs)
funcs = [] funcs = []
for cfg in tqdm(configs, desc='Elaborating'): for cfg in tqdm(configs, desc='Elaborating'):
...@@ -233,7 +347,7 @@ class JITImpl(Generic[_P, _KP, _T]): ...@@ -233,7 +347,7 @@ class JITImpl(Generic[_P, _KP, _T]):
num_workers=num_workers, num_workers=num_workers,
ignore_error=ignore_error) ignore_error=ignore_error)
def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel[_KP, _T]: def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret:
func = self.get_tir(*args, **kwargs) func = self.get_tir(*args, **kwargs)
kernel_result = compile( kernel_result = compile(
func, func,
...@@ -261,12 +375,34 @@ class JITImpl(Generic[_P, _KP, _T]): ...@@ -261,12 +375,34 @@ class JITImpl(Generic[_P, _KP, _T]):
return kernel_result return kernel_result
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel[_KP, _T]: def parse_cache_key(self, *args: _P.args, **kwargs: _P.kwargs):
if isinstance(self.func, PrimFuncCreater):
tune_params = kwargs.pop('__tune_params', {})
return self.func.func_annot.parse_key(*args, **kwargs, **tune_params)
else:
tune_params = kwargs.pop('__tune_params', {})
key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
tuned_key_kwargs_tuple = tuple(sorted(tune_params.items()))
key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple)
return key
def convert_kernel_args(self, *args: _P.args, **kwargs: _P.kwargs):
if isinstance(self.func, PrimFuncCreater):
tune_params = kwargs.pop('__tune_params', {})
return self.func.func_annot.convert_to_kernel_args(*args, **kwargs, **tune_params)
else:
raise NotImplementedError(
"convert_arg_to_kernel_args is only implemented for PrimFuncCreater.")
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret:
# Separate out the tuning parameters from the user's kwargs # Separate out the tuning parameters from the user's kwargs
tune_params = kwargs.pop('__tune_params', {})
# Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache # Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache
return_compile_arguments = kwargs.pop('__return_compile_arguments', False) return_compile_arguments = kwargs.pop('__return_compile_arguments', False)
if return_compile_arguments: if return_compile_arguments:
logger.warning(
"`__return_compile_arguments` is deprecated and will be removed in future versions."
)
compile_args = { compile_args = {
'out_idx': self.out_idx, 'out_idx': self.out_idx,
'execution_backend': self.execution_backend, 'execution_backend': self.execution_backend,
...@@ -278,19 +414,27 @@ class JITImpl(Generic[_P, _KP, _T]): ...@@ -278,19 +414,27 @@ class JITImpl(Generic[_P, _KP, _T]):
} }
return compile_args return compile_args
key_args_tuple = args key = self.parse_cache_key(*args, **kwargs)
key_kwargs_tuple = tuple(sorted(kwargs.items()))
tuned_key_kwargs_tuple = tuple(sorted(tune_params.items()))
key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple)
if key not in self._kernel_cache: tune_params = kwargs.pop('__tune_params', {})
self._kernel_cache[key] = self.compile(*args, **kwargs, **tune_params)
kernel = self._kernel_cache.get(key, None)
if kernel is None:
kernel = self.compile(*args, **kwargs, **tune_params)
self._kernel_cache[key] = kernel
if self.lazy_jit:
args = self.func.func_annot.convert_to_kernel_args(*args, **kwargs, **tune_params)
return kernel(*args)
else:
return kernel
return self._kernel_cache[key]
ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"]
@overload @overload
def jit(func: Callable[_P, PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T]: def jit(func: Callable[_P, PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]:
... ...
...@@ -300,13 +444,12 @@ def jit( ...@@ -300,13 +444,12 @@ def jit(
out_idx: Any = None, out_idx: Any = None,
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target = None, target_host: str | Target = None,
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", execution_backend: ExecutionBackend = "auto",
"torch"] = "auto",
verbose: bool = False, verbose: bool = False,
pass_configs: dict[str, Any] | None = None, pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None, debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None compile_flags: list[str] | str | None = None
) -> Callable[[Callable[_P, PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T]]: ) -> Callable[[Callable[_P, PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]]:
... ...
...@@ -316,8 +459,7 @@ def jit( # This is the new public interface ...@@ -316,8 +459,7 @@ def jit( # This is the new public interface
out_idx: Any = None, out_idx: Any = None,
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target = None, target_host: str | Target = None,
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", execution_backend: ExecutionBackend = "auto",
"torch"] = "auto",
verbose: bool = False, verbose: bool = False,
pass_configs: dict[str, Any] | None = None, pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None, debug_root_path: str | None = None,
...@@ -358,12 +500,12 @@ def jit( # This is the new public interface ...@@ -358,12 +500,12 @@ def jit( # This is the new public interface
compile_flags = [compile_flags] compile_flags = [compile_flags]
def decorator(func: Callable[_P, _T]) -> JITImpl[_P, _T]: def decorator(func: Callable[_P, _T]) -> JITImpl[_P, _T]:
if isinstance(func, PrimFunc): if isinstance(func, (PrimFunc, PrimFuncCreater)):
orig_func = func.orig_func orig_func = func.orig_func
else: else:
orig_func = func orig_func = func
return JITImpl( return JITImpl(
func, func=func,
out_idx=out_idx, out_idx=out_idx,
execution_backend=execution_backend, execution_backend=execution_backend,
target=target, target=target,
...@@ -374,9 +516,70 @@ def jit( # This is the new public interface ...@@ -374,9 +516,70 @@ def jit( # This is the new public interface
compile_flags=compile_flags, compile_flags=compile_flags,
func_source=inspect.getsource(orig_func), func_source=inspect.getsource(orig_func),
signature=inspect.signature(orig_func), signature=inspect.signature(orig_func),
) lazy_jit=False)
if func is not None: if func is not None:
return decorator(func) return decorator(func)
else: else:
return decorator return decorator
@overload
def lazy_jit(func: Callable[_KP, _T]) -> JITImpl[_KP, _KP, _T, _T]:
...
@overload
def lazy_jit(
*,
out_idx: Any = None,
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: ExecutionBackend = "auto",
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None
) -> Callable[[Callable[_KP, _T]], JITImpl[_KP, _KP, _T, _T]]:
...
def lazy_jit(
func: Callable[_P, _T] | PrimFunc | None = None,
*, # Indicates subsequent arguments are keyword-only
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: ExecutionBackend = "auto",
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None,
):
if isinstance(compile_flags, str):
compile_flags = [compile_flags]
compile_args = dict(
out_idx=None,
execution_backend=execution_backend,
target=target,
target_host=target_host,
verbose=verbose,
pass_configs=pass_configs,
debug_root_path=debug_root_path,
compile_flags=compile_flags)
def decorator(func: Callable[_P, _T]):
pf: PrimFunc[_P, _T] | PrimFuncCreater[_P, _T] = prim_func(func, generator=True)
# if isinstance(pf, PrimFunc):
# compile_args.pop('debug_root_path', None)
# return compile(pf, **compile_args)
# else:
return JITImpl(
func=pf,
**compile_args,
func_source=inspect.getsource(pf.orig_func),
signature=inspect.signature(pf.orig_func),
lazy_jit=True)
return decorator(func) if func is not None else decorator
...@@ -106,6 +106,9 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): ...@@ -106,6 +106,9 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
params = func.params params = func.params
buffer_map = func.buffer_map buffer_map = func.buffer_map
dynamic_symbolic_map = {} dynamic_symbolic_map = {}
for i, param in enumerate(params):
if isinstance(param, tir.Var) and (param not in dynamic_symbolic_map):
dynamic_symbolic_map[param] = (2, i, -1)
for i, param in enumerate(params): for i, param in enumerate(params):
if param in buffer_map: if param in buffer_map:
buffer = buffer_map[param] buffer = buffer_map[param]
...@@ -217,7 +220,14 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): ...@@ -217,7 +220,14 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
if (str(s) == str(key)): if (str(s) == str(key)):
ref_id, ref_tensor_idx, ref_shape_idx = dynamic_symbolic_map[ ref_id, ref_tensor_idx, ref_shape_idx = dynamic_symbolic_map[
key] key]
shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx]) if ref_id == 2:
shape.append(inputs[ref_tensor_idx])
elif ref_id == 0:
shape.append(
tensor_list[ref_tensor_idx].shape[ref_shape_idx])
elif ref_id == 1:
shape.append(
tensor_list[ref_tensor_idx].stride()[ref_shape_idx])
else: # Already converted to Python int during initialization else: # Already converted to Python int during initialization
shape.append(s) shape.append(s)
......
...@@ -13,16 +13,15 @@ from . import overrides as _overrides # noqa: F401 ...@@ -13,16 +13,15 @@ from . import overrides as _overrides # noqa: F401
from .v2 import * # noqa: F401 from .v2 import * # noqa: F401
from .tir.ir import * # noqa: F401 from .tir.ir import * # noqa: F401
from tilelang.layout import Layout, Fragment # noqa: F401 from tilelang.layout import Layout, Fragment # noqa: F401
from .proxy import ( from .proxy import ptr, make_tensor # noqa: F401
ptr, # noqa: F401 from .v2.annot import (
make_tensor, # noqa: F401
Buffer, # noqa: F401 Buffer, # noqa: F401
Tensor, # noqa: F401 Tensor, # noqa: F401
StridedTensor, # noqa: F401 StridedTensor, # noqa: F401
FragmentBuffer, # noqa: F401 FragmentBuffer, # noqa: F401
SharedBuffer, # noqa: F401 SharedBuffer, # noqa: F401
LocalBuffer, # noqa: F401 LocalBuffer, # noqa: F401
Ref, # noqa: F401 dyn, # noqa: F401
) )
from .loop import ( from .loop import (
Parallel, # noqa: F401 Parallel, # noqa: F401
...@@ -56,6 +55,7 @@ from .allocate import ( ...@@ -56,6 +55,7 @@ from .allocate import (
alloc_wgmma_desc, # noqa: F401 alloc_wgmma_desc, # noqa: F401
alloc_tcgen05_smem_desc, # noqa: F401 alloc_tcgen05_smem_desc, # noqa: F401
alloc_tcgen05_instr_desc, # noqa: F401 alloc_tcgen05_instr_desc, # noqa: F401
empty, # noqa: F401
) )
from .copy import copy, c2d_im2col # noqa: F401 from .copy import copy, c2d_im2col # noqa: F401
from .gemm import GemmWarpPolicy, gemm, gemm_v1, gemm_v2 # noqa: F401 from .gemm import GemmWarpPolicy, gemm, gemm_v1, gemm_v2 # noqa: F401
......
...@@ -14,8 +14,7 @@ Each function takes shape and dtype parameters and returns a TVM buffer object ...@@ -14,8 +14,7 @@ Each function takes shape and dtype parameters and returns a TVM buffer object
with the appropriate memory scope. with the appropriate memory scope.
""" """
from __future__ import annotations from __future__ import annotations
from typing import TypeVarTuple, TypeVar, overload, Literal, Unpack, Callable
from typing import overload, Literal
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.script import tir as T from tvm.script import tir as T
from tvm.tir import PrimExpr from tvm.tir import PrimExpr
...@@ -23,9 +22,16 @@ from tvm.script.parser.tir import block_attr ...@@ -23,9 +22,16 @@ from tvm.script.parser.tir import block_attr
from tvm.tir.buffer import Buffer from tvm.tir.buffer import Buffer
from tvm.tir.expr import FloatImm, IntImm from tvm.tir.expr import FloatImm, IntImm
from .v2.dtypes import dtype as tl_dtype from .v2.dtypes import dtype as tl_dtype
from .v2.builder import OutTensor
from .v2.annot import Tensor, SharedBuffer, LocalBuffer, FragmentBuffer
_Shapes = TypeVarTuple('_Shapes')
_DType = TypeVar('_DType')
def alloc_shared(shape, dtype, scope="shared.dyn"): def alloc_shared(shape: tuple[Unpack[_Shapes]],
dtype: _DType,
scope="shared.dyn") -> SharedBuffer[Callable[[Unpack[_Shapes]]], _DType]:
"""Allocate a shared memory buffer for inter-thread communication. """Allocate a shared memory buffer for inter-thread communication.
Args: Args:
...@@ -43,7 +49,9 @@ def alloc_shared(shape, dtype, scope="shared.dyn"): ...@@ -43,7 +49,9 @@ def alloc_shared(shape, dtype, scope="shared.dyn"):
return T.alloc_buffer(shape, dtype, scope=scope) return T.alloc_buffer(shape, dtype, scope=scope)
def alloc_local(shape, dtype, scope="local"): def alloc_local(shape: tuple[Unpack[_Shapes]],
dtype: _DType,
scope="local") -> LocalBuffer[Callable[[Unpack[_Shapes]]], _DType]:
"""Allocate a local memory buffer for thread-private storage. """Allocate a local memory buffer for thread-private storage.
Args: Args:
...@@ -57,7 +65,9 @@ def alloc_local(shape, dtype, scope="local"): ...@@ -57,7 +65,9 @@ def alloc_local(shape, dtype, scope="local"):
return T.alloc_buffer(shape, dtype, scope=scope) return T.alloc_buffer(shape, dtype, scope=scope)
def alloc_fragment(shape, dtype, scope="local.fragment"): def alloc_fragment(shape: tuple[Unpack[_Shapes]],
dtype: _DType,
scope="local.fragment") -> FragmentBuffer[Callable[[Unpack[_Shapes]]], _DType]:
"""Allocate a fragment memory buffer for specialized operations. """Allocate a fragment memory buffer for specialized operations.
Args: Args:
...@@ -256,3 +266,21 @@ def alloc_tcgen05_instruction_desc(dtype: str = "uint32"): ...@@ -256,3 +266,21 @@ def alloc_tcgen05_instruction_desc(dtype: str = "uint32"):
# Alias: short name consistent with imports # Alias: short name consistent with imports
def alloc_tcgen05_instr_desc(dtype: str = "uint32"): def alloc_tcgen05_instr_desc(dtype: str = "uint32"):
return alloc_tcgen05_instruction_desc(dtype) return alloc_tcgen05_instruction_desc(dtype)
@overload
def empty(shape: tuple[Unpack[_Shapes]],
dtype: str = 'float32') -> Tensor[Callable[[Unpack[_Shapes]]], _DType]:
...
def empty(*shape: Unpack[_Shapes],
dtype: str = 'float32') -> Tensor[Callable[[Unpack[_Shapes]]], _DType]:
if len(shape) == 1 and isinstance(shape[0], (tuple, list)):
return OutTensor(shape[0], dtype)
elif len(shape) == 2 and isinstance(shape[0], (tuple, list)) and isinstance(shape[1], str):
return OutTensor(shape[0], shape[1])
elif all([isinstance(x, (int, PrimExpr)) for x in shape]):
return OutTensor(shape, dtype)
else:
raise RuntimeError(f'Invalid shape {shape}')
from .builder import prim_func, macro, PrimFunc # noqa: F401 from .builder import prim_func, macro, PrimFunc, PrimFuncCreater, Ref # noqa: F401
from .dtypes import * from .dtypes import *
This diff is collapsed.
...@@ -6,12 +6,18 @@ import inspect ...@@ -6,12 +6,18 @@ import inspect
from tilelang.language.kernel import KernelLaunchFrame from tilelang.language.kernel import KernelLaunchFrame
from tvm_ffi.container import Map from tvm_ffi.container import Map
from tvm.ir.base import Span from tvm.ir.base import Span
from tvm.ir.expr import Range
from tvm.tir.stmt import BufferRegion
from .ast import BaseBuilder, IRGenerator, eval_op, mutate from .ast import BaseBuilder, IRGenerator, eval_op, mutate
from .utils import construct_strides
import tvm import tvm
from tvm.tir import Buffer from tvm.tir import Buffer
from tvm.script.ir_builder import tir, IRBuilder from tvm.script.ir_builder import tir, IRBuilder
from tvm.tir.expr import EqualOp, FloatImm, IntImm, NotEqualOp, PrimExpr, StringImm, Var from tvm.tir.expr import BufferLoad, EqualOp, FloatImm, IntImm, NotEqualOp, PrimExpr, StringImm, Var
from typing import TYPE_CHECKING, Callable, Any, Generic, TypeVar, ForwardRef, Union from typing import TYPE_CHECKING, Callable, Any, Generic, TypeVar, ForwardRef, Union
from collections.abc import Sequence
from .annot import FuncAnnot, ArgVarTable, Annot
import pprint
# Python 3.9 compatibility for ParamSpec and Self # Python 3.9 compatibility for ParamSpec and Self
try: try:
from typing import ParamSpec, Self from typing import ParamSpec, Self
...@@ -31,7 +37,9 @@ def unwrap_expr(expr) -> PrimExpr | int | float: ...@@ -31,7 +37,9 @@ def unwrap_expr(expr) -> PrimExpr | int | float:
''' '''
if isinstance(expr, tir.meta_var): if isinstance(expr, tir.meta_var):
expr = expr.value expr = expr.value
elif isinstance(expr, Buffer) and expr.scope() == 'local.var': elif isinstance(expr, Ref):
return expr.load()
elif is_var(expr):
expr = tir.BufferLoad(expr, indices=[0]) expr = tir.BufferLoad(expr, indices=[0])
elif isinstance(expr, (EqualOp, NotEqualOp)): elif isinstance(expr, (EqualOp, NotEqualOp)):
expr = expr.asobject() expr = expr.asobject()
...@@ -113,6 +121,30 @@ class SerialForWithStep: ...@@ -113,6 +121,30 @@ class SerialForWithStep:
@dataclass @dataclass
class OutTensor:
shape: Sequence[PrimExpr]
dtype: dt.dtype
@property
def strides(self):
return construct_strides(tuple(self.shape))
@dataclass
class Ref:
bufload: BufferLoad
@property
def buffer(self):
return self.bufload.buffer
def store(self, value):
tir.buffer_store(self.bufload.buffer, value, self.bufload.indices)
def load(self):
return self.bufload
class UnrollForWithStep(SerialForWithStep): class UnrollForWithStep(SerialForWithStep):
... ...
...@@ -145,11 +177,15 @@ def is_var(v: Any) -> bool: ...@@ -145,11 +177,15 @@ def is_var(v: Any) -> bool:
class Builder(BaseBuilder): class Builder(BaseBuilder):
def __init__(self): def __init__(self, func_annot: FuncAnnot = None):
self.frames: list[AnyFrame] = [] self.frames: list[AnyFrame] = []
self.ir_builder = IRBuilder() self.ir_builder = IRBuilder()
self.name_inside_frame: dict[str, AnyFrame] = {} self.name_inside_frame: dict[str, AnyFrame] = {}
self.arg_annotations = {} self.macro_arg_annot = {}
self.func_annot = func_annot
self.out_idx = []
self.out_tensor_cnt = 0
self.arg_vt = ArgVarTable()
@classmethod @classmethod
def current(cls) -> Self: def current(cls) -> Self:
...@@ -162,6 +198,8 @@ class Builder(BaseBuilder): ...@@ -162,6 +198,8 @@ class Builder(BaseBuilder):
with self.ir_builder, self.with_frame(tir.prim_func()): with self.ir_builder, self.with_frame(tir.prim_func()):
tir.func_name(name) tir.func_name(name)
yield yield
if len(self.out_idx) != self.out_tensor_cnt:
raise RuntimeError('Not all tensor allocated from `T.empty` are returned')
@contextmanager @contextmanager
def macro(self, name=None, annotations=None): def macro(self, name=None, annotations=None):
...@@ -169,9 +207,9 @@ class Builder(BaseBuilder): ...@@ -169,9 +207,9 @@ class Builder(BaseBuilder):
raise RuntimeError( raise RuntimeError(
f"Macro `{name}` is used inside boolean expressions, " f"Macro `{name}` is used inside boolean expressions, "
"please use `if` to replace `M and M`, `M or M`, `M if xxx else M` constructs") "please use `if` to replace `M and M`, `M or M`, `M if xxx else M` constructs")
save = self.name_inside_frame, self.arg_annotations save = self.name_inside_frame, self.macro_arg_annot
self.name_inside_frame = {} self.name_inside_frame = {}
self.arg_annotations = annotations or {} self.macro_arg_annot = annotations or {}
pos = len(self.frames) pos = len(self.frames)
# here we add a ExitedMacroFrame to preserve the frame stack inside macro # here we add a ExitedMacroFrame to preserve the frame stack inside macro
# because macro may bind some variable, and return it # because macro may bind some variable, and return it
...@@ -188,7 +226,7 @@ class Builder(BaseBuilder): ...@@ -188,7 +226,7 @@ class Builder(BaseBuilder):
self.frames.append(MacroFrame()) self.frames.append(MacroFrame())
yield yield
self.frames[pos] = ExitedMacroFrame() self.frames[pos] = ExitedMacroFrame()
self.name_inside_frame, self.arg_annotations = save self.name_inside_frame, self.macro_arg_annot = save
def get(self): def get(self):
return self.ir_builder.get() return self.ir_builder.get()
...@@ -269,8 +307,11 @@ class Builder(BaseBuilder): ...@@ -269,8 +307,11 @@ class Builder(BaseBuilder):
pass pass
elif isinstance(val, tvm.tir.stmt.BufferStore): elif isinstance(val, tvm.tir.stmt.BufferStore):
tir.buffer_store(val.buffer, val.value, val.indices, val.predicate) tir.buffer_store(val.buffer, val.value, val.indices, val.predicate)
elif not isinstance(val, tvm.tir.Buffer): elif isinstance(val, (Buffer, Var)):
raise TypeError(f"Unsupported eval value: {val} of type {type(val)}") pass
else:
logger.warning(
f"Unused return value: {val}({type(val)})", stack_info=True, stacklevel=2)
def ctx_for(self, it): def ctx_for(self, it):
self.check_continue_break() self.check_continue_break()
...@@ -355,10 +396,26 @@ class Builder(BaseBuilder): ...@@ -355,10 +396,26 @@ class Builder(BaseBuilder):
# c = tl.alloc_var('float32') # bind var `c` # c = tl.alloc_var('float32') # bind var `c`
# c = a # get and assign `c[0] = a_1[0]` # c = a # get and assign `c[0] = a_1[0]`
# ``` # ```
if isinstance(orig_value, Ref) and isinstance(value, (int, float, PrimExpr)):
orig_value.store(value)
return orig_value
if is_var(orig_value) and isinstance(value, (int, float, PrimExpr)): if is_var(orig_value) and isinstance(value, (int, float, PrimExpr)):
tir.buffer_store(orig_value, value, 0) tir.buffer_store(orig_value, value, 0)
return orig_value return orig_value
# 2. Quick return for trivil types
if isinstance(value, (tuple, list, tvm.ffi.Array, int, float, str)):
return value
if isinstance(value, tir.IntImm) and value.dtype == 'int32':
return value.value
if isinstance(value, (Var, Buffer)):
IRBuilder.name(name, value)
return value
# 3. Bind immutable tilelang objects
res = self.bind_immutable(name, value) res = self.bind_immutable(name, value)
# 4. Check variable scope and shadowing
if name != '_': if name != '_':
frame = self.find_frame_idx(TIR_VAR_SCOPE_FRAME) frame = self.find_frame_idx(TIR_VAR_SCOPE_FRAME)
assert frame is not None, f"Variable `{name}` is not defined inside any control flow." assert frame is not None, f"Variable `{name}` is not defined inside any control flow."
...@@ -372,6 +429,9 @@ class Builder(BaseBuilder): ...@@ -372,6 +429,9 @@ class Builder(BaseBuilder):
return res return res
def unwrap_value(self, value): def unwrap_value(self, value):
'''
Unwrap some tilelang objects to get their inner value
'''
value = unwrap_expr(value) value = unwrap_expr(value)
# handle bx, by = tl.Kernel(128, 128), rval is frame # handle bx, by = tl.Kernel(128, 128), rval is frame
if isinstance(value, tir.frame.IRBuilderFrame): if isinstance(value, tir.frame.IRBuilderFrame):
...@@ -380,6 +440,10 @@ class Builder(BaseBuilder): ...@@ -380,6 +440,10 @@ class Builder(BaseBuilder):
return value return value
def bind_immutable(self, name, value): def bind_immutable(self, name, value):
'''
Bind an immutable tilelang objects.
The immutability means the result is usually not changed or re-assigned in a python block.
'''
if name == '_': if name == '_':
# use _tmp to make the generated tir more readable # use _tmp to make the generated tir more readable
name = "_tmp" name = "_tmp"
...@@ -393,11 +457,19 @@ class Builder(BaseBuilder): ...@@ -393,11 +457,19 @@ class Builder(BaseBuilder):
stacklevel=2, stacklevel=2,
) )
return self.enter_frame(value) return self.enter_frame(value)
elif isinstance(value, OutTensor):
arg = tir.arg(name,
tir.buffer(
shape=value.shape,
dtype=value.dtype,
strides=value.strides,
))
arg._out_idx = self.out_tensor_cnt
self.out_tensor_cnt += 1
return arg
elif isinstance(value, (Buffer, tir.IterVar, tir.Var)): elif isinstance(value, (Buffer, tir.IterVar, tir.Var)):
IRBuilder.name(name, value) IRBuilder.name(name, value)
return value return value
elif isinstance(value, (tuple, list, tvm.ffi.Array)):
return value
else: else:
try: try:
value = tvm.runtime.convert(value) value = tvm.runtime.convert(value)
...@@ -420,7 +492,10 @@ class Builder(BaseBuilder): ...@@ -420,7 +492,10 @@ class Builder(BaseBuilder):
def aug_assign(self, op, target, aug_value): def aug_assign(self, op, target, aug_value):
self.check_continue_break() self.check_continue_break()
if is_var(target): if isinstance(target, Ref):
target.store(eval_op(op, target.bufload, aug_value))
return target
elif is_var(target):
tir.buffer_store(target, eval_op(op, target[0], aug_value), 0) tir.buffer_store(target, eval_op(op, target[0], aug_value), 0)
return target return target
elif isinstance(target, Buffer): elif isinstance(target, Buffer):
...@@ -457,10 +532,15 @@ class Builder(BaseBuilder): ...@@ -457,10 +532,15 @@ class Builder(BaseBuilder):
else: else:
return super().ifexp(cond, then, otherwise) return super().ifexp(cond, then, otherwise)
def ret(self, value): def ret(self, value=None):
self.check_continue_break() self.check_continue_break()
# handle return T.alloc_var() # handle return T.alloc_var()
value = self.unwrap_value(value) if value is None:
value = tuple()
elif isinstance(value, tuple):
value = tuple(self.unwrap_value(v) for v in value)
else:
value = self.unwrap_value(value)
last_macro = self.find_frame_idx(MacroFrame) last_macro = self.find_frame_idx(MacroFrame)
if last_macro is not None: if last_macro is not None:
frame = self.find_frame_idx(TIR_CONTROL_FRAME, start=last_macro) frame = self.find_frame_idx(TIR_CONTROL_FRAME, start=last_macro)
...@@ -478,7 +558,20 @@ class Builder(BaseBuilder): ...@@ -478,7 +558,20 @@ class Builder(BaseBuilder):
" return a\n" " return a\n"
"```" "```"
) )
return value return value
else:
if not isinstance(value, tuple):
value = (value,)
for v in value:
if not isinstance(v, Buffer) or not hasattr(v, '_out_idx'):
raise RuntimeError(
f'Only tensor allocated from `T.empty` can be returned in a prim_func, got {v}({type(v)})'
)
# convert 0, 1, 2 => -3, -2, -1 as the out tensor index
self.out_idx.append(v._out_idx - self.out_tensor_cnt)
if len(self.out_idx) != self.out_tensor_cnt:
raise RuntimeError(f'Not all tensor from `T.empty` are returned, only got {value}')
return NotImplemented
def ctx_with(self, ctx): def ctx_with(self, ctx):
self.check_continue_break() self.check_continue_break()
...@@ -487,9 +580,11 @@ class Builder(BaseBuilder): ...@@ -487,9 +580,11 @@ class Builder(BaseBuilder):
else: else:
return super().ctx_with(ctx) return super().ctx_with(ctx)
def assert_expr(self, cond, msg): def assert_expr(self, cond, msg=None):
self.check_continue_break() self.check_continue_break()
cond = unwrap_cond(cond) cond = unwrap_cond(cond)
if msg is None:
msg = 'Assertion failed'
if isinstance(cond, PrimExpr): if isinstance(cond, PrimExpr):
self.enter_frame(tir.Assert(cond, msg)) self.enter_frame(tir.Assert(cond, msg))
elif not cond: elif not cond:
...@@ -506,30 +601,41 @@ class Builder(BaseBuilder): ...@@ -506,30 +601,41 @@ class Builder(BaseBuilder):
return self.unwrap_value(value) return self.unwrap_value(value)
def macro_arg(self, name, value): def macro_arg(self, name, value):
from tilelang.language.proxy import Ref annot_value = self.macro_arg_annot.get(name, None)
annot_value = self.arg_annotations.get(name, None)
if annot_value is Var or annot_value is Ref: if annot_value is Var or annot_value is Ref:
if annot_value is Var: if annot_value is Var:
logger.warning('Use `T.Var` as macro annotations is deprecated, please use `T.Ref`') logger.warning('Use `T.Var` as macro annotations is deprecated, please use `T.Ref`')
is_var = isinstance(value, tvm.tir.BufferLoad) and value.buffer.scope() == 'local.var' if isinstance(value, BufferLoad):
if not is_var: if is_var(value.buffer):
raise ValueError( return value.buffer
f'Argument `{name}` is expected to be a variable allocated by `T.alloc_var`, but got {value}({type(value)})' idx = [self.bind('_', idx) for idx in value.indices]
) # indices = self.bind(f'_', value.indices)
return value.buffer return Ref(BufferLoad(value.buffer, indices=idx))
if isinstance(value, BufferRegion):
region = [
Range(
self.bind('_', x.begin),
end=self.bind('_', x.end) if x.end is not None else None)
for x in value.region
]
return BufferRegion(value.buffer, region=region)
raise ValueError(
f'To pass as reference, argument `{name}` is expected to be a variable or a buffer region, but got {value}({type(value)})'
)
elif isinstance(value, (PrimExpr, int, float)): elif isinstance(value, (PrimExpr, int, float)):
return self.bind(name, value) return self.bind(name, value)
else: else:
return value return value
def prim_func_arg(self, name, value): def prim_func_arg(self, name, value):
if isinstance(value, (Buffer, Var)): return self.func_annot.create_argument(name, value, self.arg_vt)
return tir.arg(name, value) # if isinstance(value, (Buffer, Var)):
elif value is self.empty: # return tir.arg(name, value)
raise ValueError(f'Argument `{name}` is not annotated') # elif value is self.empty:
else: # raise ValueError(f'Argument `{name}` is not annotated')
raise TypeError( # else:
f"Unsupported argument type: {value}({type(value)}) for argument `{name}`.") # raise TypeError(
# f"Unsupported argument type: {value}({type(value)}) for argument `{name}`.")
def arg(self, name, value): def arg(self, name, value):
if self.find_frame_idx(MacroFrame) is not None: if self.find_frame_idx(MacroFrame) is not None:
...@@ -547,6 +653,39 @@ class Builder(BaseBuilder): ...@@ -547,6 +653,39 @@ class Builder(BaseBuilder):
_P = ParamSpec('_P') _P = ParamSpec('_P')
_T = TypeVar('_T') _T = TypeVar('_T')
@dataclass
class PrimFuncCreater(Generic[_P, _T]):
func_annot: FuncAnnot
ir_gen: IRGenerator[_P, _T]
orig_func: Callable[_P, _T]
@property
def annot(self) -> dict[str, Annot]:
return self.func_annot.annots
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> PrimFunc[_P, _T]:
builder = Builder(self.func_annot)
with builder.prim_func(self.orig_func.__name__):
self.ir_gen.gen(builder)(*args, **kwargs)
res: PrimFunc = builder.get()
res.ir_gen = self.ir_gen
res.orig_func = self.orig_func
res.func_annot = self.func_annot
res.out_idx_override = builder.out_idx or None
return res
def __repr__(self):
fmt = pprint.pformat(
{
'annot': self.func_annot.annots,
'ir_gen': self.ir_gen,
'orig_func': self.orig_func
},
indent=2)
return f'{self.__class__.__name__}(\n{fmt}\n)'
if TYPE_CHECKING: if TYPE_CHECKING:
class PrimFunc(Generic[_P, _T], tvm.tir.PrimFunc): class PrimFunc(Generic[_P, _T], tvm.tir.PrimFunc):
...@@ -557,8 +696,10 @@ if TYPE_CHECKING: ...@@ -557,8 +696,10 @@ if TYPE_CHECKING:
attrs: tvm.Attrs | None attrs: tvm.Attrs | None
span: Span | None span: Span | None
ir_gen: IRGenerator[_P, _T] | None ir_gen: IRGenerator[_P, _T] | None
source: str | None
orig_func: Callable[_P, _T] | None orig_func: Callable[_P, _T] | None
func_annot: FuncAnnot | None
out_idx_override: list[int] | None
else: else:
PrimFunc = tvm.tir.PrimFunc PrimFunc = tvm.tir.PrimFunc
...@@ -580,6 +721,12 @@ class Macro(Generic[_P, _T]): ...@@ -580,6 +721,12 @@ class Macro(Generic[_P, _T]):
res = self.ir_gen.gen(builder)(*args, **kwargs) res = self.ir_gen.gen(builder)(*args, **kwargs)
return res return res
def __hash__(self):
return id(self)
def __eq__(self, other):
return id(self) == id(other)
def macro(func: Callable[_P, _T] = None) -> Macro[_P, _T]: def macro(func: Callable[_P, _T] = None) -> Macro[_P, _T]:
""" """
...@@ -683,13 +830,9 @@ def get_type_hints(func): ...@@ -683,13 +830,9 @@ def get_type_hints(func):
return hints return hints
def _is_static_annot(annot: Any) -> bool:
return isinstance(annot, (dt.dtype, Buffer, Var))
def prim_func(func: Callable[_P, _T] = None, def prim_func(func: Callable[_P, _T] = None,
*, *,
generator: bool = False) -> PrimFunc[_P, _T] | Callable[_P, PrimFunc[_P, _T]]: generator: bool = False) -> PrimFunc[_P, _T] | PrimFuncCreater[_P, _T]:
""" """
Decorator to create a primitive function (PrimFunc) for TileLang IR generation. Decorator to create a primitive function (PrimFunc) for TileLang IR generation.
This decorator transforms a Python function into a TileLang primitive function by analyzing This decorator transforms a Python function into a TileLang primitive function by analyzing
...@@ -739,45 +882,21 @@ def prim_func(func: Callable[_P, _T] = None, ...@@ -739,45 +882,21 @@ def prim_func(func: Callable[_P, _T] = None,
sig = inspect.signature(func) sig = inspect.signature(func)
annot = get_type_hints(func) annot = get_type_hints(func)
for k in annot: func_annot = FuncAnnot.from_sig_annots(sig, annot)
if callable(annot[k]):
annot[k] = annot[k]()
# check whether all arguments are annotated
all_arg_annotated = all([x in annot for x in sig.parameters])
# check whether all annotations are Buffer/Var/dtype
all_annot_are_static = all([_is_static_annot(x) for x in annot.values()])
ir_gen = mutate(func) ir_gen = mutate(func)
def prim_func_generator(*args, **kwargs): prim_func_generator = PrimFuncCreater(func_annot, ir_gen, orig_func=func)
builder = Builder()
with builder.prim_func(func.__name__):
ir_gen.gen(builder)(*args, **kwargs)
res = builder.get()
res.ir_gen = ir_gen
res.source = ir_gen.source
res.orig_func = func
return res
prim_func_generator.ir_gen = ir_gen
prim_func_generator.source = ir_gen.source
prim_func_generator.orig_func = func
if generator:
return prim_func_generator
if all_arg_annotated and all_annot_are_static: if func_annot.is_all_static():
return prim_func_generator(**annot) args = func_annot.get_all_static_args()
return prim_func_generator(**args)
else: else:
raise ValueError( if generator is False:
"Some arguments are not supported or statically annotated, \n" unknown_args = func_annot.get_compile_time_unknown_args()
"please check the annotations or set generator=True to get a prim_func generator.\n" raise ValueError(
f"Argument Annotations: {annot}\n" f"Cannot create PrimFunc for `{func.__name__}`, some arguments are not compile-time known, \n"
"Example usage of generator:\n" f"Annotations:\n{func_annot.annots}"
"```py\n" f"Unknown Args: {unknown_args}")
"@prim_func(generator=True)\n" return prim_func_generator
"def my_func(a=T.Tensor((128,), T.float32)): ...\n"
"return my_func()\n"
"```")
return impl(func) if func is not None else impl return impl(func) if func is not None else impl
from tilelang import tvm from tilelang import tvm
from tvm import ir from tvm import ir
import torch import torch
from typing import TYPE_CHECKING, Union from typing import Generic, TypeVar, Union, TYPE_CHECKING
from tvm import tir from tvm import tir
import tvm.script.ir_builder.tir._ffi_api as tb_ffi import tvm.script.ir_builder.tir._ffi_api as tb_ffi
import numpy as np import numpy as np
dtype = tvm.DataType _T = TypeVar('_T')
if TYPE_CHECKING:
class dtype(Generic[_T]):
def torch(self) -> torch.dtype:
...
else:
dtype = tvm.DataType
# Python 3.9 compatibility: avoid PEP 604 unions at runtime # Python 3.9 compatibility: avoid PEP 604 unions at runtime
AnyDType = Union[ir.Type, str, type, torch.dtype, dtype] AnyDType = Union[ir.Type, str, type, torch.dtype, dtype]
......
...@@ -4,6 +4,7 @@ import inspect ...@@ -4,6 +4,7 @@ import inspect
from typing import Any, Callable, Literal from typing import Any, Callable, Literal
from tilelang import env from tilelang import env
from hashlib import sha256 from hashlib import sha256
from tvm import tir
import linecache import linecache
...@@ -84,3 +85,17 @@ def get_compiled_object(source: str | ast.AST, ...@@ -84,3 +85,17 @@ def get_compiled_object(source: str | ast.AST,
locs = {} locs = {}
exec(compiled, globals, locs) exec(compiled, globals, locs)
return locs[name] return locs[name]
def construct_strides(shape: tuple[Any, ...], allow_prim_expr: bool = True) -> tuple[Any, ...]:
"""Construct row-major strides from shape."""
strides = []
stride = 1
for s in shape[::-1]:
strides.append(stride)
stride *= s
if not allow_prim_expr and isinstance(stride, tir.PrimExpr):
raise ValueError(
"Cannot construct strides with PrimExpr when allow_prim_expr is False.")
strides = tuple(reversed(strides))
return strides
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment