Unverified Commit 5f202fe5 authored by Kurisu's avatar Kurisu Committed by GitHub
Browse files

[Language] Initial version of tilelang frontend v2 (#1120)



* tilelang frontend v2

* syntax sugar: defining a local var by annotation

* [Refactor] fix type linting warning like `T.float32`

* Add tl.local_var_init for new tl.float32

* allow passing default argument as function annotation

* allow default arguments as annotation

* fix lint error

* minor fix

* [Refactor] refactor tilelang.jit and tilelang.autotune

* minor fix

* minor fix

* minor fix

* fix metal get function name

* add par_compile impl and tests

* Type consistency on tvm datatype
1. isinstance(tl.float32, tvm.DataType) == True
2. Allow `tl.float32` as function annotations
3. Allow `tl.float32` as argument to be passed to `tl.alloc` or other functions

* fix lint error

* add more warning in frontend

* update tvm version

* Minor fix on tvm_ffi annotations

* add document and examples

* fix lint error

* Simplify index calculations in example_chunk_o_bwd.py

Refactor index calculations for dg_last_fragment assignment.

* minor fix

* lint fix

---------
Co-authored-by: default avatarLei Wang <leiwang1999@outlook.com>
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent ba390756
...@@ -7,8 +7,6 @@ import tilelang ...@@ -7,8 +7,6 @@ import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401 from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401
print(tilelang.__file__)
# Add your fla repository path to sys.path # Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae # Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention") # sys.path.insert(0, "/home/tzj/flash-linear-attention")
...@@ -256,8 +254,9 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -256,8 +254,9 @@ def tilelang_chunk_o_bwd_dqkwg(
# for i_kv in T.Parallel(block_DK * block_DV): # for i_kv in T.Parallel(block_DK * block_DV):
# dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV] # dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
for i_kv in T.Parallel(block_DK * block_DV): for i_kv in T.Parallel(block_DK * block_DV):
i_k, i_v = i_kv // block_DV, i_kv % block_DV dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv %
dg_last_fragment[i_kv] = h_shared[i_k, i_v] * dh_shared[i_k, i_v] block_DV] * dh_shared[i_kv // block_DV,
i_kv % block_DV]
T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False) T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False)
dg_last_local[0] += dg_last_fragment_scalar[0] dg_last_local[0] += dg_last_fragment_scalar[0]
......
import tilelang.testing
import tilelang
import torch
@tilelang.jit(
out_idx=-1, # create the output tensor during runtime
verbose=True,
)
def matmul_kernel_jit(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A=False,
trans_B=True,
in_dtype='float16',
out_dtype='float32',
accum_dtype='float32',
num_stages=2,
threads=128,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def test_par_compile():
configs = [
(1024, 1024, 1024, 128, 128, 32),
(2048, 2048, 2048, 256, 256, 64),
(4096, 4096, 4096, 64, 64, 128),
]
kernels = matmul_kernel_jit.par_compile(configs)
for (M, N, K, _, _, _), kernel in zip(configs, kernels):
A = torch.randn(M, K, dtype=torch.float16).cuda()
B = torch.randn(N, K, dtype=torch.float16).cuda()
ref = (A @ B.T).float()
C = kernel(A, B)
tilelang.testing.torch_assert_close(C, ref, rtol=1e-2, atol=1e-2)
if __name__ == "__main__":
tilelang.testing.main()
import tilelang
import tilelang.language as T
import torch
import tilelang.testing
import tvm
def test_argument():
@T.prim_func
def test_argument(
t_1: T.bool,
t_2: T.short,
t_3: T.int,
t_4: T.long,
t_5: T.half,
t_6: T.float,
t_7: T.long,
t_8: T.int8,
t_9: T.int16,
t_10: T.int32,
t_11: T.int64,
t_12: T.uint8,
t_13: T.uint16,
t_14: T.uint32,
t_15: T.uint64,
t_16: T.float8_e4m3fn,
t_17: T.float8_e4m3fnuz,
t_18: T.float8_e5m2,
t_19: T.float8_e5m2fnuz,
t_20: T.float8_e8m0fnu,
t_21: T.float16,
t_22: T.bfloat16,
t_23: T.float32,
t_24: T.float64,
):
pass
def test_expr():
from tilelang.language.v2.dtypes import _all_dtypes
errors = []
for name in _all_dtypes:
dtype = getattr(T, name)
assert isinstance(dtype, tvm.DataType), f"{dtype} is not tvm.DataType"
try:
dtype(1.0)
dtype()
except TypeError:
pass
except Exception:
errors.append(name)
assert not errors
# def test_var_decl_sugar():
# @T.prim_func
# def test_var_decl_sugar():
# with T.Kernel(128, 128) as (bx, by):
# var_1: T.bool = 1.0
# var_2: T.short = 1.0
# var_3: T.int = 1.0
# var_4: T.long = 1.0
# var_5: T.half = 1.0
# var_6: T.float = 1.0
# var_7: T.long = 1.0
# var_8: T.int8 = 1.0
# var_9: T.int16 = 1.0
# var_10: T.int32 = 1.0
# var_11: T.int64 = 1.0
# var_12: T.uint8 = 1.0
# var_13: T.uint16 = 1.0
# var_14: T.uint32 = 1.0
# var_15: T.uint64 = 1.0
# var_16: T.float8_e4m3fn = 1.0
# var_17: T.float8_e4m3fnuz = 1.0
# var_18: T.float8_e5m2 = 1.0
# var_19: T.float8_e5m2fnuz = 1.0
# var_20: T.float8_e8m0fnu = 1.0
# var_21: T.float16 = 1.0
# var_22: T.bfloat16 = 1.0
# var_23: T.float32 = 1.0
# var_24: T.float64 = 1.0
# var_1: T.bool = var_1
# var_2: T.short = var_2
# var_3: T.int = var_3
# var_4: T.long = var_4
# var_5: T.half = var_5
# var_6: T.float = var_6
# var_7: T.long = var_7
# var_8: T.int8 = var_8
# var_9: T.int16 = var_9
# var_10: T.int32 = var_10
# var_11: T.int64 = var_11
# var_12: T.uint8 = var_12
# var_13: T.uint16 = var_13
# var_14: T.uint32 = var_14
# var_15: T.uint64 = var_15
# var_16: T.float8_e4m3fn = var_16
# var_17: T.float8_e4m3fnuz = var_17
# var_18: T.float8_e5m2 = var_18
# var_19: T.float8_e5m2fnuz = var_19
# var_20: T.float8_e8m0fnu = var_20
# var_21: T.float16 = var_21
# var_22: T.bfloat16 = var_22
# var_23: T.float32 = var_23
# var_24: T.float64 = var_24
# s = test_var_decl_sugar.script()
# for i in range(1, 25):
# assert f'var_{i}_1' in s
# assert 'tl.local_var_init' in s
def test_dtype_str_repr():
@T.prim_func
def test_str_repr():
buf_1 = T.alloc_buffer((1,), dtype=T.bool, scope='shared') # noqa F841
buf_2 = T.alloc_buffer((1,), dtype=T.short, scope='shared') # noqa F841
buf_3 = T.alloc_buffer((1,), dtype=T.int, scope='shared') # noqa F841
buf_4 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841
buf_5 = T.alloc_buffer((1,), dtype=T.half, scope='shared') # noqa F841
buf_6 = T.alloc_buffer((1,), dtype=T.float, scope='shared') # noqa F841
buf_7 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841
buf_8 = T.alloc_buffer((1,), dtype=T.int8, scope='shared') # noqa F841
buf_9 = T.alloc_buffer((1,), dtype=T.int16, scope='shared') # noqa F841
buf_10 = T.alloc_buffer((1,), dtype=T.int32, scope='shared') # noqa F841
buf_11 = T.alloc_buffer((1,), dtype=T.int64, scope='shared') # noqa F841
buf_12 = T.alloc_buffer((1,), dtype=T.uint8, scope='shared') # noqa F841
buf_13 = T.alloc_buffer((1,), dtype=T.uint16, scope='shared') # noqa F841
buf_14 = T.alloc_buffer((1,), dtype=T.uint32, scope='shared') # noqa F841
buf_15 = T.alloc_buffer((1,), dtype=T.uint64, scope='shared') # noqa F841
buf_16 = T.alloc_buffer((1,), dtype=T.float8_e4m3fn, scope='shared') # noqa F841
buf_17 = T.alloc_buffer((1,), dtype=T.float8_e4m3fnuz, scope='shared') # noqa F841
buf_18 = T.alloc_buffer((1,), dtype=T.float8_e5m2, scope='shared') # noqa F841
buf_19 = T.alloc_buffer((1,), dtype=T.float8_e5m2fnuz, scope='shared') # noqa F841
buf_20 = T.alloc_buffer((1,), dtype=T.float8_e8m0fnu, scope='shared') # noqa F841
buf_21 = T.alloc_buffer((1,), dtype=T.float16, scope='shared') # noqa F841
buf_22 = T.alloc_buffer((1,), dtype=T.bfloat16, scope='shared') # noqa F841
buf_23 = T.alloc_buffer((1,), dtype=T.float32, scope='shared') # noqa F841
buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared') # noqa F841
def test_torch_eq():
dtypes = [
T.bool,
T.short,
T.int,
T.long,
T.half,
T.float,
T.long,
T.int8,
T.int16,
T.int32,
T.int64,
T.uint8,
T.uint16,
T.uint32,
T.uint64,
T.float8_e4m3fn,
T.float8_e4m3fnuz,
T.float8_e5m2,
T.float8_e5m2fnuz,
T.float8_e8m0fnu,
T.float16,
T.bfloat16,
T.float32,
T.float64,
]
torch_dtypes = [
torch.bool,
torch.short,
torch.int,
torch.long,
torch.half,
torch.float,
torch.long,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
]
for a, b in zip(dtypes, torch_dtypes):
assert a == b, f"{a} and {b} are not equal"
assert T.dtype(b) == a, "dtype conversion error"
def test_var_assign():
@tilelang.jit(out_idx=-1)
@T.prim_func
def test_var_assign(A: T.Tensor((2,), T.int32)):
with T.Kernel(1) as _:
a: T.int32 = 1
b: T.int32 = a
a = 2
d: T.int32 = a
A[0] = b
A[1] = d
res = test_var_assign()()
assert res[0] == 1
assert res[1] == 2
def test_marco_return():
@T.macro
def macro_return_constant():
return 0
@T.macro
def macro_return_frame(x):
return T.alloc_var(T.float32, init=x)
@T.macro
def macro_return_expr(x):
y = x + 1.0
return y
@T.macro
def macro_apply_func(x, fn):
return fn(x)
def check(x, ty):
assert isinstance(x, ty)
@T.prim_func
def test_macro_return():
with T.Kernel(1) as _:
a = macro_return_constant()
b = macro_return_frame(3.0)
c = macro_return_expr(4.0)
d = macro_apply_func(5.0, lambda x: x * 2.0)
check(a, (int, float, T.PrimExpr))
check(b, T.PrimExpr)
check(c, T.PrimExpr)
check(d, T.PrimExpr)
def test_prim_func_generator():
@T.prim_func(generator=True)
def prim_func_gen(
A=T.Tensor((128,), T.float32), # noqa: B008
B=T.Tensor((128,), T.float32), # noqa: B008
):
with T.Kernel(128) as (tx,):
T.copy(A[tx], B[tx])
prim_func_gen()
@T.prim_func
def foo() -> T.Tensor((128,), T.float32):
pass
assert isinstance(foo, T.PrimFunc)
if __name__ == '__main__':
tilelang.testing.main()
...@@ -11,7 +11,7 @@ def test_let_vectorize_load(): ...@@ -11,7 +11,7 @@ def test_let_vectorize_load():
for _blockIdx in T.thread_binding(1, thread="blockIdx.x"): for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
for _threadIdx in T.thread_binding(128, thread="threadIdx.x"): for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
b: T.float32x4 = A[0, 0:4] b = A[0, 0:4]
A[0, 4:8] = b A[0, 4:8] = b
mod = tvm.IRModule({"main": main}) mod = tvm.IRModule({"main": main})
......
...@@ -15,8 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): ...@@ -15,8 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
N = tvm.te.var("n") N = tvm.te.var("n")
K = tvm.te.var("k") K = tvm.te.var("k")
@tvm.script.ir.ir_module def before():
class Before:
@T.prim_func @T.prim_func
def main(B: T.Tensor((K, N), dtype),): def main(B: T.Tensor((K, N), dtype),):
...@@ -38,8 +37,9 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): ...@@ -38,8 +37,9 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
(block_N // vec_load_b) * (block_N // vec_load_b) + vec], (block_N // vec_load_b) * (block_N // vec_load_b) + vec],
T.float16(0)) T.float16(0))
@tvm.script.ir.ir_module return tvm.IRModule({'main': main})
class After:
def after():
@T.prim_func @T.prim_func
def main(B: T.Tensor((K, N), dtype),): def main(B: T.Tensor((K, N), dtype),):
...@@ -77,11 +77,13 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): ...@@ -77,11 +77,13 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
bx * block_N + t % (block_N // vec_load_b) * bx * block_N + t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec], T.float16(0)) (block_N // vec_load_b) + vec], T.float16(0))
return tvm.IRModule({'main': main})
with tvm.target.Target(auto_target): with tvm.target.Target(auto_target):
mod = tvm.tir.transform.BindTarget(auto_target)(Before) mod = tvm.tir.transform.BindTarget(auto_target)(before())
mod = tl.transform.LayoutInference()(mod) mod = tl.transform.LayoutInference()(mod)
mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.Simplify()(mod)
ref_mod = tvm.tir.transform.BindTarget(auto_target)(After) ref_mod = tvm.tir.transform.BindTarget(auto_target)(after())
ref_mod = tvm.tir.transform.Simplify()(ref_mod) ref_mod = tvm.tir.transform.Simplify()(ref_mod)
# Note(tzj): The structures are equal except one more "for" loop after the LayoutInference pass # Note(tzj): The structures are equal except one more "for" loop after the LayoutInference pass
# This loop is "for vec in T.parallel(1)", # This loop is "for vec in T.parallel(1)",
......
...@@ -15,8 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): ...@@ -15,8 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
N = tvm.te.var("n") N = tvm.te.var("n")
K = tvm.te.var("k") K = tvm.te.var("k")
@tvm.script.ir.ir_module def before():
class Before:
@T.prim_func @T.prim_func
def main(B: T.Tensor((K, N), dtype),): def main(B: T.Tensor((K, N), dtype),):
...@@ -25,8 +24,9 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): ...@@ -25,8 +24,9 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(B[k * block_K, bx * block_N], B_shared) T.copy(B[k * block_K, bx * block_N], B_shared)
@tvm.script.ir.ir_module return tvm.IRModule({'main': main})
class After:
def after():
@T.prim_func @T.prim_func
def main(B: T.Tensor((K, N), dtype),): def main(B: T.Tensor((K, N), dtype),):
...@@ -64,11 +64,13 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): ...@@ -64,11 +64,13 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
bx * block_N + t % (block_N // vec_load_b) * bx * block_N + t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec], T.float16(0)) (block_N // vec_load_b) + vec], T.float16(0))
return tvm.IRModule({'main': main})
with tvm.transform.PassContext(): with tvm.transform.PassContext():
mod = tvm.tir.transform.BindTarget(auto_target)(Before) mod = tvm.tir.transform.BindTarget(auto_target)(before())
mod = tl.transform.LowerTileOp()(mod) mod = tl.transform.LowerTileOp()(mod)
mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.Simplify()(mod)
ref_mod = tvm.tir.transform.BindTarget(auto_target)(After) ref_mod = tvm.tir.transform.BindTarget(auto_target)(after())
ref_mod = tvm.tir.transform.Simplify()(ref_mod) ref_mod = tvm.tir.transform.Simplify()(ref_mod)
# Note(tzj): The structures are equal except the argument in "T.reads" function. # Note(tzj): The structures are equal except the argument in "T.reads" function.
# The difference is just between the first index and the indices range, which is totally equivalent # The difference is just between the first index and the indices range, which is totally equivalent
......
...@@ -113,7 +113,7 @@ def test_multi_version_buffer_with_let(): ...@@ -113,7 +113,7 @@ def test_multi_version_buffer_with_let():
shared = T.alloc_buffer((8,), "float32", scope="shared.dyn") shared = T.alloc_buffer((8,), "float32", scope="shared.dyn")
accum = T.alloc_buffer((8,), "float32", scope="local") accum = T.alloc_buffer((8,), "float32", scope="local")
for k in T.serial(4, annotations={"num_stages": T.int32(2)}): for k in T.serial(4, annotations={"num_stages": T.int32(2)}):
value: T.float32 = scales[k] value = scales[k]
for i in T.serial(8): for i in T.serial(8):
shared[i] = value shared[i] = value
for i in T.serial(8): for i in T.serial(8):
...@@ -125,7 +125,7 @@ def test_multi_version_buffer_with_let(): ...@@ -125,7 +125,7 @@ def test_multi_version_buffer_with_let():
shared = T.alloc_buffer((2, 8), "float32", scope="shared.dyn") shared = T.alloc_buffer((2, 8), "float32", scope="shared.dyn")
accum = T.alloc_buffer((8,), "float32", scope="local") accum = T.alloc_buffer((8,), "float32", scope="local")
for k in T.serial(4, annotations={"num_stages": T.int32(2)}): for k in T.serial(4, annotations={"num_stages": T.int32(2)}):
value: T.float32 = scales[k] value = scales[k]
for i in T.serial(8): for i in T.serial(8):
shared[k % 2, i] = value shared[k % 2, i] = value
for i in T.serial(8): for i in T.serial(8):
......
...@@ -4,7 +4,7 @@ import ctypes ...@@ -4,7 +4,7 @@ import ctypes
import logging import logging
import warnings import warnings
from tqdm import tqdm from tqdm.auto import tqdm
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
......
...@@ -4,17 +4,19 @@ This module provides functionality for auto-tuning tilelang programs, including ...@@ -4,17 +4,19 @@ This module provides functionality for auto-tuning tilelang programs, including
and performance optimization through configuration search. and performance optimization through configuration search.
""" """
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
import tilelang import tilelang
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang.jit import JITImpl
from tilelang.jit.kernel import JITKernel
from tvm.tir import PrimFunc, Var from tvm.tir import PrimFunc, Var
from tvm.target import Target from tvm.target import Target
import inspect import inspect
from functools import partial from functools import partial
from typing import (Callable, Literal, Any, overload) from typing import (Callable, Generic, Literal, Any, ParamSpec, TypeVar)
from tqdm import tqdm from tqdm.auto import tqdm
import logging import logging
import functools
import concurrent.futures import concurrent.futures
import torch import torch
import os import os
...@@ -30,7 +32,6 @@ from tilelang import env ...@@ -30,7 +32,6 @@ from tilelang import env
from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult
from tilelang.autotuner.capture import get_autotune_inputs from tilelang.autotuner.capture import get_autotune_inputs
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.jit.param import _P, _RProg
from tilelang import __version__ from tilelang import __version__
...@@ -524,12 +525,12 @@ class AutoTuner: ...@@ -524,12 +525,12 @@ class AutoTuner:
# latency, ref_latency = target_fn(jit_kernel) # latency, ref_latency = target_fn(jit_kernel)
latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel) latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel)
except TimeoutException: except TimeoutException:
logger.info( logger.warning(
f"A timeout occurred while testing config {config}, checkout autotuner.log for more details" f"A timeout occurred while testing config {config}, checkout autotuner.log for more details"
) )
continue continue
except Exception: except Exception:
logger.info( logger.warning(
f"An error occurred while testing config {config}, checkout autotuner.log for more details" f"An error occurred while testing config {config}, checkout autotuner.log for more details"
) )
logger.debug(f"Error: {traceback.format_exc()}") logger.debug(f"Error: {traceback.format_exc()}")
...@@ -585,9 +586,13 @@ class AutoTuner: ...@@ -585,9 +586,13 @@ class AutoTuner:
return self.run() return self.run()
class _AutoTunerImplementation: _P = ParamSpec('_P')
# Overload __init__ to help type checkers understand the effect of return_program _T = TypeVar('_T')
# The '-> None' is for __init__ itself. The crucial part is Literal for return_program.
@dataclass
class AutoTuneImpl(Generic[_P, _T]):
jit_impl: JITImpl
warmup: int = 25 warmup: int = 25
rep: int = 100 rep: int = 100
...@@ -603,125 +608,51 @@ class _AutoTunerImplementation: ...@@ -603,125 +608,51 @@ class _AutoTunerImplementation:
manual_check_prog: Callable = None manual_check_prog: Callable = None
cache_input_tensors: bool = False cache_input_tensors: bool = False
def __init__(self, def __post_init__(self):
configs: dict | Callable, self._tuner_cache = {}
warmup: int = 25,
rep: int = 100, def get_tunner(self):
timeout: int = 100, autotuner = AutoTuner(
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto, self.jit_impl.func, configs=self.configs).set_profile_args(
ref_prog: Callable = None, supply_type=self.supply_type,
supply_prog: Callable = None, ref_prog=self.ref_prog,
rtol: float = 1e-2, supply_prog=self.supply_prog,
atol: float = 1e-2, rtol=self.rtol,
max_mismatched_ratio: float = 0.01, atol=self.atol,
skip_check: bool = False, max_mismatched_ratio=self.max_mismatched_ratio,
manual_check_prog: Callable = None, skip_check=self.skip_check,
cache_input_tensors: bool = False) -> None: manual_check_prog=self.manual_check_prog,
"""Initialize the AutoTunerImplementation. cache_input_tensors=self.cache_input_tensors,
).set_compile_args(
out_idx=self.jit_impl.out_idx,
execution_backend=self.jit_impl.execution_backend,
target=self.jit_impl.target,
target_host=self.jit_impl.target_host,
verbose=self.jit_impl.verbose,
pass_configs=self.jit_impl.pass_configs,
)
autotuner.run = partial(autotuner.run, self.warmup, self.rep, self.timeout)
return autotuner
Args: def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel:
configs: Configuration space to explore during auto-tuning. key_args_tuple = args
warmup: Number of warmup iterations before timing. key_kwargs_tuple = tuple(sorted(kwargs.items()))
rep: Number of repetitions for timing measurements. key = (key_args_tuple, key_kwargs_tuple)
timeout: Maximum time (in seconds) allowed for each configuration. if key not in self._tuner_cache:
supply_type: Strategy for generating input tensors (random/zeros/etc)
ref_prog: Reference implementation for validation def jit_compile(**config_arg):
supply_prog: Custom function to provide input tensors return self.jit_impl(*args, **kwargs, __tune_params=config_arg)
rtol: Relative tolerance for numerical validation
atol: Absolute tolerance for numerical validation autotuner = self.get_tunner()
max_mismatched_ratio: Allowed percentage of mismatched values autotuner.jit_compile = jit_compile
skip_check: Bypass validation against reference implementation autotuner.set_kernel_parameters(key, self.jit_impl.signature.parameters)
manual_check_prog: Custom validation function artifact = autotuner.run()
cache_input_tensors: Reuse input tensors across trials self._tuner_cache[key] = artifact.kernel
""" return self._tuner_cache[key]
# Configuration and benchmarking parameters
self.configs = configs # Search space of tuning configurations
self.warmup = warmup # Warmup iterations for stable measurements
self.rep = rep # Measurement repetitions for statistics
self.timeout = timeout # Per-configuration timeout threshold
# Tensor handling and validation setup
self.supply_type = supply_type # Input tensor generation strategy
self.ref_prog = ref_prog # Ground truth implementation
self.supply_prog = supply_prog # Custom input data provider
self.rtol = rtol # Relative error tolerance
self.atol = atol # Absolute error tolerance
self.max_mismatched_ratio = max_mismatched_ratio # Allowed mismatch
# Validation control flags
self.skip_check = skip_check # Bypass accuracy verification
self.manual_check_prog = manual_check_prog # Custom validation
self.cache_input_tensors = cache_input_tensors # Reuse inputs
# Cache for storing tuned kernel implementations
self._tuner_cache: dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel
# This tells the type checker what the *wrapper* function will return.
# this is for linting, please do not remove it.
@overload
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, tuple[_RProg, AutotuneResult]]:
...
@overload
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, AutotuneResult]:
...
# Actual implementation of __call__
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Any]:
warmup = self.warmup
rep = self.rep
timeout = self.timeout
@functools.wraps(fn)
def wrapper(*args, **kwargs):
key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
key = (key_args_tuple, key_kwargs_tuple)
if key not in self._tuner_cache:
def jit_compile(**config_arg):
return fn(*args, **kwargs, __tune_params=config_arg)
compile_arguments = fn(__return_compile_arguments=True)
autotuner = AutoTuner(
fn, configs=self.configs).set_profile_args(
supply_type=self.supply_type,
ref_prog=self.ref_prog,
supply_prog=self.supply_prog,
rtol=self.rtol,
atol=self.atol,
max_mismatched_ratio=self.max_mismatched_ratio,
skip_check=self.skip_check,
manual_check_prog=self.manual_check_prog,
cache_input_tensors=self.cache_input_tensors,
).set_compile_args(
out_idx=compile_arguments['out_idx'],
execution_backend=compile_arguments['execution_backend'],
target=compile_arguments['target'],
target_host=compile_arguments['target_host'],
verbose=compile_arguments['verbose'],
pass_configs=compile_arguments['pass_configs'],
)
autotuner.jit_compile = jit_compile
autotuner.set_kernel_parameters(key, inspect.signature(fn).parameters)
autotuner.run = partial(autotuner.run, warmup, rep, timeout)
artifact = autotuner.run()
self._tuner_cache[key] = artifact.kernel
return self._tuner_cache[key]
return wrapper
def autotune( # This is the new public interface def autotune( # This is the new public interface
func: Callable[_P, _RProg] | PrimFunc | None = None, func: Callable[_P, _T] | PrimFunc | None = None,
*, # Indicates subsequent arguments are keyword-only *, # Indicates subsequent arguments are keyword-only
configs: dict | Callable, configs: dict | Callable,
# profile arguments # profile arguments
...@@ -795,22 +726,26 @@ def autotune( # This is the new public interface ...@@ -795,22 +726,26 @@ def autotune( # This is the new public interface
elif isinstance(func, PrimFunc): elif isinstance(func, PrimFunc):
raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.") raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.")
else: else:
# Case 2: Used as @autotune(...) to configure, or func_or_out_idx is meant as out_idx.
# Create a _AutoTunerImplementation instance with the provided/defaulted arguments. def decorator(impl):
# This instance is a decorator that will be applied to the function later. assert isinstance(
configured_decorator = _AutoTunerImplementation( impl, JITImpl
configs=configs, ), "The @autotune decorator can only be applied to @tilelang.jit decorated instances."
warmup=warmup, return AutoTuneImpl(
rep=rep, jit_impl=impl,
timeout=timeout, configs=configs,
supply_type=supply_type, warmup=warmup,
ref_prog=ref_prog, rep=rep,
supply_prog=supply_prog, timeout=timeout,
rtol=rtol, supply_type=supply_type,
atol=atol, ref_prog=ref_prog,
max_mismatched_ratio=max_mismatched_ratio, supply_prog=supply_prog,
skip_check=skip_check, rtol=rtol,
manual_check_prog=manual_check_prog, atol=atol,
cache_input_tensors=cache_input_tensors, max_mismatched_ratio=max_mismatched_ratio,
) skip_check=skip_check,
return configured_decorator manual_check_prog=manual_check_prog,
cache_input_tensors=cache_input_tensors,
)
return decorator
...@@ -5,15 +5,21 @@ kernel adapter using TVM. ...@@ -5,15 +5,21 @@ kernel adapter using TVM.
""" """
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
import inspect
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Generic,
Iterable,
ParamSpec,
TypeVar,
overload, overload,
Literal, Literal,
) )
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang.language.v2 import PrimFunc
from tilelang.jit.adapter.utils import is_metal_target from tilelang.jit.adapter.utils import is_metal_target
from tvm.tir import PrimFunc
from tvm.target import Target from tvm.target import Target
from tilelang.jit.kernel import JITKernel from tilelang.jit.kernel import JITKernel
...@@ -21,14 +27,20 @@ from tilelang.utils.target import determine_target ...@@ -21,14 +27,20 @@ from tilelang.utils.target import determine_target
from tilelang.cache import cached from tilelang.cache import cached
from os import path, makedirs from os import path, makedirs
from logging import getLogger from logging import getLogger
import functools from tilelang.jit.param import Kernel
from tilelang.jit.param import Kernel, _P, _RProg import concurrent.futures
from tqdm.auto import tqdm
logger = getLogger(__name__) logger = getLogger(__name__)
_P = ParamSpec('_P')
_KP = ParamSpec('_KP')
_T = TypeVar('_T')
def compile( def compile(
func: PrimFunc = None, func: PrimFunc[_KP, _T] = None,
out_idx: list[int] | int | None = None, out_idx: list[int] | int | None = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
target: str | Target = "auto", target: str | Target = "auto",
...@@ -36,7 +48,7 @@ def compile( ...@@ -36,7 +48,7 @@ def compile(
verbose: bool = False, verbose: bool = False,
pass_configs: dict[str, Any] | None = None, pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | str | None = None, compile_flags: list[str] | str | None = None,
) -> JITKernel: ) -> JITKernel[_KP, _T]:
""" """
Compile the given TileLang PrimFunc with TVM and build a JITKernel. Compile the given TileLang PrimFunc with TVM and build a JITKernel.
Parameters Parameters
...@@ -79,159 +91,208 @@ def compile( ...@@ -79,159 +91,208 @@ def compile(
) )
class _JitImplementation: def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
out_idx: list[int] | int | None = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
target: str | Target = "auto",
target_host: str | Target | None = None,
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | str | None = None,
num_workers: int = None,
ignore_error: bool = False) -> list[JITKernel[_KP, _T]]:
"""
Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.
Parameters
----------
funcs : Iterable[tvm.tir.PrimFunc]
The TileLang TIR functions to compile and wrap.
out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None).
execution_backend : Literal["dlpack", "ctypes", "cython", "nvrtc"], optional
Execution backend to use for kernel execution (default: "cython").
target : Union[str, Target], optional
Compilation target, either as a string or a TVM Target object (default: "auto").
target_host : Union[str, Target], optional
Target host for cross-compilation (default: None).
verbose : bool, optional
Whether to enable verbose output (default: False).
pass_configs : dict, optional
Additional keyword arguments to pass to the Compiler PassContext.
Refer to `tilelang.transform.PassConfigKey` for supported options.
"""
with concurrent.futures.ThreadPoolExecutor(num_workers, 'tl-par-comp') as executor:
futures = []
future_map = {}
for i, func in enumerate(funcs):
future = executor.submit(
compile,
func=func,
out_idx=out_idx,
execution_backend=execution_backend,
target=target,
target_host=target_host,
verbose=verbose,
pass_configs=pass_configs,
compile_flags=compile_flags,
)
future_map[future] = i
futures.append(future)
results = [... for _ in futures]
for future in tqdm(
concurrent.futures.as_completed(futures),
total=len(futures),
desc="Parallel Compiling",
):
idx = future_map[future]
if ignore_error:
try:
results[idx] = future.result()
except Exception as e:
logger.warning(f"Error compiling function at index {idx}: {e}")
results[idx] = None
else:
results[idx] = future.result()
return results
return results
@dataclass
class JITImpl(Generic[_P, _KP, _T]):
func: Callable[_P, _T] | PrimFunc[_KP, _T]
out_idx: list[int] | int | None out_idx: list[int] | int | None
execution_backend: Literal["dlpack", "ctypes", "cython"]
target: str | Target target: str | Target
target_host: str | Target target_host: str | Target
execution_backend: Literal["dlpack", "ctypes", "cython"]
verbose: bool verbose: bool
pass_configs: dict[str, Any] | None pass_configs: dict[str, Any] | None
debug_root_path: str | None debug_root_path: str | None
compile_flags: list[str] | str | None compile_flags: list[str] | str | None
func_source: str
signature: inspect.Signature
def __init__(self, def __post_init__(self):
out_idx: Any = None,
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None):
"""
Initializes the JIT compiler decorator.
Parameters
----------
out_idx : Any, optional
Index(es) of the output tensors to return from the compiled kernel
(default: None, meaning all outputs are returned or determined by the kernel itself).
target : Union[str, Target], optional
Compilation target for TVM. Can be a string (e.g., "cuda", "llvm")
or a TVM Target object. If "auto", the target is determined automatically
(default: "auto").
target_host : Union[str, Target], optional
Target host for cross-compilation, similar to `target` (default: None).
execution_backend : Literal["dlpack", "ctypes", "cython"], optional
The backend used for kernel execution and argument passing.
"dlpack" is generally preferred for zero-copy tensor passing with compatible frameworks.
"ctypes" uses standard C types. "cython" uses Cython for potentially faster execution.
(default: "cython").
verbose : bool, optional
If True, enables verbose logging during compilation (default: False).
pass_configs : Optional[Dict[str, Any]], optional
A dictionary of configurations for TVM's pass context. These can fine-tune
the compilation process. Examples include "tir.disable_vectorize"
(default: None).
debug_root_path : Optional[str], optional
If provided, the compiled kernel's source code will be saved to a file
in this directory. This is useful for debugging the generated code.
If None, no debug information is saved (default: None).
If a relative path is given, it's made absolute relative to the project root
or current working directory.
compile_flags : Optional[Union[List[str], str]], optional
Additional compilation flags to pass to the compiler.
If None, no additional compilation flags are passed (default: None).
"""
self.out_idx = out_idx
self.execution_backend = execution_backend
self.target = target
self.target_host = target_host
self.verbose = verbose
self.pass_configs = pass_configs
self.compile_flags = compile_flags
# Corrected debug_root_path handling
self.debug_root_path = debug_root_path
if self.debug_root_path is not None and not path.isabs(self.debug_root_path): if self.debug_root_path is not None and not path.isabs(self.debug_root_path):
try: try:
base_path = path.dirname(path.dirname(path.dirname(__file__))) base_path = path.dirname(path.dirname(path.dirname(__file__)))
self.debug_root_path = path.join(base_path, self.debug_root_path) self.debug_root_path = path.join(base_path, self.debug_root_path)
except NameError: except NameError:
self.debug_root_path = path.abspath(self.debug_root_path) self.debug_root_path = path.abspath(self.debug_root_path)
self._kernel_cache: dict[tuple, Kernel] = {} self._kernel_cache: dict[tuple, Kernel] = {}
# This tells the type checker what the *wrapper* function will return. def get_tir(self, *args: _P.args, **kwargs: _P.kwargs) -> PrimFunc[_KP, _T]:
# this is for linting, please do not remove it. program_result_source = self.func
@overload if isinstance(program_result_source, PrimFunc):
def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, tuple[_RProg, Kernel]]: program_result = program_result_source
... elif callable(program_result_source):
program_result = program_result_source(*args, **kwargs)
@overload else:
def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Kernel]: raise ValueError(f"Invalid function type: {type(program_result_source)}")
... return program_result
# Actual implementation of __call__ def par_compile(self,
def __call__( configs: Iterable[dict[str, Any] | tuple[str, Any]],
self, num_workers: int = None,
func: Callable[_P, _RProg] # func is Union[Callable[_P, _RProg], PrimFunc] in original ignore_error: bool = False) -> list[JITKernel[_KP, _T]]:
) -> Callable[_P, Any]: configs = list(configs)
funcs = []
@functools.wraps(func) for cfg in tqdm(configs, desc='Elaborating'):
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any: if isinstance(cfg, tuple):
# Separate out the tuning parameters from the user's kwargs funcs.append(self.get_tir(*cfg))
tune_params = kwargs.pop('__tune_params', {}) elif isinstance(cfg, dict):
# Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache funcs.append(self.get_tir(**cfg))
return_compile_arguments = kwargs.pop('__return_compile_arguments', False) else:
if return_compile_arguments: raise ValueError(f"Invalid config type: {type(cfg)}, expected tuple or dict.")
compile_args = { return par_compile(
'out_idx': self.out_idx, funcs,
'execution_backend': self.execution_backend, out_idx=self.out_idx,
'target': self.target, execution_backend=self.execution_backend,
'target_host': self.target_host, target=self.target,
'verbose': self.verbose, target_host=self.target_host,
'pass_configs': self.pass_configs, verbose=self.verbose,
'compile_flags': self.compile_flags, pass_configs=self.pass_configs,
} compile_flags=self.compile_flags,
return compile_args num_workers=num_workers,
ignore_error=ignore_error)
key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items())) def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel[_KP, _T]:
tuned_key_kwargs_tuple = tuple(sorted(tune_params.items())) func = self.get_tir(*args, **kwargs)
key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple) kernel_result = compile(
func,
if key not in self._kernel_cache: out_idx=self.out_idx,
# Ensure 'func' (the original user function) is used correctly execution_backend=self.execution_backend,
program_result_source = func target=self.target,
if isinstance(program_result_source, PrimFunc): target_host=self.target_host,
program_result = program_result_source verbose=self.verbose,
elif callable(program_result_source): pass_configs=self.pass_configs,
program_result = program_result_source(*args, **kwargs, **tune_params) compile_flags=self.compile_flags,
else: )
raise ValueError(f"Invalid function type: {type(program_result_source)}")
if self.debug_root_path:
kernel_result = compile( if isinstance(self.func, PrimFunc):
program_result, func_name = self.func.attrs['global_symbol']
out_idx=self.out_idx, else:
execution_backend=self.execution_backend, func_name = getattr(self.func, '__name__', 'jit_kernel')
target=self.target, kernel_file = f'tilelang_jit_kernel_{func_name}.c'
target_host=self.target_host, program_file = f'tilelang_jit_program_{func_name}.py'
verbose=self.verbose, makedirs(self.debug_root_path, exist_ok=True)
pass_configs=self.pass_configs, with open(path.join(self.debug_root_path, kernel_file), 'w') as f:
compile_flags=self.compile_flags, print(kernel_result.get_kernel_source(), file=f)
) with open(path.join(self.debug_root_path, program_file), 'w') as f:
print(func.script(), file=f)
if self.debug_root_path:
func_name = getattr(func, '__name__', 'jit_kernel') # Use func for name return kernel_result
kernel_file = f'tilelang_jit_kernel_{func_name}.c'
program_file = f'tilelang_jit_program_{func_name}.py' def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel[_KP, _T]:
makedirs(self.debug_root_path, exist_ok=True) # Separate out the tuning parameters from the user's kwargs
with open(path.join(self.debug_root_path, kernel_file), 'w') as f: tune_params = kwargs.pop('__tune_params', {})
print(kernel_result.get_kernel_source(), file=f) # Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache
with open(path.join(self.debug_root_path, program_file), 'w') as f: return_compile_arguments = kwargs.pop('__return_compile_arguments', False)
print(program_result.script(), file=f) if return_compile_arguments:
compile_args = {
self._kernel_cache[key] = kernel_result 'out_idx': self.out_idx,
'execution_backend': self.execution_backend,
return self._kernel_cache[key] 'target': self.target,
'target_host': self.target_host,
return wrapper 'verbose': self.verbose,
'pass_configs': self.pass_configs,
'compile_flags': self.compile_flags,
}
return compile_args
key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
tuned_key_kwargs_tuple = tuple(sorted(tune_params.items()))
key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple)
if key not in self._kernel_cache:
self._kernel_cache[key] = self.compile(*args, **kwargs, **tune_params)
return self._kernel_cache[key]
@overload
def jit(func: Callable[_P, PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T]:
...
@overload
def jit(
*, # Indicates subsequent arguments are keyword-only
out_idx: Any = None,
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None
) -> Callable[[Callable[_P, PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T]]:
...
def jit( # This is the new public interface def jit( # This is the new public interface
func: Callable[_P, _RProg] | PrimFunc | None = None, func: Callable[_P, _T] | PrimFunc | None = None,
*, # Indicates subsequent arguments are keyword-only *, # Indicates subsequent arguments are keyword-only
out_idx: Any = None, out_idx: Any = None,
target: str | Target = "auto", target: str | Target = "auto",
...@@ -275,32 +336,26 @@ def jit( # This is the new public interface ...@@ -275,32 +336,26 @@ def jit( # This is the new public interface
if isinstance(compile_flags, str): if isinstance(compile_flags, str):
compile_flags = [compile_flags] compile_flags = [compile_flags]
if callable(func): def decorator(func: Callable[_P, _T]) -> JITImpl[_P, _T]:
# Case 1: Used as @jit (func_or_out_idx is the function, others are defaults) if isinstance(func, PrimFunc):
# Create a default _JitImplementation instance and apply it to the function. orig_func = func.orig_func
default_decorator = _JitImplementation( else:
out_idx=out_idx, # Explicitly None for the default case orig_func = func
target=target, return JITImpl(
target_host=target_host, func,
out_idx=out_idx,
execution_backend=execution_backend, execution_backend=execution_backend,
verbose=verbose,
pass_configs=pass_configs,
debug_root_path=debug_root_path,
compile_flags=compile_flags)
return default_decorator(func)
elif isinstance(func, PrimFunc):
raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.")
else:
# Case 2: Used as @jit(...) to configure, or func_or_out_idx is meant as out_idx.
# Create a _JitImplementation instance with the provided/defaulted arguments.
# This instance is a decorator that will be applied to the function later.
configured_decorator = _JitImplementation(
out_idx=out_idx, # Pass along; could be an actual out_idx or None
target=target, target=target,
target_host=target_host, target_host=target_host,
execution_backend=execution_backend,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
debug_root_path=debug_root_path, debug_root_path=debug_root_path,
compile_flags=compile_flags) compile_flags=compile_flags,
return configured_decorator func_source=inspect.getsource(orig_func),
signature=inspect.signature(orig_func),
)
if func is not None:
return decorator(func)
else:
return decorator
...@@ -27,7 +27,11 @@ class MetalKernelAdapter(BaseKernelAdapter): ...@@ -27,7 +27,11 @@ class MetalKernelAdapter(BaseKernelAdapter):
# compile_flags: Optional[List[str]] = None # compile_flags: Optional[List[str]] = None
): ):
self.kernel_global_source = kernel_global_source self.kernel_global_source = kernel_global_source
self.kernel_name = func_or_mod.__name__ + '_kernel' if isinstance(func_or_mod, tir.PrimFunc):
func_name = func_or_mod.attrs['global_symbol']
else:
func_name = func_or_mod.__name__
self.kernel_name = func_name + '_kernel'
self.verbose = verbose self.verbose = verbose
self.block_info = [1, 1, 1] self.block_info = [1, 1, 1]
...@@ -43,7 +47,7 @@ class MetalKernelAdapter(BaseKernelAdapter): ...@@ -43,7 +47,7 @@ class MetalKernelAdapter(BaseKernelAdapter):
self.grid_info["xyz".index(tag[-1])] = extent self.grid_info["xyz".index(tag[-1])] = extent
break break
else: else:
raise AssertionError(f'no kernel with name {func_or_mod.__name__}') raise AssertionError(f'no kernel with name {func_name}')
# print(self.block_info, self.grid_info) # print(self.block_info, self.grid_info)
super().__init__(func_or_mod, result_idx=result_idx, params=params) super().__init__(func_or_mod, result_idx=result_idx, params=params)
......
from __future__ import annotations from __future__ import annotations
from typing import Any, Callable, Literal from typing import Any, Callable, Generic, Literal, ParamSpec, TypeVar
from tilelang.jit.adapter.utils import is_metal_target from tilelang.jit.adapter.utils import is_metal_target
from tvm.target import Target from tvm.target import Target
...@@ -17,8 +17,11 @@ import logging ...@@ -17,8 +17,11 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_P = ParamSpec('_P')
_T = TypeVar('_T')
class JITKernel:
class JITKernel(Generic[_P, _T]):
""" """
A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions. A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions.
...@@ -170,7 +173,7 @@ class JITKernel: ...@@ -170,7 +173,7 @@ class JITKernel:
instance.torch_function = instance.adapter.func instance.torch_function = instance.adapter.func
return instance return instance
def __call__(self, *args: Any, **kwds: Any) -> Any: def __call__(self, *args: _P.args, **kwds: _P.kwargs) -> _T:
""" """
Invokes the compiled function with the given arguments. Invokes the compiled function with the given arguments.
......
...@@ -8,9 +8,9 @@ from __future__ import annotations ...@@ -8,9 +8,9 @@ from __future__ import annotations
# upstream tir script is fully compatible # upstream tir script is fully compatible
from tvm.script.parser.tir import * from tvm.script.parser.tir import *
from . import overrides as _overrides # noqa: F401 from . import overrides as _overrides # noqa: F401
from .tir import (
prim_func, # noqa: F401 # from .tir import prim_func, macro, # noqa: F401
) from .v2 import * # noqa: F401
from .tir.ir import * # noqa: F401 from .tir.ir import * # noqa: F401
from tilelang.layout import Layout, Fragment # noqa: F401 from tilelang.layout import Layout, Fragment # noqa: F401
from .proxy import ( from .proxy import (
......
...@@ -7,7 +7,6 @@ from tilelang.utils import deprecated ...@@ -7,7 +7,6 @@ from tilelang.utils import deprecated
__all__ = ["dynamic", "symbolic"] __all__ = ["dynamic", "symbolic"]
@deprecated("T.dynamic(...)", "tir.Var(...)", "v0.1.9")
def dynamic(name: str, dtype: str = "int32"): def dynamic(name: str, dtype: str = "int32"):
""" """
Create a TIR dynamic symbolic variable. Create a TIR dynamic symbolic variable.
...@@ -22,7 +21,7 @@ def dynamic(name: str, dtype: str = "int32"): ...@@ -22,7 +21,7 @@ def dynamic(name: str, dtype: str = "int32"):
return tir.Var(name, dtype) return tir.Var(name, dtype)
@deprecated("T.symbolic(...)", "T.dynamic(...)") @deprecated("T.symbolic(...)", "T.dynamic(...)", "v0.1.9")
def symbolic(name: str, dtype: str = "int32"): def symbolic(name: str, dtype: str = "int32"):
"""Deprecated alias for `T.dynamic`.""" """Deprecated alias for `T.dynamic`."""
return tir.Var(name, dtype) return tir.Var(name, dtype)
from .builder import prim_func, macro, PrimFunc # noqa: F401
from .dtypes import *
from __future__ import annotations
import ast
from dataclasses import dataclass
from typing import Callable, ContextManager, Generic, Iterable, Any, Literal, ParamSpec, TypeVar
import inspect
# from .utils import get_ast, get_compiled_object
from . import utils
_span_attrs = ['lineno', 'col_offset', 'end_lineno', 'end_col_offset']
def ast_has_span(ast: ast.AST) -> bool:
return all(hasattr(ast, attr) for attr in _span_attrs)
def ast_get_span(ast: ast.AST) -> tuple[int, int, int, int]:
if not ast_has_span(ast):
return None
return tuple(getattr(ast, attr) for attr in _span_attrs)
def ast_set_span(ast: ast.AST, span: tuple[int, int, int, int]):
if not ast_has_span(ast):
return
for attr, value in zip(_span_attrs, span):
setattr(ast, attr, value)
class QuoteVisitor(ast.NodeTransformer):
def __init__(self, names: dict[str, ast.AST], passes: list[Any] | None = None, span=None):
self.names = names
self.passes = passes or []
self.span = span
def generic_visit(self, node: ast.AST):
if self.span is not None:
ast_set_span(node, self.span)
return super().generic_visit(node)
def visit_Name(self, node: ast.Name) -> Any:
if node.id in self.names:
return self.names[node.id]
else:
return node
def visit_Pass(self, node: ast.Pass) -> Any:
item = self.passes.pop(0)
return item if item else node
def quote(expr: str, *, passes: list[Any] | None = None, span=None, **kws) -> list[ast.AST]:
tree = ast.parse(expr)
if isinstance(span, ast.AST):
span = ast_get_span(span)
tree = QuoteVisitor(kws, passes, span).visit(tree)
return tree.body
def quote1(expr: str, *, passes: list[Any] | None = None, span=None, **kws) -> ast.AST:
res = quote(expr, passes=passes, span=span, **kws)
assert len(res) == 1
return res[0]
def quote_expr(expr: str, **kws) -> ast.expr:
res = quote1(expr, **kws)
assert isinstance(res, ast.Expr)
return res.value
Operator = Literal['Add', 'Sub', 'Mult', 'MatMult', 'Div', 'Mod', 'Pow', 'LShift', 'RShift',
'BitOr', 'BitXor', 'BitAnd', 'FloorDiv']
BoolOp = Literal['And', 'Or']
def get_operator_name(operator: ast.operator) -> Operator:
return operator.__class__.__name__
def get_boolop_name(boolop: ast.boolop) -> BoolOp:
return boolop.__class__.__name__
_T = TypeVar('_T')
def eval_op(op: Operator, left: Any, right: Any) -> Any:
if op == 'Add':
return left + right
if op == 'Sub':
return left - right
if op == 'Mult':
return left * right
if op == 'MatMult':
return left @ right
if op == 'Div':
return left / right
if op == 'Mod':
return left % right
if op == 'Pow':
return left**right
if op == 'LShift':
return left << right
if op == 'RShift':
return left >> right
if op == 'BitOr':
return left | right
if op == 'BitXor':
return left ^ right
if op == 'BitAnd':
return left & right
if op == 'FloorDiv':
return left // right
raise ValueError(f'Unknown operator: {op}')
def eval_aug_assign(op: Operator, left: Any, sl: slice, right: Any) -> Any:
if op == 'Add':
left[sl] += right
return left
if op == 'Sub':
left[sl] -= right
return left
if op == 'Mult':
left[sl] *= right
return left
if op == 'MatMult':
left[sl] @= right
return left
if op == 'Div':
left[sl] /= right
return left
if op == 'Mod':
left[sl] %= right
return left
if op == 'Pow':
left[sl] **= right
return left
if op == 'LShift':
left[sl] <<= right
return left
if op == 'RShift':
left[sl] >>= right
return left
if op == 'BitOr':
left[sl] |= right
return left
if op == 'BitXor':
left[sl] ^= right
return left
if op == 'BitAnd':
left[sl] &= right
return left
if op == 'FloorDiv':
left[sl] //= right
return left
raise ValueError(f'Unknown operator: {op}')
class _empty:
...
class BaseBuilder:
empty = _empty
def get_parent_locals(self):
return inspect.currentframe().f_back.f_back.f_locals
def ctx_if(self, cond) -> Iterable[_T]:
yield cond
def ctx_then(self, val: _T) -> Iterable[None]:
if val:
yield
def ctx_else(self, val: _T) -> Iterable[None]:
if not val:
yield
def eval(self, val: Any): # noqa: B027
pass
def ctx_for(self, range: Iterable[Any]) -> Iterable[Any]:
return range
def ctx_continue(self) -> bool:
return True
def ctx_break(self) -> bool:
return True
def ctx_while(self, cond: Callable[[], Any]) -> Iterable[None]:
while cond():
yield
def bind(self, name: str, value: Any, annot: Any = empty) -> Any:
return value
def unwrap_value(self, value):
return value
def assign_slice(self, lval: Any, sl: slice, value: Any, annot: Any = empty):
lval[sl] = value
def aug_assign(self, op: Operator, target: Any, aug_value: Any) -> Any:
return eval_op(op, target, aug_value)
def aug_assign_slice(self, op: Operator, target: Any, sl: slice, aug_value: Any):
eval_aug_assign(op, target, sl, aug_value)
def boolop(self, op: BoolOp, left: Any, right: Callable[[], Any]) -> Any:
if op == 'And':
return left and right()
if op == 'Or':
return left or right()
raise ValueError(f'Unknown boolop: {op}')
def ifexp(self, cond: Any, then: Callable[[], Any], otherwise: Callable[[], Any]) -> Any:
return then() if cond else otherwise()
def ret(self, value: Any) -> Any:
return value
def ctx_with(self, ctx: ContextManager[Any]) -> ContextManager[Any]:
return ctx
def assert_expr(self, cond: Any, msg: Any):
assert cond, msg
def rval(self, name: str, value: Any):
return value
def arg(self, name: str, value: Any):
return value
def override(self, name: str):
return globals()[name]
class DSLMutator(ast.NodeTransformer):
def __init__(self):
self.tmp_counter = 0
def get_tmp(self) -> str:
name = f"__{self.tmp_counter}"
self.tmp_counter += 1
return name
def visit_If(self, node: ast.If):
node = self.generic_visit(node)
br = self.get_tmp()
if len(node.orelse) == 0:
return quote(
f"for {br} in __tb.ctx_if(cond):\n"
f" for _ in __tb.ctx_then({br}):\n"
" pass\n",
cond=node.test,
passes=[node.body],
span=node,
)
return quote(
f"for {br} in __tb.ctx_if(cond):\n"
f" for _ in __tb.ctx_then({br}):\n"
f" pass\n"
f" for _ in __tb.ctx_else({br}):\n"
f" pass\n",
cond=node.test,
passes=[node.body, node.orelse],
span=node,
)
def visit_Expr(self, node: ast.Expr):
node = self.generic_visit(node)
return quote("__tb.eval(value)", value=node.value, span=node)
def _parse_names(self, target: ast.expr):
if isinstance(target, ast.Name):
return f"'{target.id}'"
elif isinstance(target, ast.Tuple):
return ("(" + ",".join([self._parse_names(elt) for elt in target.elts]) + ",)")
else:
s = ast.unparse(target)
raise NotImplementedError(f"Unsupported for target `{s}`")
def visit_For(self, node: ast.For):
node = self.generic_visit(node)
tmp = self.get_tmp()
# names = self._parse_names(node.target)
var = ast.Name(tmp, ctx=ast.Load())
ast_set_span(var, ast_get_span(node.target))
stmts = self._emit_assign_target(node.target, var)
return quote(
f"for {tmp} in __tb.ctx_for(range):\n"
" pass\n",
target=node.target,
range=node.iter,
passes=[stmts + node.body],
span=node,
)
def visit_Continue(self, node: ast.Continue):
node = self.generic_visit(node)
return quote("if __tb.ctx_continue(): continue", span=node)
def visit_Break(self, node: ast.Break):
node = self.generic_visit(node)
return quote("if __tb.ctx_break(): break", span=node)
def _emit_assign_target(self,
target: ast.expr,
rval: ast.expr,
annot: ast.expr = None) -> list[ast.AST]:
if isinstance(target, ast.Name):
if annot is None:
return quote(
f"name = __tb.bind('{target.id}', value)", name=target, value=rval, span=target)
else:
return quote(
f'name = __tb.bind("{target.id}", value, annot)',
name=target,
value=rval,
annot=annot,
span=target)
elif isinstance(target, ast.Attribute):
s = ast.unparse(target)
raise NotImplementedError(f'Attribute assignment not supported yet, `{s}`')
elif isinstance(target, ast.Subscript):
if annot is None:
return quote(
"__tb.assign_slice(lval, slice, value)",
lval=target.value,
slice=target.slice,
value=rval,
span=target,
)
else:
return quote(
"__tb.assign_slice(lval, slice, value, annot)",
lval=target.value,
slice=target.slice,
value=rval,
annot=annot,
span=target,
)
else:
unpacked = []
def _visit_target(target: ast.expr) -> str:
if isinstance(target, (ast.Name, ast.Subscript)):
tmp = self.get_tmp()
unpacked.append((tmp, target))
res = ast.Name(id=tmp, ctx=target.ctx)
ast_set_span(res, ast_get_span(target))
return res
elif isinstance(target, ast.Tuple):
elts = [_visit_target(elt) for elt in target.elts]
res = ast.Tuple(elts=elts, ctx=target.ctx)
ast_set_span(res, ast_get_span(target))
return res
unpack_stmt = ast.Assign(
targets=[_visit_target(target)],
value=quote_expr('__tb.unwrap_value(rval)', rval=rval, span=rval))
ast_set_span(unpack_stmt, ast_get_span(target))
stmts = [unpack_stmt]
bind_lvals = []
bind_rvals = []
def flush_binds():
if bind_lvals:
stmts.append(
quote1(f'{", ".join(bind_lvals)}, = {", ".join(bind_rvals)},', span=target))
bind_lvals.clear()
bind_rvals.clear()
for tmp, target in unpacked:
if isinstance(target, ast.Name):
bind_lvals.append(target.id)
bind_rvals.append(f'__tb.bind("{target.id}", {tmp})')
elif isinstance(target, ast.Subscript):
flush_binds()
stmts.append(
quote1(
f'__tb.assign_slice(lval, slice, {tmp})',
lval=target.value,
slice=target.slice,
span=target))
else:
s = ast.unparse(target)
raise NotImplementedError(f'Unsupported target: {s}')
flush_binds()
return stmts
def visit_Assign(self, node: ast.Assign) -> list[ast.AST]:
node = self.generic_visit(node)
rval = node.value
if len(node.targets) == 1:
return self._emit_assign_target(node.targets[0], rval)
else:
tmp_name = self.get_tmp()
tmp_store = ast.Name(tmp_name, ctx=ast.Store())
tmp_load = ast.Name(tmp_name, ctx=ast.Load())
ast_set_span(tmp_store, node.targets[0])
ast_set_span(tmp_load, node.targets[0])
stmt = self._emit_assign_target(tmp_store, rval)
for target in node.targets:
stmt.extend(self._emit_assign_target(target, tmp_load))
return stmt
def visit_AugAssign(self, node: ast.AugAssign) -> list[ast.AST]:
node = self.generic_visit(node)
target, rval = node.target, node.value
op = get_operator_name(node.op)
if isinstance(target, ast.Name):
return quote(
f"name = __tb.aug_assign('{op}', {target.id}, value)",
name=target,
value=rval,
span=node)
elif isinstance(target, ast.Subscript):
return quote(
f"__tb.aug_assign_slice('{op}', lval, slice, value)",
lval=target.value,
slice=target.slice,
value=rval,
span=node,
)
else:
return node
def visit_AnnAssign(self, node: ast.AnnAssign):
node = self.generic_visit(node)
rval = node.value or quote_expr('__tb.empty', span=node, annot=node)
return self._emit_assign_target(node.target, rval, annot=node.annotation)
def visit_While(self, node):
return quote1(
"for _ in __tb.ctx_while(lambda: cond):\n pass",
cond=node.test,
passes=[node.body],
span=node)
def visit_FunctionDef(self, node: ast.FunctionDef):
node = self.generic_visit(node)
all_args = node.args.posonlyargs + node.args.args
if node.args.vararg is not None:
all_args += node.args.vararg
all_args += node.args.kwonlyargs
stmts = []
for arg in all_args:
name = arg.arg
if arg.annotation is not None:
arg_stmt = quote1(f'{name} = __tb.arg("{name}", {name})', span=arg)
else:
arg_stmt = quote1(f'{name} = __tb.arg("{name}", {name})', span=arg)
arg.annotation = None
stmts.append(arg_stmt)
node.body = stmts + node.body
node.decorator_list.clear()
return quote1(
f"def {node.name}(__tb):\n"
" range = __tb.override('range')\n"
" pass\n"
f" return {node.name}",
passes=[node],
)
def visit_BoolOp(self, node: ast.BoolOp):
node = self.generic_visit(node)
op_name = get_boolop_name(node.op)
last = node.values[-1]
for i in reversed(range(len(node.values) - 1)):
last = quote_expr(
expr=f"__tb.boolop('{op_name}', left, lambda: right)",
left=node.values[i],
right=last,
span=node,
)
return last
def visit_Compare(self, node: ast.Compare) -> ast.expr:
node = self.generic_visit(node)
left = node.left
split = []
for op, comp in zip(node.ops, node.comparators):
cmp = ast.Compare(left=left, ops=[op], comparators=[comp])
ast_set_span(cmp, ast_get_span(node))
split.append(cmp)
left = comp
last = split[-1]
for i in reversed(range(len(split) - 1)):
last = quote_expr(
"__tb.boolop('And', left, lambda: right)", left=split[i], right=last, span=node)
return last
def visit_IfExp(self, node: ast.IfExp) -> ast.Expr:
node = self.generic_visit(node)
return quote_expr(
'__tb.ifexp(cond, lambda: then, lambda: otherwise)',
cond=node.test,
then=node.body,
otherwise=node.orelse,
span=node)
def visit_Return(self, node: ast.Return):
node = self.generic_visit(node)
return quote("return __tb.ret(value)", value=node.value, span=node)
def visit_With(self, node: ast.With):
node = self.generic_visit(node)
for expr in node.items:
expr.context_expr = quote_expr("__tb.ctx_with(e)", e=expr.context_expr, span=expr)
return node
def visit_Assert(self, node: ast.Assert):
node = self.generic_visit(node)
return quote("__tb.assert_expr(cond, msg)", cond=node.test, msg=node.msg, span=node)
def visit_Name(self, node: ast.Name):
if isinstance(node.ctx, ast.Load):
return quote_expr(f"__tb.rval('{node.id}', {node.id})", span=node)
return node
_P = ParamSpec('_P')
@dataclass
class IRGenerator(Generic[_P, _T]):
gen: Callable[[BaseBuilder], Callable[_P, _T]]
source: str
def mutate(func: Callable[_P, _T]) -> IRGenerator[_P, _T]:
"""
Transform a Python function into an IR (Intermediate Representation) generator.
This function takes a regular Python function and performs AST (Abstract Syntax Tree)
transformation to create an IRGenerator that can be used for code generation purposes.
Args:
func (Callable[_P, _T]): The Python function to be transformed. This should be a
callable that will be analyzed and mutated at the AST level. The function's
signature is preserved through generic type parameters _P (parameters) and
_T (return type).
Returns:
IRGenerator[_P, _T]: An IRGenerator instance wrapping the transformed function.
The generator contains:
- gen: The compiled and mutated version of the original function
- source: The unparsed source code of the transformed AST as a string
Example:
>>> @mutate
... def my_function(x: int) -> int:
... return x * 2
>>> # my_function is now an IRGenerator that can be used for code generation
Note:
- The original function's closure variables and captured context are preserved
- The transformation is performed at compile-time through AST manipulation
- The returned IRGenerator maintains type information from the original function
"""
tree = utils.get_ast(func)
filename = inspect.getsourcefile(func) or inspect.getfile(func)
tree = DSLMutator().visit(tree)
fn = utils.get_compiled_object(tree, func.__name__, filename,
utils.inspect_function_capture(func))
return IRGenerator(gen=fn, source=ast.unparse(tree))
This diff is collapsed.
from tilelang import tvm
from tvm import ir
import tvm_ffi
import torch
import ctypes
from typing import TYPE_CHECKING
from tvm import tir
import tvm.script.ir_builder.tir._ffi_api as tb_ffi
dtype = tvm.DataType
AnyDType = ir.Type | str | type | torch.dtype | dtype
_dtype_cvt = [
(None, 'handle', ctypes.c_long, 'long', None), # use long to repr void*
(bool, 'bool', ctypes.c_bool, 'bool', 'Boolean'),
(int, 'int32', ctypes.c_int32, 'int', 'Int32'),
(float, 'float32', ctypes.c_float, 'float', 'Float32'),
(torch.short, 'int16', ctypes.c_int16, 'short', 'Int16'),
(torch.int, 'int32', ctypes.c_int32, 'int', 'Int32'),
(torch.long, 'int64', ctypes.c_int64, 'long long', 'Int64'),
(torch.half, 'float16', None, None, 'Float16'),
(torch.float, 'float32', ctypes.c_float, 'float', 'Float32'),
(torch.double, 'float64', ctypes.c_double, 'double', 'Float64'),
# (pytype, 'tvm dtype str', 'ctypes dtype', 'cffi dtype')
(torch.bool, 'bool', ctypes.c_bool, 'bool', 'Boolean'),
(torch.int8, 'int8', ctypes.c_int8, 'char', 'Int8'),
(torch.int16, 'int16', ctypes.c_int16, 'short', 'Int16'),
(torch.int32, 'int32', ctypes.c_int32, 'int', 'Int32'),
(torch.int64, 'int64', ctypes.c_int64, 'long long', 'Int64'),
(torch.uint8, 'uint8', ctypes.c_uint8, 'unsigned char', 'UInt8'),
(torch.uint16, 'uint16', ctypes.c_uint16, 'unsigned short', 'UInt16'),
(torch.uint32, 'uint32', ctypes.c_uint32, 'unsigned int', 'UInt32'),
(torch.uint64, 'uint64', ctypes.c_uint64, 'unsigned long long', 'UInt64'),
(torch.float16, 'float16', None, None, 'Float16'),
(torch.float32, 'float32', ctypes.c_float, 'float', 'Float32'),
(torch.float64, 'float64', ctypes.c_double, 'double', 'Float64'),
(None, 'float8_e4m3', None, None, 'Float8E4M3'),
(torch.float8_e4m3fn, 'float8_e4m3fn', None, None, 'Float8E4M3FN'),
(torch.float8_e4m3fnuz, 'float8_e4m3fnuz', None, None, 'Float8E4M3FNUZ'),
(torch.float8_e5m2, 'float8_e5m2', None, None, 'Float8E5M2'),
(torch.float8_e5m2fnuz, 'float8_e5m2fnuz', None, None, 'Float8E5M2FNUZ'),
(torch.float8_e8m0fnu, 'float8_e8m0fnu', None, None, 'Float8E8M0FNU'),
(torch.bfloat16, 'bfloat16', None, None, 'BFloat16'),
]
def _create_type_mapper(sidx, didx, smapper=lambda x: x, dmapper=lambda x: x):
return {
smapper(item[sidx]): dmapper(item[didx])
for item in _dtype_cvt
if item[didx] is not None and item[sidx] is not None
}
_dtype_py2tvmstr = _create_type_mapper(0, 1)
_dtype_tvmstr2fficall = _create_type_mapper(1, 4, dmapper=lambda x: getattr(tb_ffi, x))
_dtype_tvm2py = _create_type_mapper(1, 0, lambda x: dtype(x))
_dtype_tvm2ctype = _create_type_mapper(1, 2, lambda x: dtype(x))
_dtype_tvm2cffi = _create_type_mapper(1, 3, lambda x: dtype(x))
def __dtype_eq__(self: dtype, other: AnyDType):
if isinstance(other, str):
return str.__eq__(self, other)
if other in _dtype_py2tvmstr:
return str.__eq__(self, _dtype_py2tvmstr[other])
return NotImplemented
def __dtype_ne__(self: dtype, other: AnyDType):
if isinstance(other, str):
return str.__ne__(self, other)
if other in _dtype_py2tvmstr:
return str.__ne__(self, _dtype_py2tvmstr[other])
return NotImplemented
def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var:
if self in _dtype_tvmstr2fficall:
return _dtype_tvmstr2fficall[self](expr, is_size_var)
# try to construct the ffi call
if self.startswith('uint'):
val = 'UInt' + self[4:]
elif self.startswith('int'):
val = 'Int' + self[3:]
elif self.startswith('float'):
val = 'Float' + self[5:]
elif self.startswith('bfloat'):
val = 'BFloat' + self[6:]
else:
raise TypeError(f'Invalid type {self}')
if '_' in val:
first, second = val.split('_', maxsplit=1)
val = first + second.upper()
call = getattr(tb_ffi, val, None)
if call is None:
raise TypeError(f"Convert to datatype `{self}` is not supported by tvm\n"
f"calling failed on `tvm.script.ir_builder.tir._ffi_api.{val}`")
return call(expr, is_size_var)
def __dtype_new__(cls, value: AnyDType) -> dtype:
if isinstance(value, str):
val = str.__new__(cls, value)
elif value in _dtype_py2tvmstr:
val = str.__new__(cls, _dtype_py2tvmstr[value])
else:
expected = set(list(_dtype_py2tvmstr.keys()) + list(_dtype_tvmstr2fficall.values()))
raise TypeError(f"Invalid DataType {value}({type(value)}), expect one of {expected}")
val.__tvm_ffi_dtype__ = tvm_ffi.core.DataType(val)
return val
dtype.__eq__ = __dtype_eq__
dtype.__req__ = __dtype_eq__
dtype.__ne__ = __dtype_ne__
dtype.__rne__ = __dtype_ne__
dtype.__call__ = __dtype_call__
dtype.__new__ = __dtype_new__
def get_tvm_dtype(value: AnyDType) -> dtype:
if isinstance(value, (dtype, ir.Type)):
return value
return dtype(value)
if TYPE_CHECKING:
# yapf: disable
class bool(dtype): ...
class short(dtype): ...
class int(dtype): ...
class long(dtype): ...
class half(dtype): ...
class float(dtype): ...
class double(dtype): ...
class int8(dtype): ...
class int16(dtype): ...
class int32(dtype): ...
class int64(dtype): ...
class int8x4(dtype): ...
class int16x4(dtype): ...
class int32x4(dtype): ...
class int64x4(dtype): ...
class int8x8(dtype): ...
class int16x8(dtype): ...
class int32x8(dtype): ...
class int64x8(dtype): ...
class int8x16(dtype): ...
class int16x16(dtype): ...
class int32x16(dtype): ...
class int64x16(dtype): ...
class int8x32(dtype): ...
class int16x32(dtype): ...
class int32x32(dtype): ...
class int64x32(dtype): ...
class int8x64(dtype): ...
class int16x64(dtype): ...
class int32x64(dtype): ...
class int64x64(dtype): ...
class uint8(dtype): ...
class uint16(dtype): ...
class uint32(dtype): ...
class uint64(dtype): ...
class uint8x4(dtype): ...
class uint16x4(dtype): ...
class uint32x4(dtype): ...
class uint64x4(dtype): ...
class uint8x8(dtype): ...
class uint16x8(dtype): ...
class uint32x8(dtype): ...
class uint64x8(dtype): ...
class uint8x16(dtype): ...
class uint16x16(dtype): ...
class uint32x16(dtype): ...
class uint64x16(dtype): ...
class uint8x32(dtype): ...
class uint16x32(dtype): ...
class uint32x32(dtype): ...
class uint64x32(dtype): ...
class uint8x64(dtype): ...
class uint16x64(dtype): ...
class uint32x64(dtype): ...
class uint64x64(dtype): ...
class float16(dtype): ...
class float32(dtype): ...
class float64(dtype): ...
class float16x2(dtype): ...
class float32x2(dtype): ...
class float64x2(dtype): ...
class float16x4(dtype): ...
class float32x4(dtype): ...
class float64x4(dtype): ...
class float16x8(dtype): ...
class float32x8(dtype): ...
class float64x8(dtype): ...
class float16x16(dtype): ...
class float32x16(dtype): ...
class float64x16(dtype): ...
class float16x32(dtype): ...
class float32x32(dtype): ...
class float64x32(dtype): ...
class float16x64(dtype): ...
class float32x64(dtype): ...
class float64x64(dtype): ...
class float8_e3m4(dtype): ...
class float8_e3m4x2(dtype): ...
class float8_e3m4x4(dtype): ...
class float8_e3m4x8(dtype): ...
class float8_e3m4x16(dtype): ...
class float8_e3m4x32(dtype): ...
class float8_e3m4x64(dtype): ...
class float8_e4m3(dtype): ...
class float8_e4m3x2(dtype): ...
class float8_e4m3x4(dtype): ...
class float8_e4m3x8(dtype): ...
class float8_e4m3x16(dtype): ...
class float8_e4m3x32(dtype): ...
class float8_e4m3x64(dtype): ...
class float8_e4m3b11fnuz(dtype): ...
class float8_e4m3b11fnuzx2(dtype): ...
class float8_e4m3b11fnuzx4(dtype): ...
class float8_e4m3b11fnuzx8(dtype): ...
class float8_e4m3b11fnuzx16(dtype): ...
class float8_e4m3b11fnuzx32(dtype): ...
class float8_e4m3b11fnuzx64(dtype): ...
class float8_e4m3fn(dtype): ...
class float8_e4m3fnx2(dtype): ...
class float8_e4m3fnx4(dtype): ...
class float8_e4m3fnx8(dtype): ...
class float8_e4m3fnx16(dtype): ...
class float8_e4m3fnx32(dtype): ...
class float8_e4m3fnx64(dtype): ...
class float8_e4m3fnuz(dtype): ...
class float8_e4m3fnuzx2(dtype): ...
class float8_e4m3fnuzx4(dtype): ...
class float8_e4m3fnuzx8(dtype): ...
class float8_e4m3fnuzx16(dtype): ...
class float8_e4m3fnuzx32(dtype): ...
class float8_e4m3fnuzx64(dtype): ...
class float8_e5m2(dtype): ...
class float8_e5m2x2(dtype): ...
class float8_e5m2x4(dtype): ...
class float8_e5m2x8(dtype): ...
class float8_e5m2x16(dtype): ...
class float8_e5m2x32(dtype): ...
class float8_e5m2x64(dtype): ...
class float8_e5m2fnuz(dtype): ...
class float8_e5m2fnuzx2(dtype): ...
class float8_e5m2fnuzx4(dtype): ...
class float8_e5m2fnuzx8(dtype): ...
class float8_e5m2fnuzx16(dtype): ...
class float8_e5m2fnuzx32(dtype): ...
class float8_e5m2fnuzx64(dtype): ...
class float8_e8m0fnu(dtype): ...
class float8_e8m0fnux2(dtype): ...
class float8_e8m0fnux4(dtype): ...
class float8_e8m0fnux8(dtype): ...
class float8_e8m0fnux16(dtype): ...
class float8_e8m0fnux32(dtype): ...
class float8_e8m0fnux64(dtype): ...
class float6_e2m3fn(dtype): ...
class float6_e2m3fnx2(dtype): ...
class float6_e2m3fnx4(dtype): ...
class float6_e2m3fnx8(dtype): ...
class float6_e2m3fnx16(dtype): ...
class float6_e2m3fnx32(dtype): ...
class float6_e2m3fnx64(dtype): ...
class float6_e3m2fn(dtype): ...
class float6_e3m2fnx2(dtype): ...
class float6_e3m2fnx4(dtype): ...
class float6_e3m2fnx8(dtype): ...
class float6_e3m2fnx16(dtype): ...
class float6_e3m2fnx32(dtype): ...
class float6_e3m2fnx64(dtype): ...
class float4_e2m1fn(dtype): ...
class float4_e2m1fnx2(dtype): ...
class float4_e2m1fnx4(dtype): ...
class float4_e2m1fnx8(dtype): ...
class float4_e2m1fnx16(dtype): ...
class float4_e2m1fnx32(dtype): ...
class float4_e2m1fnx64(dtype): ...
class bfloat16(dtype): ...
# yapf: enable
else:
bool = dtype('bool')
short = dtype('int16')
int = dtype('int32')
long = dtype('int64')
half = dtype('float16')
float = dtype('float32')
double = dtype('float64')
int8 = dtype('int8')
int16 = dtype('int16')
int32 = dtype('int32')
int64 = dtype('int64')
int8x4 = dtype('int8x4')
int16x4 = dtype('int16x4')
int32x4 = dtype('int32x4')
int64x4 = dtype('int64x4')
int8x8 = dtype('int8x8')
int16x8 = dtype('int16x8')
int32x8 = dtype('int32x8')
int64x8 = dtype('int64x8')
int8x16 = dtype('int8x16')
int16x16 = dtype('int16x16')
int32x16 = dtype('int32x16')
int64x16 = dtype('int64x16')
int8x32 = dtype('int8x32')
int16x32 = dtype('int16x32')
int32x32 = dtype('int32x32')
int64x32 = dtype('int64x32')
int8x64 = dtype('int8x64')
int16x64 = dtype('int16x64')
int32x64 = dtype('int32x64')
int64x64 = dtype('int64x64')
uint8 = dtype('uint8')
uint16 = dtype('uint16')
uint32 = dtype('uint32')
uint64 = dtype('uint64')
uint8x4 = dtype('uint8x4')
uint16x4 = dtype('uint16x4')
uint32x4 = dtype('uint32x4')
uint64x4 = dtype('uint64x4')
uint8x8 = dtype('uint8x8')
uint16x8 = dtype('uint16x8')
uint32x8 = dtype('uint32x8')
uint64x8 = dtype('uint64x8')
uint8x16 = dtype('uint8x16')
uint16x16 = dtype('uint16x16')
uint32x16 = dtype('uint32x16')
uint64x16 = dtype('uint64x16')
uint8x32 = dtype('uint8x32')
uint16x32 = dtype('uint16x32')
uint32x32 = dtype('uint32x32')
uint64x32 = dtype('uint64x32')
uint8x64 = dtype('uint8x64')
uint16x64 = dtype('uint16x64')
uint32x64 = dtype('uint32x64')
uint64x64 = dtype('uint64x64')
float16 = dtype('float16')
float32 = dtype('float32')
float64 = dtype('float64')
float16x2 = dtype('float16x2')
float32x2 = dtype('float32x2')
float64x2 = dtype('float64x2')
float16x4 = dtype('float16x4')
float32x4 = dtype('float32x4')
float64x4 = dtype('float64x4')
float16x8 = dtype('float16x8')
float32x8 = dtype('float32x8')
float64x8 = dtype('float64x8')
float16x16 = dtype('float16x16')
float32x16 = dtype('float32x16')
float64x16 = dtype('float64x16')
float16x32 = dtype('float16x32')
float32x32 = dtype('float32x32')
float64x32 = dtype('float64x32')
float16x64 = dtype('float16x64')
float32x64 = dtype('float32x64')
float64x64 = dtype('float64x64')
float8_e3m4 = dtype('float8_e3m4')
float8_e3m4x2 = dtype('float8_e3m4x2')
float8_e3m4x4 = dtype('float8_e3m4x4')
float8_e3m4x8 = dtype('float8_e3m4x8')
float8_e3m4x16 = dtype('float8_e3m4x16')
float8_e3m4x32 = dtype('float8_e3m4x32')
float8_e3m4x64 = dtype('float8_e3m4x64')
float8_e4m3 = dtype('float8_e4m3')
float8_e4m3x2 = dtype('float8_e4m3x2')
float8_e4m3x4 = dtype('float8_e4m3x4')
float8_e4m3x8 = dtype('float8_e4m3x8')
float8_e4m3x16 = dtype('float8_e4m3x16')
float8_e4m3x32 = dtype('float8_e4m3x32')
float8_e4m3x64 = dtype('float8_e4m3x64')
float8_e4m3b11fnuz = dtype('float8_e4m3b11fnuz')
float8_e4m3b11fnuzx2 = dtype('float8_e4m3b11fnuzx2')
float8_e4m3b11fnuzx4 = dtype('float8_e4m3b11fnuzx4')
float8_e4m3b11fnuzx8 = dtype('float8_e4m3b11fnuzx8')
float8_e4m3b11fnuzx16 = dtype('float8_e4m3b11fnuzx16')
float8_e4m3b11fnuzx32 = dtype('float8_e4m3b11fnuzx32')
float8_e4m3b11fnuzx64 = dtype('float8_e4m3b11fnuzx64')
float8_e4m3fn = dtype('float8_e4m3fn')
float8_e4m3fnx2 = dtype('float8_e4m3fnx2')
float8_e4m3fnx4 = dtype('float8_e4m3fnx4')
float8_e4m3fnx8 = dtype('float8_e4m3fnx8')
float8_e4m3fnx16 = dtype('float8_e4m3fnx16')
float8_e4m3fnx32 = dtype('float8_e4m3fnx32')
float8_e4m3fnx64 = dtype('float8_e4m3fnx64')
float8_e4m3fnuz = dtype('float8_e4m3fnuz')
float8_e4m3fnuzx2 = dtype('float8_e4m3fnuzx2')
float8_e4m3fnuzx4 = dtype('float8_e4m3fnuzx4')
float8_e4m3fnuzx8 = dtype('float8_e4m3fnuzx8')
float8_e4m3fnuzx16 = dtype('float8_e4m3fnuzx16')
float8_e4m3fnuzx32 = dtype('float8_e4m3fnuzx32')
float8_e4m3fnuzx64 = dtype('float8_e4m3fnuzx64')
float8_e5m2 = dtype('float8_e5m2')
float8_e5m2x2 = dtype('float8_e5m2x2')
float8_e5m2x4 = dtype('float8_e5m2x4')
float8_e5m2x8 = dtype('float8_e5m2x8')
float8_e5m2x16 = dtype('float8_e5m2x16')
float8_e5m2x32 = dtype('float8_e5m2x32')
float8_e5m2x64 = dtype('float8_e5m2x64')
float8_e5m2fnuz = dtype('float8_e5m2fnuz')
float8_e5m2fnuzx2 = dtype('float8_e5m2fnuzx2')
float8_e5m2fnuzx4 = dtype('float8_e5m2fnuzx4')
float8_e5m2fnuzx8 = dtype('float8_e5m2fnuzx8')
float8_e5m2fnuzx16 = dtype('float8_e5m2fnuzx16')
float8_e5m2fnuzx32 = dtype('float8_e5m2fnuzx32')
float8_e5m2fnuzx64 = dtype('float8_e5m2fnuzx64')
float8_e8m0fnu = dtype('float8_e8m0fnu')
float8_e8m0fnux2 = dtype('float8_e8m0fnux2')
float8_e8m0fnux4 = dtype('float8_e8m0fnux4')
float8_e8m0fnux8 = dtype('float8_e8m0fnux8')
float8_e8m0fnux16 = dtype('float8_e8m0fnux16')
float8_e8m0fnux32 = dtype('float8_e8m0fnux32')
float8_e8m0fnux64 = dtype('float8_e8m0fnux64')
float6_e2m3fn = dtype('float6_e2m3fn')
float6_e2m3fnx2 = dtype('float6_e2m3fnx2')
float6_e2m3fnx4 = dtype('float6_e2m3fnx4')
float6_e2m3fnx8 = dtype('float6_e2m3fnx8')
float6_e2m3fnx16 = dtype('float6_e2m3fnx16')
float6_e2m3fnx32 = dtype('float6_e2m3fnx32')
float6_e2m3fnx64 = dtype('float6_e2m3fnx64')
float6_e3m2fn = dtype('float6_e3m2fn')
float6_e3m2fnx2 = dtype('float6_e3m2fnx2')
float6_e3m2fnx4 = dtype('float6_e3m2fnx4')
float6_e3m2fnx8 = dtype('float6_e3m2fnx8')
float6_e3m2fnx16 = dtype('float6_e3m2fnx16')
float6_e3m2fnx32 = dtype('float6_e3m2fnx32')
float6_e3m2fnx64 = dtype('float6_e3m2fnx64')
float4_e2m1fn = dtype('float4_e2m1fn')
float4_e2m1fnx2 = dtype('float4_e2m1fnx2')
float4_e2m1fnx4 = dtype('float4_e2m1fnx4')
float4_e2m1fnx8 = dtype('float4_e2m1fnx8')
float4_e2m1fnx16 = dtype('float4_e2m1fnx16')
float4_e2m1fnx32 = dtype('float4_e2m1fnx32')
float4_e2m1fnx64 = dtype('float4_e2m1fnx64')
bfloat16 = dtype('bfloat16')
_all_dtypes = {
'bool',
'short',
'int',
'long',
'half',
'float',
'double',
'int8',
'int16',
'int32',
'int64',
'int8x4',
'int16x4',
'int32x4',
'int64x4',
'int8x8',
'int16x8',
'int32x8',
'int64x8',
'int8x16',
'int16x16',
'int32x16',
'int64x16',
'int8x32',
'int16x32',
'int32x32',
'int64x32',
'int8x64',
'int16x64',
'int32x64',
'int64x64',
'uint8',
'uint16',
'uint32',
'uint64',
'uint8x4',
'uint16x4',
'uint32x4',
'uint64x4',
'uint8x8',
'uint16x8',
'uint32x8',
'uint64x8',
'uint8x16',
'uint16x16',
'uint32x16',
'uint64x16',
'uint8x32',
'uint16x32',
'uint32x32',
'uint64x32',
'uint8x64',
'uint16x64',
'uint32x64',
'uint64x64',
'float16',
'float32',
'float64',
'float16x2',
'float32x2',
'float64x2',
'float16x4',
'float32x4',
'float64x4',
'float16x8',
'float32x8',
'float64x8',
'float16x16',
'float32x16',
'float64x16',
'float16x32',
'float32x32',
'float64x32',
'float16x64',
'float32x64',
'float64x64',
'float8_e3m4',
'float8_e3m4x2',
'float8_e3m4x4',
'float8_e3m4x8',
'float8_e3m4x16',
'float8_e3m4x32',
'float8_e3m4x64',
'float8_e4m3',
'float8_e4m3x2',
'float8_e4m3x4',
'float8_e4m3x8',
'float8_e4m3x16',
'float8_e4m3x32',
'float8_e4m3x64',
'float8_e4m3b11fnuz',
'float8_e4m3b11fnuzx2',
'float8_e4m3b11fnuzx4',
'float8_e4m3b11fnuzx8',
'float8_e4m3b11fnuzx16',
'float8_e4m3b11fnuzx32',
'float8_e4m3b11fnuzx64',
'float8_e4m3fn',
'float8_e4m3fnx2',
'float8_e4m3fnx4',
'float8_e4m3fnx8',
'float8_e4m3fnx16',
'float8_e4m3fnx32',
'float8_e4m3fnx64',
'float8_e4m3fnuz',
'float8_e4m3fnuzx2',
'float8_e4m3fnuzx4',
'float8_e4m3fnuzx8',
'float8_e4m3fnuzx16',
'float8_e4m3fnuzx32',
'float8_e4m3fnuzx64',
'float8_e5m2',
'float8_e5m2x2',
'float8_e5m2x4',
'float8_e5m2x8',
'float8_e5m2x16',
'float8_e5m2x32',
'float8_e5m2x64',
'float8_e5m2fnuz',
'float8_e5m2fnuzx2',
'float8_e5m2fnuzx4',
'float8_e5m2fnuzx8',
'float8_e5m2fnuzx16',
'float8_e5m2fnuzx32',
'float8_e5m2fnuzx64',
'float8_e8m0fnu',
'float8_e8m0fnux2',
'float8_e8m0fnux4',
'float8_e8m0fnux8',
'float8_e8m0fnux16',
'float8_e8m0fnux32',
'float8_e8m0fnux64',
'float6_e2m3fn',
'float6_e2m3fnx2',
'float6_e2m3fnx4',
'float6_e2m3fnx8',
'float6_e2m3fnx16',
'float6_e2m3fnx32',
'float6_e2m3fnx64',
'float6_e3m2fn',
'float6_e3m2fnx2',
'float6_e3m2fnx4',
'float6_e3m2fnx8',
'float6_e3m2fnx16',
'float6_e3m2fnx32',
'float6_e3m2fnx64',
'float4_e2m1fn',
'float4_e2m1fnx2',
'float4_e2m1fnx4',
'float4_e2m1fnx8',
'float4_e2m1fnx16',
'float4_e2m1fnx32',
'float4_e2m1fnx64',
'bfloat16',
}
__all__ = list(_all_dtypes) + [
'dtype',
'AnyDType',
'get_tvm_dtype',
]
from __future__ import annotations
import ast
import inspect
from typing import Any, Callable, Literal
from tilelang import env
from hashlib import sha256
import linecache
def disk_compile(source, name):
cache_dir = env.TILELANG_CACHE_DIR
if cache_dir is not None:
import os
save_dir = os.path.join(cache_dir, "py-cache")
os.makedirs(save_dir, exist_ok=True)
hash_sfx = sha256(source.encode('utf-8')).hexdigest()[:8]
path = os.path.join(save_dir, f"{name}.{hash_sfx}.py")
with open(path, 'w') as f:
f.write(source)
linecache.cache[path] = (len(source), None, source.splitlines(), path)
return compile(source, path, "exec")
def _remove_leading_ident(source: str):
lines = source.splitlines()
if not lines:
return source
ident_size = len(lines[0]) - len(lines[0].lstrip())
return "\n".join([line[ident_size:] if len(line) >= ident_size else line for line in lines])
def get_func_nonlocals(func):
"""A modified version of `inspect.getclosurevars`"""
if inspect.ismethod(func):
func = func.__func__
if not inspect.isfunction(func):
raise TypeError(f"{func!r} is not a Python function")
code = func.__code__
# Nonlocal references are named in co_freevars and resolved
# by looking them up in __closure__ by positional index
nonlocal_vars = {}
if func.__closure__ is not None:
for var, cell in zip(code.co_freevars, func.__closure__):
try:
nonlocal_vars[var] = cell.cell_contents
except ValueError as err:
# cell_contents may raise ValueError if the cell is empty.
if "empty" not in str(err):
raise
return nonlocal_vars
def inspect_function_capture(func: Callable) -> dict[str, Any]:
"""Capture function non-locals and global variables.
Parameters
----------
func : Callable
The function to inspect.
Returns
-------
res : Dict[str, Any]
The function variables map with non-local or global variables.
"""
captured = {
**func.__globals__, # type: ignore
**get_func_nonlocals(func),
}
return captured
def get_ast(func: Callable):
_, start = inspect.getsourcelines(func)
filename = inspect.getsourcefile(func) or inspect.getfile(func)
source = inspect.getsource(func)
source = _remove_leading_ident(source)
source = '\n' * (start - 1) + source
tree = ast.parse(source, filename=filename)
return tree
CompileMethod = Literal['direct', 'disk']
def get_compiled_object(source: str | ast.AST,
name: str,
filename: str = None,
globals: dict[str, Any] = None):
if isinstance(source, ast.AST):
assert filename is not None, "filename must be provided when source is an AST"
try:
if isinstance(source, ast.AST):
ast.fix_missing_locations(source)
compiled = compile(source, filename, 'exec')
else:
compiled = disk_compile(source, name)
except Exception as e:
source_str = source if isinstance(source, str) else ast.unparse(source)
raise RuntimeError(f'Failed to compile source for {name}, Error: {e}:\n{source_str}') from e
locs = {}
exec(compiled, globals, locs)
return locs[name]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment