Unverified Commit 5f202fe5 authored by Kurisu's avatar Kurisu Committed by GitHub
Browse files

[Language] Initial version of tilelang frontend v2 (#1120)



* tilelang frontend v2

* syntax sugar: defining a local var by annotation

* [Refactor] fix type linting warning like `T.float32`

* Add tl.local_var_init for new tl.float32

* allow passing default argument as function annotation

* allow default arguments as annotation

* fix lint error

* minor fix

* [Refactor] refactor tilelang.jit and tilelang.autotune

* minor fix

* minor fix

* minor fix

* fix metal get function name

* add par_compile impl and tests

* Type consistency on tvm datatype
1. isinstance(tl.float32, tvm.DataType) == True
2. Allow `tl.float32` as function annotations
3. Allow `tl.float32` as argument to be passed to `tl.alloc` or other functions

* fix lint error

* add more warning in frontend

* update tvm version

* Minor fix on tvm_ffi annotations

* add document and examples

* fix lint error

* Simplify index calculations in example_chunk_o_bwd.py

Refactor index calculations for dg_last_fragment assignment.

* minor fix

* lint fix

---------
Co-authored-by: default avatarLei Wang <leiwang1999@outlook.com>
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent ba390756
...@@ -7,8 +7,6 @@ import tilelang ...@@ -7,8 +7,6 @@ import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401 from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401
print(tilelang.__file__)
# Add your fla repository path to sys.path # Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae # Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention") # sys.path.insert(0, "/home/tzj/flash-linear-attention")
...@@ -256,8 +254,9 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -256,8 +254,9 @@ def tilelang_chunk_o_bwd_dqkwg(
# for i_kv in T.Parallel(block_DK * block_DV): # for i_kv in T.Parallel(block_DK * block_DV):
# dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV] # dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
for i_kv in T.Parallel(block_DK * block_DV): for i_kv in T.Parallel(block_DK * block_DV):
i_k, i_v = i_kv // block_DV, i_kv % block_DV dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv %
dg_last_fragment[i_kv] = h_shared[i_k, i_v] * dh_shared[i_k, i_v] block_DV] * dh_shared[i_kv // block_DV,
i_kv % block_DV]
T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False) T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False)
dg_last_local[0] += dg_last_fragment_scalar[0] dg_last_local[0] += dg_last_fragment_scalar[0]
......
import tilelang.testing
import tilelang
import torch
@tilelang.jit(
out_idx=-1, # create the output tensor during runtime
verbose=True,
)
def matmul_kernel_jit(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A=False,
trans_B=True,
in_dtype='float16',
out_dtype='float32',
accum_dtype='float32',
num_stages=2,
threads=128,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def test_par_compile():
configs = [
(1024, 1024, 1024, 128, 128, 32),
(2048, 2048, 2048, 256, 256, 64),
(4096, 4096, 4096, 64, 64, 128),
]
kernels = matmul_kernel_jit.par_compile(configs)
for (M, N, K, _, _, _), kernel in zip(configs, kernels):
A = torch.randn(M, K, dtype=torch.float16).cuda()
B = torch.randn(N, K, dtype=torch.float16).cuda()
ref = (A @ B.T).float()
C = kernel(A, B)
tilelang.testing.torch_assert_close(C, ref, rtol=1e-2, atol=1e-2)
if __name__ == "__main__":
tilelang.testing.main()
import tilelang
import tilelang.language as T
import torch
import tilelang.testing
import tvm
def test_argument():
@T.prim_func
def test_argument(
t_1: T.bool,
t_2: T.short,
t_3: T.int,
t_4: T.long,
t_5: T.half,
t_6: T.float,
t_7: T.long,
t_8: T.int8,
t_9: T.int16,
t_10: T.int32,
t_11: T.int64,
t_12: T.uint8,
t_13: T.uint16,
t_14: T.uint32,
t_15: T.uint64,
t_16: T.float8_e4m3fn,
t_17: T.float8_e4m3fnuz,
t_18: T.float8_e5m2,
t_19: T.float8_e5m2fnuz,
t_20: T.float8_e8m0fnu,
t_21: T.float16,
t_22: T.bfloat16,
t_23: T.float32,
t_24: T.float64,
):
pass
def test_expr():
from tilelang.language.v2.dtypes import _all_dtypes
errors = []
for name in _all_dtypes:
dtype = getattr(T, name)
assert isinstance(dtype, tvm.DataType), f"{dtype} is not tvm.DataType"
try:
dtype(1.0)
dtype()
except TypeError:
pass
except Exception:
errors.append(name)
assert not errors
# def test_var_decl_sugar():
# @T.prim_func
# def test_var_decl_sugar():
# with T.Kernel(128, 128) as (bx, by):
# var_1: T.bool = 1.0
# var_2: T.short = 1.0
# var_3: T.int = 1.0
# var_4: T.long = 1.0
# var_5: T.half = 1.0
# var_6: T.float = 1.0
# var_7: T.long = 1.0
# var_8: T.int8 = 1.0
# var_9: T.int16 = 1.0
# var_10: T.int32 = 1.0
# var_11: T.int64 = 1.0
# var_12: T.uint8 = 1.0
# var_13: T.uint16 = 1.0
# var_14: T.uint32 = 1.0
# var_15: T.uint64 = 1.0
# var_16: T.float8_e4m3fn = 1.0
# var_17: T.float8_e4m3fnuz = 1.0
# var_18: T.float8_e5m2 = 1.0
# var_19: T.float8_e5m2fnuz = 1.0
# var_20: T.float8_e8m0fnu = 1.0
# var_21: T.float16 = 1.0
# var_22: T.bfloat16 = 1.0
# var_23: T.float32 = 1.0
# var_24: T.float64 = 1.0
# var_1: T.bool = var_1
# var_2: T.short = var_2
# var_3: T.int = var_3
# var_4: T.long = var_4
# var_5: T.half = var_5
# var_6: T.float = var_6
# var_7: T.long = var_7
# var_8: T.int8 = var_8
# var_9: T.int16 = var_9
# var_10: T.int32 = var_10
# var_11: T.int64 = var_11
# var_12: T.uint8 = var_12
# var_13: T.uint16 = var_13
# var_14: T.uint32 = var_14
# var_15: T.uint64 = var_15
# var_16: T.float8_e4m3fn = var_16
# var_17: T.float8_e4m3fnuz = var_17
# var_18: T.float8_e5m2 = var_18
# var_19: T.float8_e5m2fnuz = var_19
# var_20: T.float8_e8m0fnu = var_20
# var_21: T.float16 = var_21
# var_22: T.bfloat16 = var_22
# var_23: T.float32 = var_23
# var_24: T.float64 = var_24
# s = test_var_decl_sugar.script()
# for i in range(1, 25):
# assert f'var_{i}_1' in s
# assert 'tl.local_var_init' in s
def test_dtype_str_repr():
@T.prim_func
def test_str_repr():
buf_1 = T.alloc_buffer((1,), dtype=T.bool, scope='shared') # noqa F841
buf_2 = T.alloc_buffer((1,), dtype=T.short, scope='shared') # noqa F841
buf_3 = T.alloc_buffer((1,), dtype=T.int, scope='shared') # noqa F841
buf_4 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841
buf_5 = T.alloc_buffer((1,), dtype=T.half, scope='shared') # noqa F841
buf_6 = T.alloc_buffer((1,), dtype=T.float, scope='shared') # noqa F841
buf_7 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841
buf_8 = T.alloc_buffer((1,), dtype=T.int8, scope='shared') # noqa F841
buf_9 = T.alloc_buffer((1,), dtype=T.int16, scope='shared') # noqa F841
buf_10 = T.alloc_buffer((1,), dtype=T.int32, scope='shared') # noqa F841
buf_11 = T.alloc_buffer((1,), dtype=T.int64, scope='shared') # noqa F841
buf_12 = T.alloc_buffer((1,), dtype=T.uint8, scope='shared') # noqa F841
buf_13 = T.alloc_buffer((1,), dtype=T.uint16, scope='shared') # noqa F841
buf_14 = T.alloc_buffer((1,), dtype=T.uint32, scope='shared') # noqa F841
buf_15 = T.alloc_buffer((1,), dtype=T.uint64, scope='shared') # noqa F841
buf_16 = T.alloc_buffer((1,), dtype=T.float8_e4m3fn, scope='shared') # noqa F841
buf_17 = T.alloc_buffer((1,), dtype=T.float8_e4m3fnuz, scope='shared') # noqa F841
buf_18 = T.alloc_buffer((1,), dtype=T.float8_e5m2, scope='shared') # noqa F841
buf_19 = T.alloc_buffer((1,), dtype=T.float8_e5m2fnuz, scope='shared') # noqa F841
buf_20 = T.alloc_buffer((1,), dtype=T.float8_e8m0fnu, scope='shared') # noqa F841
buf_21 = T.alloc_buffer((1,), dtype=T.float16, scope='shared') # noqa F841
buf_22 = T.alloc_buffer((1,), dtype=T.bfloat16, scope='shared') # noqa F841
buf_23 = T.alloc_buffer((1,), dtype=T.float32, scope='shared') # noqa F841
buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared') # noqa F841
def test_torch_eq():
dtypes = [
T.bool,
T.short,
T.int,
T.long,
T.half,
T.float,
T.long,
T.int8,
T.int16,
T.int32,
T.int64,
T.uint8,
T.uint16,
T.uint32,
T.uint64,
T.float8_e4m3fn,
T.float8_e4m3fnuz,
T.float8_e5m2,
T.float8_e5m2fnuz,
T.float8_e8m0fnu,
T.float16,
T.bfloat16,
T.float32,
T.float64,
]
torch_dtypes = [
torch.bool,
torch.short,
torch.int,
torch.long,
torch.half,
torch.float,
torch.long,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
]
for a, b in zip(dtypes, torch_dtypes):
assert a == b, f"{a} and {b} are not equal"
assert T.dtype(b) == a, "dtype conversion error"
def test_var_assign():
@tilelang.jit(out_idx=-1)
@T.prim_func
def test_var_assign(A: T.Tensor((2,), T.int32)):
with T.Kernel(1) as _:
a: T.int32 = 1
b: T.int32 = a
a = 2
d: T.int32 = a
A[0] = b
A[1] = d
res = test_var_assign()()
assert res[0] == 1
assert res[1] == 2
def test_marco_return():
@T.macro
def macro_return_constant():
return 0
@T.macro
def macro_return_frame(x):
return T.alloc_var(T.float32, init=x)
@T.macro
def macro_return_expr(x):
y = x + 1.0
return y
@T.macro
def macro_apply_func(x, fn):
return fn(x)
def check(x, ty):
assert isinstance(x, ty)
@T.prim_func
def test_macro_return():
with T.Kernel(1) as _:
a = macro_return_constant()
b = macro_return_frame(3.0)
c = macro_return_expr(4.0)
d = macro_apply_func(5.0, lambda x: x * 2.0)
check(a, (int, float, T.PrimExpr))
check(b, T.PrimExpr)
check(c, T.PrimExpr)
check(d, T.PrimExpr)
def test_prim_func_generator():
@T.prim_func(generator=True)
def prim_func_gen(
A=T.Tensor((128,), T.float32), # noqa: B008
B=T.Tensor((128,), T.float32), # noqa: B008
):
with T.Kernel(128) as (tx,):
T.copy(A[tx], B[tx])
prim_func_gen()
@T.prim_func
def foo() -> T.Tensor((128,), T.float32):
pass
assert isinstance(foo, T.PrimFunc)
if __name__ == '__main__':
tilelang.testing.main()
...@@ -11,7 +11,7 @@ def test_let_vectorize_load(): ...@@ -11,7 +11,7 @@ def test_let_vectorize_load():
for _blockIdx in T.thread_binding(1, thread="blockIdx.x"): for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
for _threadIdx in T.thread_binding(128, thread="threadIdx.x"): for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
b: T.float32x4 = A[0, 0:4] b = A[0, 0:4]
A[0, 4:8] = b A[0, 4:8] = b
mod = tvm.IRModule({"main": main}) mod = tvm.IRModule({"main": main})
......
...@@ -15,8 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): ...@@ -15,8 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
N = tvm.te.var("n") N = tvm.te.var("n")
K = tvm.te.var("k") K = tvm.te.var("k")
@tvm.script.ir.ir_module def before():
class Before:
@T.prim_func @T.prim_func
def main(B: T.Tensor((K, N), dtype),): def main(B: T.Tensor((K, N), dtype),):
...@@ -38,8 +37,9 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): ...@@ -38,8 +37,9 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
(block_N // vec_load_b) * (block_N // vec_load_b) + vec], (block_N // vec_load_b) * (block_N // vec_load_b) + vec],
T.float16(0)) T.float16(0))
@tvm.script.ir.ir_module return tvm.IRModule({'main': main})
class After:
def after():
@T.prim_func @T.prim_func
def main(B: T.Tensor((K, N), dtype),): def main(B: T.Tensor((K, N), dtype),):
...@@ -77,11 +77,13 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): ...@@ -77,11 +77,13 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
bx * block_N + t % (block_N // vec_load_b) * bx * block_N + t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec], T.float16(0)) (block_N // vec_load_b) + vec], T.float16(0))
return tvm.IRModule({'main': main})
with tvm.target.Target(auto_target): with tvm.target.Target(auto_target):
mod = tvm.tir.transform.BindTarget(auto_target)(Before) mod = tvm.tir.transform.BindTarget(auto_target)(before())
mod = tl.transform.LayoutInference()(mod) mod = tl.transform.LayoutInference()(mod)
mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.Simplify()(mod)
ref_mod = tvm.tir.transform.BindTarget(auto_target)(After) ref_mod = tvm.tir.transform.BindTarget(auto_target)(after())
ref_mod = tvm.tir.transform.Simplify()(ref_mod) ref_mod = tvm.tir.transform.Simplify()(ref_mod)
# Note(tzj): The structures are equal except one more "for" loop after the LayoutInference pass # Note(tzj): The structures are equal except one more "for" loop after the LayoutInference pass
# This loop is "for vec in T.parallel(1)", # This loop is "for vec in T.parallel(1)",
......
...@@ -15,8 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): ...@@ -15,8 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
N = tvm.te.var("n") N = tvm.te.var("n")
K = tvm.te.var("k") K = tvm.te.var("k")
@tvm.script.ir.ir_module def before():
class Before:
@T.prim_func @T.prim_func
def main(B: T.Tensor((K, N), dtype),): def main(B: T.Tensor((K, N), dtype),):
...@@ -25,8 +24,9 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): ...@@ -25,8 +24,9 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(B[k * block_K, bx * block_N], B_shared) T.copy(B[k * block_K, bx * block_N], B_shared)
@tvm.script.ir.ir_module return tvm.IRModule({'main': main})
class After:
def after():
@T.prim_func @T.prim_func
def main(B: T.Tensor((K, N), dtype),): def main(B: T.Tensor((K, N), dtype),):
...@@ -64,11 +64,13 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): ...@@ -64,11 +64,13 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
bx * block_N + t % (block_N // vec_load_b) * bx * block_N + t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec], T.float16(0)) (block_N // vec_load_b) + vec], T.float16(0))
return tvm.IRModule({'main': main})
with tvm.transform.PassContext(): with tvm.transform.PassContext():
mod = tvm.tir.transform.BindTarget(auto_target)(Before) mod = tvm.tir.transform.BindTarget(auto_target)(before())
mod = tl.transform.LowerTileOp()(mod) mod = tl.transform.LowerTileOp()(mod)
mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.Simplify()(mod)
ref_mod = tvm.tir.transform.BindTarget(auto_target)(After) ref_mod = tvm.tir.transform.BindTarget(auto_target)(after())
ref_mod = tvm.tir.transform.Simplify()(ref_mod) ref_mod = tvm.tir.transform.Simplify()(ref_mod)
# Note(tzj): The structures are equal except the argument in "T.reads" function. # Note(tzj): The structures are equal except the argument in "T.reads" function.
# The difference is just between the first index and the indices range, which is totally equivalent # The difference is just between the first index and the indices range, which is totally equivalent
......
...@@ -113,7 +113,7 @@ def test_multi_version_buffer_with_let(): ...@@ -113,7 +113,7 @@ def test_multi_version_buffer_with_let():
shared = T.alloc_buffer((8,), "float32", scope="shared.dyn") shared = T.alloc_buffer((8,), "float32", scope="shared.dyn")
accum = T.alloc_buffer((8,), "float32", scope="local") accum = T.alloc_buffer((8,), "float32", scope="local")
for k in T.serial(4, annotations={"num_stages": T.int32(2)}): for k in T.serial(4, annotations={"num_stages": T.int32(2)}):
value: T.float32 = scales[k] value = scales[k]
for i in T.serial(8): for i in T.serial(8):
shared[i] = value shared[i] = value
for i in T.serial(8): for i in T.serial(8):
...@@ -125,7 +125,7 @@ def test_multi_version_buffer_with_let(): ...@@ -125,7 +125,7 @@ def test_multi_version_buffer_with_let():
shared = T.alloc_buffer((2, 8), "float32", scope="shared.dyn") shared = T.alloc_buffer((2, 8), "float32", scope="shared.dyn")
accum = T.alloc_buffer((8,), "float32", scope="local") accum = T.alloc_buffer((8,), "float32", scope="local")
for k in T.serial(4, annotations={"num_stages": T.int32(2)}): for k in T.serial(4, annotations={"num_stages": T.int32(2)}):
value: T.float32 = scales[k] value = scales[k]
for i in T.serial(8): for i in T.serial(8):
shared[k % 2, i] = value shared[k % 2, i] = value
for i in T.serial(8): for i in T.serial(8):
......
...@@ -4,7 +4,7 @@ import ctypes ...@@ -4,7 +4,7 @@ import ctypes
import logging import logging
import warnings import warnings
from tqdm import tqdm from tqdm.auto import tqdm
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
......
...@@ -4,17 +4,19 @@ This module provides functionality for auto-tuning tilelang programs, including ...@@ -4,17 +4,19 @@ This module provides functionality for auto-tuning tilelang programs, including
and performance optimization through configuration search. and performance optimization through configuration search.
""" """
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
import tilelang import tilelang
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang.jit import JITImpl
from tilelang.jit.kernel import JITKernel
from tvm.tir import PrimFunc, Var from tvm.tir import PrimFunc, Var
from tvm.target import Target from tvm.target import Target
import inspect import inspect
from functools import partial from functools import partial
from typing import (Callable, Literal, Any, overload) from typing import (Callable, Generic, Literal, Any, ParamSpec, TypeVar)
from tqdm import tqdm from tqdm.auto import tqdm
import logging import logging
import functools
import concurrent.futures import concurrent.futures
import torch import torch
import os import os
...@@ -30,7 +32,6 @@ from tilelang import env ...@@ -30,7 +32,6 @@ from tilelang import env
from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult
from tilelang.autotuner.capture import get_autotune_inputs from tilelang.autotuner.capture import get_autotune_inputs
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.jit.param import _P, _RProg
from tilelang import __version__ from tilelang import __version__
...@@ -524,12 +525,12 @@ class AutoTuner: ...@@ -524,12 +525,12 @@ class AutoTuner:
# latency, ref_latency = target_fn(jit_kernel) # latency, ref_latency = target_fn(jit_kernel)
latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel) latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel)
except TimeoutException: except TimeoutException:
logger.info( logger.warning(
f"A timeout occurred while testing config {config}, checkout autotuner.log for more details" f"A timeout occurred while testing config {config}, checkout autotuner.log for more details"
) )
continue continue
except Exception: except Exception:
logger.info( logger.warning(
f"An error occurred while testing config {config}, checkout autotuner.log for more details" f"An error occurred while testing config {config}, checkout autotuner.log for more details"
) )
logger.debug(f"Error: {traceback.format_exc()}") logger.debug(f"Error: {traceback.format_exc()}")
...@@ -585,9 +586,13 @@ class AutoTuner: ...@@ -585,9 +586,13 @@ class AutoTuner:
return self.run() return self.run()
class _AutoTunerImplementation: _P = ParamSpec('_P')
# Overload __init__ to help type checkers understand the effect of return_program _T = TypeVar('_T')
# The '-> None' is for __init__ itself. The crucial part is Literal for return_program.
@dataclass
class AutoTuneImpl(Generic[_P, _T]):
jit_impl: JITImpl
warmup: int = 25 warmup: int = 25
rep: int = 100 rep: int = 100
...@@ -603,91 +608,12 @@ class _AutoTunerImplementation: ...@@ -603,91 +608,12 @@ class _AutoTunerImplementation:
manual_check_prog: Callable = None manual_check_prog: Callable = None
cache_input_tensors: bool = False cache_input_tensors: bool = False
def __init__(self, def __post_init__(self):
configs: dict | Callable, self._tuner_cache = {}
warmup: int = 25,
rep: int = 100,
timeout: int = 100,
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
ref_prog: Callable = None,
supply_prog: Callable = None,
rtol: float = 1e-2,
atol: float = 1e-2,
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = False) -> None:
"""Initialize the AutoTunerImplementation.
Args:
configs: Configuration space to explore during auto-tuning.
warmup: Number of warmup iterations before timing.
rep: Number of repetitions for timing measurements.
timeout: Maximum time (in seconds) allowed for each configuration.
supply_type: Strategy for generating input tensors (random/zeros/etc)
ref_prog: Reference implementation for validation
supply_prog: Custom function to provide input tensors
rtol: Relative tolerance for numerical validation
atol: Absolute tolerance for numerical validation
max_mismatched_ratio: Allowed percentage of mismatched values
skip_check: Bypass validation against reference implementation
manual_check_prog: Custom validation function
cache_input_tensors: Reuse input tensors across trials
"""
# Configuration and benchmarking parameters
self.configs = configs # Search space of tuning configurations
self.warmup = warmup # Warmup iterations for stable measurements
self.rep = rep # Measurement repetitions for statistics
self.timeout = timeout # Per-configuration timeout threshold
# Tensor handling and validation setup
self.supply_type = supply_type # Input tensor generation strategy
self.ref_prog = ref_prog # Ground truth implementation
self.supply_prog = supply_prog # Custom input data provider
self.rtol = rtol # Relative error tolerance
self.atol = atol # Absolute error tolerance
self.max_mismatched_ratio = max_mismatched_ratio # Allowed mismatch
# Validation control flags
self.skip_check = skip_check # Bypass accuracy verification
self.manual_check_prog = manual_check_prog # Custom validation
self.cache_input_tensors = cache_input_tensors # Reuse inputs
# Cache for storing tuned kernel implementations
self._tuner_cache: dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel
# This tells the type checker what the *wrapper* function will return.
# this is for linting, please do not remove it.
@overload
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, tuple[_RProg, AutotuneResult]]:
...
@overload
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, AutotuneResult]:
...
# Actual implementation of __call__
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Any]:
warmup = self.warmup
rep = self.rep
timeout = self.timeout
@functools.wraps(fn)
def wrapper(*args, **kwargs):
key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
key = (key_args_tuple, key_kwargs_tuple)
if key not in self._tuner_cache:
def jit_compile(**config_arg):
return fn(*args, **kwargs, __tune_params=config_arg)
compile_arguments = fn(__return_compile_arguments=True)
def get_tunner(self):
autotuner = AutoTuner( autotuner = AutoTuner(
fn, configs=self.configs).set_profile_args( self.jit_impl.func, configs=self.configs).set_profile_args(
supply_type=self.supply_type, supply_type=self.supply_type,
ref_prog=self.ref_prog, ref_prog=self.ref_prog,
supply_prog=self.supply_prog, supply_prog=self.supply_prog,
...@@ -698,30 +624,35 @@ class _AutoTunerImplementation: ...@@ -698,30 +624,35 @@ class _AutoTunerImplementation:
manual_check_prog=self.manual_check_prog, manual_check_prog=self.manual_check_prog,
cache_input_tensors=self.cache_input_tensors, cache_input_tensors=self.cache_input_tensors,
).set_compile_args( ).set_compile_args(
out_idx=compile_arguments['out_idx'], out_idx=self.jit_impl.out_idx,
execution_backend=compile_arguments['execution_backend'], execution_backend=self.jit_impl.execution_backend,
target=compile_arguments['target'], target=self.jit_impl.target,
target_host=compile_arguments['target_host'], target_host=self.jit_impl.target_host,
verbose=compile_arguments['verbose'], verbose=self.jit_impl.verbose,
pass_configs=compile_arguments['pass_configs'], pass_configs=self.jit_impl.pass_configs,
) )
autotuner.run = partial(autotuner.run, self.warmup, self.rep, self.timeout)
return autotuner
autotuner.jit_compile = jit_compile def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel:
autotuner.set_kernel_parameters(key, inspect.signature(fn).parameters) key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
key = (key_args_tuple, key_kwargs_tuple)
if key not in self._tuner_cache:
autotuner.run = partial(autotuner.run, warmup, rep, timeout) def jit_compile(**config_arg):
return self.jit_impl(*args, **kwargs, __tune_params=config_arg)
autotuner = self.get_tunner()
autotuner.jit_compile = jit_compile
autotuner.set_kernel_parameters(key, self.jit_impl.signature.parameters)
artifact = autotuner.run() artifact = autotuner.run()
self._tuner_cache[key] = artifact.kernel self._tuner_cache[key] = artifact.kernel
return self._tuner_cache[key] return self._tuner_cache[key]
return wrapper
def autotune( # This is the new public interface def autotune( # This is the new public interface
func: Callable[_P, _RProg] | PrimFunc | None = None, func: Callable[_P, _T] | PrimFunc | None = None,
*, # Indicates subsequent arguments are keyword-only *, # Indicates subsequent arguments are keyword-only
configs: dict | Callable, configs: dict | Callable,
# profile arguments # profile arguments
...@@ -795,10 +726,13 @@ def autotune( # This is the new public interface ...@@ -795,10 +726,13 @@ def autotune( # This is the new public interface
elif isinstance(func, PrimFunc): elif isinstance(func, PrimFunc):
raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.") raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.")
else: else:
# Case 2: Used as @autotune(...) to configure, or func_or_out_idx is meant as out_idx.
# Create a _AutoTunerImplementation instance with the provided/defaulted arguments. def decorator(impl):
# This instance is a decorator that will be applied to the function later. assert isinstance(
configured_decorator = _AutoTunerImplementation( impl, JITImpl
), "The @autotune decorator can only be applied to @tilelang.jit decorated instances."
return AutoTuneImpl(
jit_impl=impl,
configs=configs, configs=configs,
warmup=warmup, warmup=warmup,
rep=rep, rep=rep,
...@@ -813,4 +747,5 @@ def autotune( # This is the new public interface ...@@ -813,4 +747,5 @@ def autotune( # This is the new public interface
manual_check_prog=manual_check_prog, manual_check_prog=manual_check_prog,
cache_input_tensors=cache_input_tensors, cache_input_tensors=cache_input_tensors,
) )
return configured_decorator
return decorator
...@@ -5,15 +5,21 @@ kernel adapter using TVM. ...@@ -5,15 +5,21 @@ kernel adapter using TVM.
""" """
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
import inspect
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Generic,
Iterable,
ParamSpec,
TypeVar,
overload, overload,
Literal, Literal,
) )
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang.language.v2 import PrimFunc
from tilelang.jit.adapter.utils import is_metal_target from tilelang.jit.adapter.utils import is_metal_target
from tvm.tir import PrimFunc
from tvm.target import Target from tvm.target import Target
from tilelang.jit.kernel import JITKernel from tilelang.jit.kernel import JITKernel
...@@ -21,14 +27,20 @@ from tilelang.utils.target import determine_target ...@@ -21,14 +27,20 @@ from tilelang.utils.target import determine_target
from tilelang.cache import cached from tilelang.cache import cached
from os import path, makedirs from os import path, makedirs
from logging import getLogger from logging import getLogger
import functools from tilelang.jit.param import Kernel
from tilelang.jit.param import Kernel, _P, _RProg import concurrent.futures
from tqdm.auto import tqdm
logger = getLogger(__name__) logger = getLogger(__name__)
_P = ParamSpec('_P')
_KP = ParamSpec('_KP')
_T = TypeVar('_T')
def compile( def compile(
func: PrimFunc = None, func: PrimFunc[_KP, _T] = None,
out_idx: list[int] | int | None = None, out_idx: list[int] | int | None = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
target: str | Target = "auto", target: str | Target = "auto",
...@@ -36,7 +48,7 @@ def compile( ...@@ -36,7 +48,7 @@ def compile(
verbose: bool = False, verbose: bool = False,
pass_configs: dict[str, Any] | None = None, pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | str | None = None, compile_flags: list[str] | str | None = None,
) -> JITKernel: ) -> JITKernel[_KP, _T]:
""" """
Compile the given TileLang PrimFunc with TVM and build a JITKernel. Compile the given TileLang PrimFunc with TVM and build a JITKernel.
Parameters Parameters
...@@ -79,98 +91,159 @@ def compile( ...@@ -79,98 +91,159 @@ def compile(
) )
class _JitImplementation: def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
out_idx: list[int] | int | None = None,
out_idx: list[int] | int | None execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
target: str | Target
target_host: str | Target
execution_backend: Literal["dlpack", "ctypes", "cython"]
verbose: bool
pass_configs: dict[str, Any] | None
debug_root_path: str | None
compile_flags: list[str] | str | None
def __init__(self,
out_idx: Any = None,
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target = None, target_host: str | Target | None = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False, verbose: bool = False,
pass_configs: dict[str, Any] | None = None, pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None, compile_flags: list[str] | str | None = None,
compile_flags: list[str] | str | None = None): num_workers: int = None,
ignore_error: bool = False) -> list[JITKernel[_KP, _T]]:
""" """
Initializes the JIT compiler decorator. Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.
Parameters Parameters
---------- ----------
out_idx : Any, optional funcs : Iterable[tvm.tir.PrimFunc]
Index(es) of the output tensors to return from the compiled kernel The TileLang TIR functions to compile and wrap.
(default: None, meaning all outputs are returned or determined by the kernel itself). out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None).
execution_backend : Literal["dlpack", "ctypes", "cython", "nvrtc"], optional
Execution backend to use for kernel execution (default: "cython").
target : Union[str, Target], optional target : Union[str, Target], optional
Compilation target for TVM. Can be a string (e.g., "cuda", "llvm") Compilation target, either as a string or a TVM Target object (default: "auto").
or a TVM Target object. If "auto", the target is determined automatically
(default: "auto").
target_host : Union[str, Target], optional target_host : Union[str, Target], optional
Target host for cross-compilation, similar to `target` (default: None). Target host for cross-compilation (default: None).
execution_backend : Literal["dlpack", "ctypes", "cython"], optional
The backend used for kernel execution and argument passing.
"dlpack" is generally preferred for zero-copy tensor passing with compatible frameworks.
"ctypes" uses standard C types. "cython" uses Cython for potentially faster execution.
(default: "cython").
verbose : bool, optional verbose : bool, optional
If True, enables verbose logging during compilation (default: False). Whether to enable verbose output (default: False).
pass_configs : Optional[Dict[str, Any]], optional pass_configs : dict, optional
A dictionary of configurations for TVM's pass context. These can fine-tune Additional keyword arguments to pass to the Compiler PassContext.
the compilation process. Examples include "tir.disable_vectorize" Refer to `tilelang.transform.PassConfigKey` for supported options.
(default: None).
debug_root_path : Optional[str], optional
If provided, the compiled kernel's source code will be saved to a file
in this directory. This is useful for debugging the generated code.
If None, no debug information is saved (default: None).
If a relative path is given, it's made absolute relative to the project root
or current working directory.
compile_flags : Optional[Union[List[str], str]], optional
Additional compilation flags to pass to the compiler.
If None, no additional compilation flags are passed (default: None).
""" """
self.out_idx = out_idx with concurrent.futures.ThreadPoolExecutor(num_workers, 'tl-par-comp') as executor:
self.execution_backend = execution_backend futures = []
self.target = target future_map = {}
self.target_host = target_host for i, func in enumerate(funcs):
self.verbose = verbose future = executor.submit(
self.pass_configs = pass_configs compile,
self.compile_flags = compile_flags func=func,
out_idx=out_idx,
# Corrected debug_root_path handling execution_backend=execution_backend,
self.debug_root_path = debug_root_path target=target,
target_host=target_host,
verbose=verbose,
pass_configs=pass_configs,
compile_flags=compile_flags,
)
future_map[future] = i
futures.append(future)
results = [... for _ in futures]
for future in tqdm(
concurrent.futures.as_completed(futures),
total=len(futures),
desc="Parallel Compiling",
):
idx = future_map[future]
if ignore_error:
try:
results[idx] = future.result()
except Exception as e:
logger.warning(f"Error compiling function at index {idx}: {e}")
results[idx] = None
else:
results[idx] = future.result()
return results
return results
@dataclass
class JITImpl(Generic[_P, _KP, _T]):
func: Callable[_P, _T] | PrimFunc[_KP, _T]
out_idx: list[int] | int | None
execution_backend: Literal["dlpack", "ctypes", "cython"]
target: str | Target
target_host: str | Target
verbose: bool
pass_configs: dict[str, Any] | None
debug_root_path: str | None
compile_flags: list[str] | str | None
func_source: str
signature: inspect.Signature
def __post_init__(self):
if self.debug_root_path is not None and not path.isabs(self.debug_root_path): if self.debug_root_path is not None and not path.isabs(self.debug_root_path):
try: try:
base_path = path.dirname(path.dirname(path.dirname(__file__))) base_path = path.dirname(path.dirname(path.dirname(__file__)))
self.debug_root_path = path.join(base_path, self.debug_root_path) self.debug_root_path = path.join(base_path, self.debug_root_path)
except NameError: except NameError:
self.debug_root_path = path.abspath(self.debug_root_path) self.debug_root_path = path.abspath(self.debug_root_path)
self._kernel_cache: dict[tuple, Kernel] = {} self._kernel_cache: dict[tuple, Kernel] = {}
# This tells the type checker what the *wrapper* function will return. def get_tir(self, *args: _P.args, **kwargs: _P.kwargs) -> PrimFunc[_KP, _T]:
# this is for linting, please do not remove it. program_result_source = self.func
@overload if isinstance(program_result_source, PrimFunc):
def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, tuple[_RProg, Kernel]]: program_result = program_result_source
... elif callable(program_result_source):
program_result = program_result_source(*args, **kwargs)
else:
raise ValueError(f"Invalid function type: {type(program_result_source)}")
return program_result
def par_compile(self,
configs: Iterable[dict[str, Any] | tuple[str, Any]],
num_workers: int = None,
ignore_error: bool = False) -> list[JITKernel[_KP, _T]]:
configs = list(configs)
funcs = []
for cfg in tqdm(configs, desc='Elaborating'):
if isinstance(cfg, tuple):
funcs.append(self.get_tir(*cfg))
elif isinstance(cfg, dict):
funcs.append(self.get_tir(**cfg))
else:
raise ValueError(f"Invalid config type: {type(cfg)}, expected tuple or dict.")
return par_compile(
funcs,
out_idx=self.out_idx,
execution_backend=self.execution_backend,
target=self.target,
target_host=self.target_host,
verbose=self.verbose,
pass_configs=self.pass_configs,
compile_flags=self.compile_flags,
num_workers=num_workers,
ignore_error=ignore_error)
@overload def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel[_KP, _T]:
def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Kernel]: func = self.get_tir(*args, **kwargs)
... kernel_result = compile(
func,
out_idx=self.out_idx,
execution_backend=self.execution_backend,
target=self.target,
target_host=self.target_host,
verbose=self.verbose,
pass_configs=self.pass_configs,
compile_flags=self.compile_flags,
)
if self.debug_root_path:
if isinstance(self.func, PrimFunc):
func_name = self.func.attrs['global_symbol']
else:
func_name = getattr(self.func, '__name__', 'jit_kernel')
kernel_file = f'tilelang_jit_kernel_{func_name}.c'
program_file = f'tilelang_jit_program_{func_name}.py'
makedirs(self.debug_root_path, exist_ok=True)
with open(path.join(self.debug_root_path, kernel_file), 'w') as f:
print(kernel_result.get_kernel_source(), file=f)
with open(path.join(self.debug_root_path, program_file), 'w') as f:
print(func.script(), file=f)
# Actual implementation of __call__ return kernel_result
def __call__(
self,
func: Callable[_P, _RProg] # func is Union[Callable[_P, _RProg], PrimFunc] in original
) -> Callable[_P, Any]:
@functools.wraps(func) def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel[_KP, _T]:
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
# Separate out the tuning parameters from the user's kwargs # Separate out the tuning parameters from the user's kwargs
tune_params = kwargs.pop('__tune_params', {}) tune_params = kwargs.pop('__tune_params', {})
# Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache # Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache
...@@ -193,45 +266,33 @@ class _JitImplementation: ...@@ -193,45 +266,33 @@ class _JitImplementation:
key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple) key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple)
if key not in self._kernel_cache: if key not in self._kernel_cache:
# Ensure 'func' (the original user function) is used correctly self._kernel_cache[key] = self.compile(*args, **kwargs, **tune_params)
program_result_source = func
if isinstance(program_result_source, PrimFunc):
program_result = program_result_source
elif callable(program_result_source):
program_result = program_result_source(*args, **kwargs, **tune_params)
else:
raise ValueError(f"Invalid function type: {type(program_result_source)}")
kernel_result = compile( return self._kernel_cache[key]
program_result,
out_idx=self.out_idx,
execution_backend=self.execution_backend,
target=self.target,
target_host=self.target_host,
verbose=self.verbose,
pass_configs=self.pass_configs,
compile_flags=self.compile_flags,
)
if self.debug_root_path:
func_name = getattr(func, '__name__', 'jit_kernel') # Use func for name
kernel_file = f'tilelang_jit_kernel_{func_name}.c'
program_file = f'tilelang_jit_program_{func_name}.py'
makedirs(self.debug_root_path, exist_ok=True)
with open(path.join(self.debug_root_path, kernel_file), 'w') as f:
print(kernel_result.get_kernel_source(), file=f)
with open(path.join(self.debug_root_path, program_file), 'w') as f:
print(program_result.script(), file=f)
self._kernel_cache[key] = kernel_result @overload
def jit(func: Callable[_P, PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T]:
...
return self._kernel_cache[key]
return wrapper @overload
def jit(
*, # Indicates subsequent arguments are keyword-only
out_idx: Any = None,
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None
) -> Callable[[Callable[_P, PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T]]:
...
def jit( # This is the new public interface def jit( # This is the new public interface
func: Callable[_P, _RProg] | PrimFunc | None = None, func: Callable[_P, _T] | PrimFunc | None = None,
*, # Indicates subsequent arguments are keyword-only *, # Indicates subsequent arguments are keyword-only
out_idx: Any = None, out_idx: Any = None,
target: str | Target = "auto", target: str | Target = "auto",
...@@ -275,32 +336,26 @@ def jit( # This is the new public interface ...@@ -275,32 +336,26 @@ def jit( # This is the new public interface
if isinstance(compile_flags, str): if isinstance(compile_flags, str):
compile_flags = [compile_flags] compile_flags = [compile_flags]
if callable(func): def decorator(func: Callable[_P, _T]) -> JITImpl[_P, _T]:
# Case 1: Used as @jit (func_or_out_idx is the function, others are defaults) if isinstance(func, PrimFunc):
# Create a default _JitImplementation instance and apply it to the function. orig_func = func.orig_func
default_decorator = _JitImplementation(
out_idx=out_idx, # Explicitly None for the default case
target=target,
target_host=target_host,
execution_backend=execution_backend,
verbose=verbose,
pass_configs=pass_configs,
debug_root_path=debug_root_path,
compile_flags=compile_flags)
return default_decorator(func)
elif isinstance(func, PrimFunc):
raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.")
else: else:
# Case 2: Used as @jit(...) to configure, or func_or_out_idx is meant as out_idx. orig_func = func
# Create a _JitImplementation instance with the provided/defaulted arguments. return JITImpl(
# This instance is a decorator that will be applied to the function later. func,
configured_decorator = _JitImplementation( out_idx=out_idx,
out_idx=out_idx, # Pass along; could be an actual out_idx or None execution_backend=execution_backend,
target=target, target=target,
target_host=target_host, target_host=target_host,
execution_backend=execution_backend,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
debug_root_path=debug_root_path, debug_root_path=debug_root_path,
compile_flags=compile_flags) compile_flags=compile_flags,
return configured_decorator func_source=inspect.getsource(orig_func),
signature=inspect.signature(orig_func),
)
if func is not None:
return decorator(func)
else:
return decorator
...@@ -27,7 +27,11 @@ class MetalKernelAdapter(BaseKernelAdapter): ...@@ -27,7 +27,11 @@ class MetalKernelAdapter(BaseKernelAdapter):
# compile_flags: Optional[List[str]] = None # compile_flags: Optional[List[str]] = None
): ):
self.kernel_global_source = kernel_global_source self.kernel_global_source = kernel_global_source
self.kernel_name = func_or_mod.__name__ + '_kernel' if isinstance(func_or_mod, tir.PrimFunc):
func_name = func_or_mod.attrs['global_symbol']
else:
func_name = func_or_mod.__name__
self.kernel_name = func_name + '_kernel'
self.verbose = verbose self.verbose = verbose
self.block_info = [1, 1, 1] self.block_info = [1, 1, 1]
...@@ -43,7 +47,7 @@ class MetalKernelAdapter(BaseKernelAdapter): ...@@ -43,7 +47,7 @@ class MetalKernelAdapter(BaseKernelAdapter):
self.grid_info["xyz".index(tag[-1])] = extent self.grid_info["xyz".index(tag[-1])] = extent
break break
else: else:
raise AssertionError(f'no kernel with name {func_or_mod.__name__}') raise AssertionError(f'no kernel with name {func_name}')
# print(self.block_info, self.grid_info) # print(self.block_info, self.grid_info)
super().__init__(func_or_mod, result_idx=result_idx, params=params) super().__init__(func_or_mod, result_idx=result_idx, params=params)
......
from __future__ import annotations from __future__ import annotations
from typing import Any, Callable, Literal from typing import Any, Callable, Generic, Literal, ParamSpec, TypeVar
from tilelang.jit.adapter.utils import is_metal_target from tilelang.jit.adapter.utils import is_metal_target
from tvm.target import Target from tvm.target import Target
...@@ -17,8 +17,11 @@ import logging ...@@ -17,8 +17,11 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_P = ParamSpec('_P')
_T = TypeVar('_T')
class JITKernel:
class JITKernel(Generic[_P, _T]):
""" """
A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions. A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions.
...@@ -170,7 +173,7 @@ class JITKernel: ...@@ -170,7 +173,7 @@ class JITKernel:
instance.torch_function = instance.adapter.func instance.torch_function = instance.adapter.func
return instance return instance
def __call__(self, *args: Any, **kwds: Any) -> Any: def __call__(self, *args: _P.args, **kwds: _P.kwargs) -> _T:
""" """
Invokes the compiled function with the given arguments. Invokes the compiled function with the given arguments.
......
...@@ -8,9 +8,9 @@ from __future__ import annotations ...@@ -8,9 +8,9 @@ from __future__ import annotations
# upstream tir script is fully compatible # upstream tir script is fully compatible
from tvm.script.parser.tir import * from tvm.script.parser.tir import *
from . import overrides as _overrides # noqa: F401 from . import overrides as _overrides # noqa: F401
from .tir import (
prim_func, # noqa: F401 # from .tir import prim_func, macro, # noqa: F401
) from .v2 import * # noqa: F401
from .tir.ir import * # noqa: F401 from .tir.ir import * # noqa: F401
from tilelang.layout import Layout, Fragment # noqa: F401 from tilelang.layout import Layout, Fragment # noqa: F401
from .proxy import ( from .proxy import (
......
...@@ -7,7 +7,6 @@ from tilelang.utils import deprecated ...@@ -7,7 +7,6 @@ from tilelang.utils import deprecated
__all__ = ["dynamic", "symbolic"] __all__ = ["dynamic", "symbolic"]
@deprecated("T.dynamic(...)", "tir.Var(...)", "v0.1.9")
def dynamic(name: str, dtype: str = "int32"): def dynamic(name: str, dtype: str = "int32"):
""" """
Create a TIR dynamic symbolic variable. Create a TIR dynamic symbolic variable.
...@@ -22,7 +21,7 @@ def dynamic(name: str, dtype: str = "int32"): ...@@ -22,7 +21,7 @@ def dynamic(name: str, dtype: str = "int32"):
return tir.Var(name, dtype) return tir.Var(name, dtype)
@deprecated("T.symbolic(...)", "T.dynamic(...)") @deprecated("T.symbolic(...)", "T.dynamic(...)", "v0.1.9")
def symbolic(name: str, dtype: str = "int32"): def symbolic(name: str, dtype: str = "int32"):
"""Deprecated alias for `T.dynamic`.""" """Deprecated alias for `T.dynamic`."""
return tir.Var(name, dtype) return tir.Var(name, dtype)
from .builder import prim_func, macro, PrimFunc # noqa: F401
from .dtypes import *
from __future__ import annotations
import ast
from dataclasses import dataclass
from typing import Callable, ContextManager, Generic, Iterable, Any, Literal, ParamSpec, TypeVar
import inspect
# from .utils import get_ast, get_compiled_object
from . import utils
_span_attrs = ['lineno', 'col_offset', 'end_lineno', 'end_col_offset']
def ast_has_span(ast: ast.AST) -> bool:
return all(hasattr(ast, attr) for attr in _span_attrs)
def ast_get_span(ast: ast.AST) -> tuple[int, int, int, int]:
if not ast_has_span(ast):
return None
return tuple(getattr(ast, attr) for attr in _span_attrs)
def ast_set_span(ast: ast.AST, span: tuple[int, int, int, int]):
if not ast_has_span(ast):
return
for attr, value in zip(_span_attrs, span):
setattr(ast, attr, value)
class QuoteVisitor(ast.NodeTransformer):
def __init__(self, names: dict[str, ast.AST], passes: list[Any] | None = None, span=None):
self.names = names
self.passes = passes or []
self.span = span
def generic_visit(self, node: ast.AST):
if self.span is not None:
ast_set_span(node, self.span)
return super().generic_visit(node)
def visit_Name(self, node: ast.Name) -> Any:
if node.id in self.names:
return self.names[node.id]
else:
return node
def visit_Pass(self, node: ast.Pass) -> Any:
item = self.passes.pop(0)
return item if item else node
def quote(expr: str, *, passes: list[Any] | None = None, span=None, **kws) -> list[ast.AST]:
tree = ast.parse(expr)
if isinstance(span, ast.AST):
span = ast_get_span(span)
tree = QuoteVisitor(kws, passes, span).visit(tree)
return tree.body
def quote1(expr: str, *, passes: list[Any] | None = None, span=None, **kws) -> ast.AST:
res = quote(expr, passes=passes, span=span, **kws)
assert len(res) == 1
return res[0]
def quote_expr(expr: str, **kws) -> ast.expr:
res = quote1(expr, **kws)
assert isinstance(res, ast.Expr)
return res.value
Operator = Literal['Add', 'Sub', 'Mult', 'MatMult', 'Div', 'Mod', 'Pow', 'LShift', 'RShift',
'BitOr', 'BitXor', 'BitAnd', 'FloorDiv']
BoolOp = Literal['And', 'Or']
def get_operator_name(operator: ast.operator) -> Operator:
return operator.__class__.__name__
def get_boolop_name(boolop: ast.boolop) -> BoolOp:
return boolop.__class__.__name__
_T = TypeVar('_T')
def eval_op(op: Operator, left: Any, right: Any) -> Any:
if op == 'Add':
return left + right
if op == 'Sub':
return left - right
if op == 'Mult':
return left * right
if op == 'MatMult':
return left @ right
if op == 'Div':
return left / right
if op == 'Mod':
return left % right
if op == 'Pow':
return left**right
if op == 'LShift':
return left << right
if op == 'RShift':
return left >> right
if op == 'BitOr':
return left | right
if op == 'BitXor':
return left ^ right
if op == 'BitAnd':
return left & right
if op == 'FloorDiv':
return left // right
raise ValueError(f'Unknown operator: {op}')
def eval_aug_assign(op: Operator, left: Any, sl: slice, right: Any) -> Any:
if op == 'Add':
left[sl] += right
return left
if op == 'Sub':
left[sl] -= right
return left
if op == 'Mult':
left[sl] *= right
return left
if op == 'MatMult':
left[sl] @= right
return left
if op == 'Div':
left[sl] /= right
return left
if op == 'Mod':
left[sl] %= right
return left
if op == 'Pow':
left[sl] **= right
return left
if op == 'LShift':
left[sl] <<= right
return left
if op == 'RShift':
left[sl] >>= right
return left
if op == 'BitOr':
left[sl] |= right
return left
if op == 'BitXor':
left[sl] ^= right
return left
if op == 'BitAnd':
left[sl] &= right
return left
if op == 'FloorDiv':
left[sl] //= right
return left
raise ValueError(f'Unknown operator: {op}')
class _empty:
...
class BaseBuilder:
empty = _empty
def get_parent_locals(self):
return inspect.currentframe().f_back.f_back.f_locals
def ctx_if(self, cond) -> Iterable[_T]:
yield cond
def ctx_then(self, val: _T) -> Iterable[None]:
if val:
yield
def ctx_else(self, val: _T) -> Iterable[None]:
if not val:
yield
def eval(self, val: Any): # noqa: B027
pass
def ctx_for(self, range: Iterable[Any]) -> Iterable[Any]:
return range
def ctx_continue(self) -> bool:
return True
def ctx_break(self) -> bool:
return True
def ctx_while(self, cond: Callable[[], Any]) -> Iterable[None]:
while cond():
yield
def bind(self, name: str, value: Any, annot: Any = empty) -> Any:
return value
def unwrap_value(self, value):
return value
def assign_slice(self, lval: Any, sl: slice, value: Any, annot: Any = empty):
lval[sl] = value
def aug_assign(self, op: Operator, target: Any, aug_value: Any) -> Any:
return eval_op(op, target, aug_value)
def aug_assign_slice(self, op: Operator, target: Any, sl: slice, aug_value: Any):
eval_aug_assign(op, target, sl, aug_value)
def boolop(self, op: BoolOp, left: Any, right: Callable[[], Any]) -> Any:
if op == 'And':
return left and right()
if op == 'Or':
return left or right()
raise ValueError(f'Unknown boolop: {op}')
def ifexp(self, cond: Any, then: Callable[[], Any], otherwise: Callable[[], Any]) -> Any:
return then() if cond else otherwise()
def ret(self, value: Any) -> Any:
return value
def ctx_with(self, ctx: ContextManager[Any]) -> ContextManager[Any]:
return ctx
def assert_expr(self, cond: Any, msg: Any):
assert cond, msg
def rval(self, name: str, value: Any):
return value
def arg(self, name: str, value: Any):
return value
def override(self, name: str):
return globals()[name]
class DSLMutator(ast.NodeTransformer):
def __init__(self):
self.tmp_counter = 0
def get_tmp(self) -> str:
name = f"__{self.tmp_counter}"
self.tmp_counter += 1
return name
def visit_If(self, node: ast.If):
node = self.generic_visit(node)
br = self.get_tmp()
if len(node.orelse) == 0:
return quote(
f"for {br} in __tb.ctx_if(cond):\n"
f" for _ in __tb.ctx_then({br}):\n"
" pass\n",
cond=node.test,
passes=[node.body],
span=node,
)
return quote(
f"for {br} in __tb.ctx_if(cond):\n"
f" for _ in __tb.ctx_then({br}):\n"
f" pass\n"
f" for _ in __tb.ctx_else({br}):\n"
f" pass\n",
cond=node.test,
passes=[node.body, node.orelse],
span=node,
)
def visit_Expr(self, node: ast.Expr):
node = self.generic_visit(node)
return quote("__tb.eval(value)", value=node.value, span=node)
def _parse_names(self, target: ast.expr):
if isinstance(target, ast.Name):
return f"'{target.id}'"
elif isinstance(target, ast.Tuple):
return ("(" + ",".join([self._parse_names(elt) for elt in target.elts]) + ",)")
else:
s = ast.unparse(target)
raise NotImplementedError(f"Unsupported for target `{s}`")
def visit_For(self, node: ast.For):
node = self.generic_visit(node)
tmp = self.get_tmp()
# names = self._parse_names(node.target)
var = ast.Name(tmp, ctx=ast.Load())
ast_set_span(var, ast_get_span(node.target))
stmts = self._emit_assign_target(node.target, var)
return quote(
f"for {tmp} in __tb.ctx_for(range):\n"
" pass\n",
target=node.target,
range=node.iter,
passes=[stmts + node.body],
span=node,
)
def visit_Continue(self, node: ast.Continue):
node = self.generic_visit(node)
return quote("if __tb.ctx_continue(): continue", span=node)
def visit_Break(self, node: ast.Break):
node = self.generic_visit(node)
return quote("if __tb.ctx_break(): break", span=node)
def _emit_assign_target(self,
target: ast.expr,
rval: ast.expr,
annot: ast.expr = None) -> list[ast.AST]:
if isinstance(target, ast.Name):
if annot is None:
return quote(
f"name = __tb.bind('{target.id}', value)", name=target, value=rval, span=target)
else:
return quote(
f'name = __tb.bind("{target.id}", value, annot)',
name=target,
value=rval,
annot=annot,
span=target)
elif isinstance(target, ast.Attribute):
s = ast.unparse(target)
raise NotImplementedError(f'Attribute assignment not supported yet, `{s}`')
elif isinstance(target, ast.Subscript):
if annot is None:
return quote(
"__tb.assign_slice(lval, slice, value)",
lval=target.value,
slice=target.slice,
value=rval,
span=target,
)
else:
return quote(
"__tb.assign_slice(lval, slice, value, annot)",
lval=target.value,
slice=target.slice,
value=rval,
annot=annot,
span=target,
)
else:
unpacked = []
def _visit_target(target: ast.expr) -> str:
if isinstance(target, (ast.Name, ast.Subscript)):
tmp = self.get_tmp()
unpacked.append((tmp, target))
res = ast.Name(id=tmp, ctx=target.ctx)
ast_set_span(res, ast_get_span(target))
return res
elif isinstance(target, ast.Tuple):
elts = [_visit_target(elt) for elt in target.elts]
res = ast.Tuple(elts=elts, ctx=target.ctx)
ast_set_span(res, ast_get_span(target))
return res
unpack_stmt = ast.Assign(
targets=[_visit_target(target)],
value=quote_expr('__tb.unwrap_value(rval)', rval=rval, span=rval))
ast_set_span(unpack_stmt, ast_get_span(target))
stmts = [unpack_stmt]
bind_lvals = []
bind_rvals = []
def flush_binds():
if bind_lvals:
stmts.append(
quote1(f'{", ".join(bind_lvals)}, = {", ".join(bind_rvals)},', span=target))
bind_lvals.clear()
bind_rvals.clear()
for tmp, target in unpacked:
if isinstance(target, ast.Name):
bind_lvals.append(target.id)
bind_rvals.append(f'__tb.bind("{target.id}", {tmp})')
elif isinstance(target, ast.Subscript):
flush_binds()
stmts.append(
quote1(
f'__tb.assign_slice(lval, slice, {tmp})',
lval=target.value,
slice=target.slice,
span=target))
else:
s = ast.unparse(target)
raise NotImplementedError(f'Unsupported target: {s}')
flush_binds()
return stmts
def visit_Assign(self, node: ast.Assign) -> list[ast.AST]:
node = self.generic_visit(node)
rval = node.value
if len(node.targets) == 1:
return self._emit_assign_target(node.targets[0], rval)
else:
tmp_name = self.get_tmp()
tmp_store = ast.Name(tmp_name, ctx=ast.Store())
tmp_load = ast.Name(tmp_name, ctx=ast.Load())
ast_set_span(tmp_store, node.targets[0])
ast_set_span(tmp_load, node.targets[0])
stmt = self._emit_assign_target(tmp_store, rval)
for target in node.targets:
stmt.extend(self._emit_assign_target(target, tmp_load))
return stmt
def visit_AugAssign(self, node: ast.AugAssign) -> list[ast.AST]:
node = self.generic_visit(node)
target, rval = node.target, node.value
op = get_operator_name(node.op)
if isinstance(target, ast.Name):
return quote(
f"name = __tb.aug_assign('{op}', {target.id}, value)",
name=target,
value=rval,
span=node)
elif isinstance(target, ast.Subscript):
return quote(
f"__tb.aug_assign_slice('{op}', lval, slice, value)",
lval=target.value,
slice=target.slice,
value=rval,
span=node,
)
else:
return node
def visit_AnnAssign(self, node: ast.AnnAssign):
node = self.generic_visit(node)
rval = node.value or quote_expr('__tb.empty', span=node, annot=node)
return self._emit_assign_target(node.target, rval, annot=node.annotation)
def visit_While(self, node):
return quote1(
"for _ in __tb.ctx_while(lambda: cond):\n pass",
cond=node.test,
passes=[node.body],
span=node)
def visit_FunctionDef(self, node: ast.FunctionDef):
node = self.generic_visit(node)
all_args = node.args.posonlyargs + node.args.args
if node.args.vararg is not None:
all_args += node.args.vararg
all_args += node.args.kwonlyargs
stmts = []
for arg in all_args:
name = arg.arg
if arg.annotation is not None:
arg_stmt = quote1(f'{name} = __tb.arg("{name}", {name})', span=arg)
else:
arg_stmt = quote1(f'{name} = __tb.arg("{name}", {name})', span=arg)
arg.annotation = None
stmts.append(arg_stmt)
node.body = stmts + node.body
node.decorator_list.clear()
return quote1(
f"def {node.name}(__tb):\n"
" range = __tb.override('range')\n"
" pass\n"
f" return {node.name}",
passes=[node],
)
def visit_BoolOp(self, node: ast.BoolOp):
node = self.generic_visit(node)
op_name = get_boolop_name(node.op)
last = node.values[-1]
for i in reversed(range(len(node.values) - 1)):
last = quote_expr(
expr=f"__tb.boolop('{op_name}', left, lambda: right)",
left=node.values[i],
right=last,
span=node,
)
return last
def visit_Compare(self, node: ast.Compare) -> ast.expr:
node = self.generic_visit(node)
left = node.left
split = []
for op, comp in zip(node.ops, node.comparators):
cmp = ast.Compare(left=left, ops=[op], comparators=[comp])
ast_set_span(cmp, ast_get_span(node))
split.append(cmp)
left = comp
last = split[-1]
for i in reversed(range(len(split) - 1)):
last = quote_expr(
"__tb.boolop('And', left, lambda: right)", left=split[i], right=last, span=node)
return last
def visit_IfExp(self, node: ast.IfExp) -> ast.Expr:
node = self.generic_visit(node)
return quote_expr(
'__tb.ifexp(cond, lambda: then, lambda: otherwise)',
cond=node.test,
then=node.body,
otherwise=node.orelse,
span=node)
def visit_Return(self, node: ast.Return):
node = self.generic_visit(node)
return quote("return __tb.ret(value)", value=node.value, span=node)
def visit_With(self, node: ast.With):
node = self.generic_visit(node)
for expr in node.items:
expr.context_expr = quote_expr("__tb.ctx_with(e)", e=expr.context_expr, span=expr)
return node
def visit_Assert(self, node: ast.Assert):
node = self.generic_visit(node)
return quote("__tb.assert_expr(cond, msg)", cond=node.test, msg=node.msg, span=node)
def visit_Name(self, node: ast.Name):
if isinstance(node.ctx, ast.Load):
return quote_expr(f"__tb.rval('{node.id}', {node.id})", span=node)
return node
_P = ParamSpec('_P')
@dataclass
class IRGenerator(Generic[_P, _T]):
gen: Callable[[BaseBuilder], Callable[_P, _T]]
source: str
def mutate(func: Callable[_P, _T]) -> IRGenerator[_P, _T]:
"""
Transform a Python function into an IR (Intermediate Representation) generator.
This function takes a regular Python function and performs AST (Abstract Syntax Tree)
transformation to create an IRGenerator that can be used for code generation purposes.
Args:
func (Callable[_P, _T]): The Python function to be transformed. This should be a
callable that will be analyzed and mutated at the AST level. The function's
signature is preserved through generic type parameters _P (parameters) and
_T (return type).
Returns:
IRGenerator[_P, _T]: An IRGenerator instance wrapping the transformed function.
The generator contains:
- gen: The compiled and mutated version of the original function
- source: The unparsed source code of the transformed AST as a string
Example:
>>> @mutate
... def my_function(x: int) -> int:
... return x * 2
>>> # my_function is now an IRGenerator that can be used for code generation
Note:
- The original function's closure variables and captured context are preserved
- The transformation is performed at compile-time through AST manipulation
- The returned IRGenerator maintains type information from the original function
"""
tree = utils.get_ast(func)
filename = inspect.getsourcefile(func) or inspect.getfile(func)
tree = DSLMutator().visit(tree)
fn = utils.get_compiled_object(tree, func.__name__, filename,
utils.inspect_function_capture(func))
return IRGenerator(gen=fn, source=ast.unparse(tree))
from __future__ import annotations
from contextlib import contextmanager
from dataclasses import dataclass
import inspect
from tilelang.language.kernel import KernelLaunchFrame
from tvm_ffi.container import Map
from tvm.ir.base import Span
from .ast import BaseBuilder, IRGenerator, eval_op, mutate
import tvm
from tvm.tir import Buffer
from tvm.script.ir_builder import tir, IRBuilder
from tvm.tir.expr import EqualOp, FloatImm, IntImm, NotEqualOp, PrimExpr, StringImm, Var
from typing import TYPE_CHECKING, Callable, ContextManager, Any, Generic, ParamSpec, Self, TypeVar, ForwardRef
from . import dtypes as dt
import threading
import logging
logger = logging.getLogger(__name__)
def unwrap_expr(expr) -> PrimExpr | int | float:
'''
unwrap expr and convert it into PrimExpr like
'''
if isinstance(expr, tir.meta_var):
expr = expr.value
elif isinstance(expr, Buffer) and expr.scope() == 'local.var':
expr = tir.BufferLoad(expr, indices=[0])
elif isinstance(expr, (EqualOp, NotEqualOp)):
expr = expr.asobject()
return expr
def unwrap_cond(expr):
'''
unwrap expr and convert to bool condition
'''
expr = unwrap_expr(expr)
if isinstance(expr, (IntImm, FloatImm, StringImm)):
return bool(expr.value)
elif isinstance(expr, PrimExpr):
return expr
elif isinstance(expr, Buffer):
raise TypeError(f"Buffer `{expr}` cannot be used as condition directly.")
elif isinstance(expr, (int, bool)) or expr is None:
return bool(expr)
else:
logger.warning(
f"Python expression `{expr}` is used as condition in TileLang, \n"
"this is treated as a constant expression. ",
stack_info=True,
stacklevel=3)
return bool(expr)
thread_local_storage = threading.local()
class Frame:
'''
Frame are virtual context managers used in frontend only
They do not have any runtime representation in the generated TIR.
'''
def __enter__(self):
...
def __exit__(self, exc_type, exc_value, traceback):
...
class MacroFrame(Frame):
...
class BoolOpFrame(Frame):
...
class ConstIfFrame(Frame):
...
class BlockFrame(Frame):
...
class ContinueFrame(Frame):
...
class BreakFrame(Frame):
...
ContinueOrBreak = ContinueFrame | BreakFrame
AnyFrame = tir.frame.IRBuilderFrame | Frame
TIR_CONTROL_FRAME = (
tir.frame.WhileFrame,
tir.frame.ForFrame,
tir.frame.IfFrame,
tir.frame.PrimFuncFrame,
)
TIR_VAR_SCOPE_FRAME = (
tir.frame.WhileFrame,
tir.frame.ForFrame,
tir.frame.IfFrame,
tir.frame.PrimFuncFrame,
MacroFrame,
KernelLaunchFrame,
)
def is_var(v: Any) -> bool:
return isinstance(v, Buffer) and v.scope() == 'local.var'
class Builder(BaseBuilder):
def __init__(self):
self.frames: list[AnyFrame] = []
self.ir_builder = IRBuilder()
self.name_inside_frame: dict[str, AnyFrame] = {}
@classmethod
def current(cls) -> Self:
builder = thread_local_storage.builder
assert builder is not None, "No active Builder found in the current thread."
return builder
@contextmanager
def prim_func(self, name):
thread_local_storage.builder = self
with self.ir_builder, self.with_frame(tir.prim_func()):
tir.func_name(name)
yield
@contextmanager
def macro(self, name=None):
if self.find_frame_idx(BoolOpFrame) is not None:
raise RuntimeError(
f"Macro `{name}` is used inside boolean expressions, "
"please use `if` to replace `M and M`, `M or M`, `M if xxx else M` constructs")
save = self.name_inside_frame
self.name_inside_frame = {}
with self.with_frame(MacroFrame()):
yield
self.name_inside_frame = save
def get(self):
return self.ir_builder.get()
def find_frame_idx(self, frame: type | tuple[type, ...], start=0) -> int | None:
for idx in reversed(range(start, len(self.frames))):
f = self.frames[idx]
if isinstance(f, frame):
return idx
def enter_frame(self, frame: ContextManager):
self.frames.append(frame)
return frame.__enter__()
def check_continue_break(self):
idx = self.find_frame_idx(ContinueOrBreak)
if idx is not None:
logger.warning(
'Writing code after continue/break may cause undefined behavior in tilelang.',
stack_info=True,
stacklevel=3)
@contextmanager
def with_frame(self, frame: ContextManager | None):
pop_idx = len(self.frames)
yield self.enter_frame(frame)
while len(self.frames) > pop_idx:
self.frames.pop().__exit__(None, None, None)
class _has_if_frame:
...
def ctx_if(self, cond):
self.check_continue_break()
cond = unwrap_cond(cond)
if isinstance(cond, PrimExpr):
with self.with_frame(tir.If(cond)):
yield self._has_if_frame
else:
with self.with_frame(ConstIfFrame()):
yield cond
def ctx_then(self, val):
if val is self._has_if_frame:
with self.with_frame(tir.Then()):
yield
else:
with self.with_frame(BlockFrame()):
if val:
yield
def ctx_else(self, val):
if val is self._has_if_frame:
with self.with_frame(tir.Else()):
yield
else:
with self.with_frame(BlockFrame()):
if not val:
yield
def eval(self, val: Any):
val = unwrap_expr(val)
if val is None:
pass
elif isinstance(val, tir.frame.IRBuilderFrame):
if isinstance(val, tir.frame.ForFrame):
logger.warning(
'Evaluating a for frame may cause undefined behavior in tilelang.',
stack_info=True,
stacklevel=1,
)
self.enter_frame(val)
elif isinstance(val, PrimExpr):
tir.evaluate(val)
elif isinstance(val, (int, bool)):
tir.evaluate(tvm.tir.const(val))
elif isinstance(val, str):
pass
elif isinstance(val, tvm.tir.stmt.BufferStore):
tir.buffer_store(val.buffer, val.value, val.indices, val.predicate)
else:
raise TypeError(f"Unsupported eval value: {val} of type {type(val)}")
def ctx_for(self, it):
self.check_continue_break()
it = unwrap_expr(it)
if not isinstance(it, tir.frame.ForFrame):
raise TypeError(
f"Invalid for loop, got {it}({type(it)}), expect one of the following: "
"range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding")
with self.with_frame(it) as v:
yield v
def ctx_continue(self):
self.check_continue_break()
# add a dummy frame for checking code after continue/break
self.enter_frame(ContinueFrame())
tir.evaluate(tir.continue_loop())
def ctx_break(self):
self.check_continue_break()
# add a dummy frame for checking code after continue/break
self.enter_frame(BreakFrame())
tir.evaluate(tir.break_loop())
def ctx_while(self, cond):
self.check_continue_break()
raise RuntimeError("while loops are not supported in TileLang builder")
def bind(self, name, value, annot=BaseBuilder.empty):
self.check_continue_break()
locals = self.get_parent_locals()
orig_value = locals.get(name, None)
# annotation like tl.float32
# temporarily disable annotation based var declaration, for better pull request separation
# if callable(annot):
# annot_val = annot()
# if isinstance(annot_val, tir.Var):
# orig_value = tir.alloc_buffer((1,), dtype=annot_val.dtype, scope='local.var')
# IRBuilder.name(name, orig_value)
# if isinstance(value, EllipsisType) or value is self.empty:
# return orig_value
# elif isinstance(value, (int, float, IntImm, FloatImm)):
# tir.block_attr(
# {'tl.local_var_init': {
# orig_value.data: tvm.runtime.convert(value)
# }})
# return orig_value
# if orig_value is a local.var, we use buffer_store to modify it immutably
# however, if rvalue is also a local.var, this is a new binding,
# we should not use buffer_store, and bind it instead
# ```py
# a = tl.alloc_var('float32') # bind var `a`
# a = tl.alloc_var('float32') # bind a new var `a_1`
# b = a # get value of var `b = a_1[0]``
# c = tl.alloc_var('float32') # bind var `c`
# c = a # get and assign `c[0] = a_1[0]`
# ```
if is_var(orig_value) and not is_var(value):
tir.buffer_store(orig_value, value, 0)
return orig_value
res = self.bind_immutable(name, value)
if name != '_':
frame = self.find_frame_idx(TIR_VAR_SCOPE_FRAME)
assert frame is not None, f"Variable `{name}` is not defined inside any control flow."
if name in self.name_inside_frame and self.name_inside_frame[name] in self.frames:
logger.warning(
f'Variable `{name}` shadows another declared value, Are you forgetting to allocate it as a var?',
stack_info=True,
stacklevel=2,
)
self.name_inside_frame[name] = self.frames[frame]
return res
def unwrap_value(self, value):
value = unwrap_expr(value)
# handle bx, by = tl.Kernel(128, 128), rval is frame
if isinstance(value, tir.frame.IRBuilderFrame):
return self.enter_frame(value)
else:
return value
def bind_immutable(self, name, value):
if isinstance(value, tir.meta_var):
return value.value
elif isinstance(value, tir.frame.IRBuilderFrame):
if isinstance(value, tir.frame.ForFrame):
logger.warning(
'Binding a for frame to variable may cause undefined behavior in tilelang.',
stack_info=True,
stacklevel=2,
)
return self.enter_frame(value)
elif isinstance(value, (Buffer, tir.IterVar, tir.Var)):
IRBuilder.name(name, value)
return value
elif isinstance(value, (tuple, list, tvm.ffi.Array)):
return value
else:
try:
value = tvm.runtime.convert(value)
except TypeError:
return value
frame = tir.LetStmt(value)
var = frame.var
IRBuilder.name(name, var)
return self.enter_frame(frame)
def assign_slice(self, lval: Any, sl: slice, value: Any, annot=BaseBuilder.empty):
self.check_continue_break()
if annot is not self.empty:
logger.warning(
"Type annotation in slice assignment has no effect", stack_info=True, stacklevel=2)
if isinstance(lval, Buffer):
tir.buffer_store(lval, value, sl)
else:
return super().assign_slice(lval, sl, value)
def aug_assign(self, op, target, aug_value):
self.check_continue_break()
if is_var(target):
tir.buffer_store(target, eval_op(op, target[0], aug_value), 0)
elif isinstance(target, Buffer):
raise RuntimeError("Augmented assignment is not supported for Buffer")
else:
return super().aug_assign(op, target, aug_value)
def aug_assign_slice(self, op, target, sl, aug_value):
self.check_continue_break()
if isinstance(target, Buffer):
tir.buffer_store(target, eval_op(op, target[sl], aug_value), sl)
else:
return super().aug_assign_slice(op, target, sl, aug_value)
def boolop(self, op, left, right):
left = unwrap_cond(left)
if isinstance(left, PrimExpr):
with self.with_frame(BoolOpFrame()):
if op == 'And':
return tir.And(left, right())
if op == 'Or':
return tir.Or(left, right())
raise RuntimeError(f"Unsupported boolean operator: {op}")
else:
return super().boolop(op, left, right)
def ifexp(self, cond, then, otherwise):
cond = unwrap_cond(cond)
if isinstance(cond, PrimExpr):
with self.with_frame(BoolOpFrame()):
return tir.if_then_else(cond, then(), otherwise())
else:
return super().ifexp(cond, then, otherwise)
def ret(self, value):
self.check_continue_break()
# handle return T.alloc_var()
value = self.unwrap_value(value)
last_macro = self.find_frame_idx(MacroFrame)
if last_macro is not None:
frame = self.find_frame_idx(TIR_CONTROL_FRAME, start=last_macro)
if frame is not None:
raise NotImplementedError(
"Return from control flow is not supported yet. \n"
"You should allocate a var before the control flow, assign value inside the blocks, \n"
"and return the var after the control flow. i.e.\n"
"```\n"
"@T.macro\n" \
"def my_macro(cond):\n"
" a = T.alloc_var(T.float16)\n"
" if cond:\n"
" a = 1.0\n"
" return a\n"
"```"
)
return value
def ctx_with(self, ctx):
self.check_continue_break()
if isinstance(ctx, tir.frame.IRBuilderFrame):
return self.with_frame(ctx)
else:
return super().ctx_with(ctx)
def assert_expr(self, cond, msg):
self.check_continue_break()
cond = unwrap_cond(cond)
if isinstance(cond, PrimExpr):
self.enter_frame(tir.Assert(cond, msg))
elif not cond:
raise AssertionError(msg)
def rval(self, name: str, value: Any) -> Any:
if name in self.name_inside_frame:
frame = self.name_inside_frame[name]
if frame not in self.frames:
raise RuntimeError(
f"Use immutable variable `{name}` outside its defining region, did you forget **alloc_var**?\n"
f"variable `{name}` is defined in frame: {frame}, current frames: {self.frames}."
)
return self.unwrap_value(value)
def arg(self, name, value):
if self.find_frame_idx(MacroFrame) is not None:
if isinstance(value, (PrimExpr, int, float)):
return self.bind(name, value)
else:
return value
if isinstance(value, (Buffer, Var)):
return tir.arg(name, value)
elif value is self.empty:
raise ValueError(f'Argument `{name}` is not annotated')
# elif isinstance(value, Hashable):
# return value
else:
raise TypeError(
f"Unsupported argument type: {value}({type(value)}) for argument `{name}`.")
def override(self, name: str):
if name == 'range':
return tir.serial
raise ValueError(f'Unknown override: {name}')
_P = ParamSpec('_P')
_T = TypeVar('_T')
if TYPE_CHECKING:
class PrimFunc(Generic[_P, _T], tvm.tir.PrimFunc):
params: list[tvm.tir.Var | tvm.tir.Buffer]
body: tvm.tir.Stmt
ret_type: tvm.ir.Type
buffer_map: Map[tvm.tir.Var, tvm.tir.Buffer]
attrs: tvm.Attrs | None
span: Span | None
ir_gen: IRGenerator[_P, _T] | None
source: str | None
orig_func: Callable[_P, _T] | None
else:
PrimFunc = tvm.tir.PrimFunc
@dataclass
class Macro(Generic[_P, _T]):
name: str
orig_func: Callable[_P, _T]
ir_gen: IRGenerator[_P, _T]
@property
def source(self) -> str:
return self.ir_gen.source
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
builder = Builder.current()
with builder.macro(self.name):
res = self.ir_gen.gen(builder)(*args, **kwargs)
return res
def macro(func: Callable[_P, _T] = None) -> Macro[_P, _T]:
"""
Decorator that converts a Python function into a TileLang macro.
TileLang macro is very similar to PrimFunc, it can be used in prim_func or another macro.
Parameters
----------
func : Callable[_P, _T]
The Python function to be converted into a macro. This function will be analyzed
and transformed into an IR generation function. The function can take any parameters
(_P) and return any type (_T).
Returns
-------
Macro[_P, _T]
A Macro object that wraps the original function with IR generation capabilities.
The returned Macro preserves the original function's signature (parameters _P and
return type _T) while adding metaprogramming capabilities.
Example:
--------
>>> @macro
... def my_macro(x: T.int32) -> T.int32:
... return x ** 2
>>> @prim_func
... def my_func(A: T.Tensor((10,), T.int32), B: T.Tensor((10,), T.int32)):
... with T.Kernel(1) as _:
... for i in T.serial(10):
... B[i] = my_macro(A[i])
See Also
--------
Macro : The class that wraps macro functions
mutate : The function that transforms Python code into IR generators
"""
def impl(func: Callable[_P, _T]) -> Macro[_P, _T]:
return Macro(name=func.__name__, orig_func=func, ir_gen=mutate(func))
return impl(func) if func is not None else impl
from typing import _eval_type
def get_type_hints(func):
annot = getattr(func, '__annotations__', None)
if annot is None:
raise TypeError(f'Failed to get function type hints, {func} is not a function')
hints = {}
type_params = getattr(func, "__type_params__", ())
globalns = getattr(func, '__globals__', {})
localns = globalns
for name, value in annot.items():
if name == 'return':
continue
if isinstance(value, tvm.DataType):
hints[name] = value
continue
if value is None:
value = type(None)
if isinstance(value, str):
# this branch handles T.float32 style annotation
# since they are string, directly evaluating them usually causes NameError
# so we need to split and evaluate them separately
_, v = value.split('.', maxsplit=1)
if v in dt._all_dtypes:
try:
hints[name] = eval(value, globalns, localns)
continue
except Exception:
pass
value = ForwardRef(value, is_argument=True, is_class=False)
hints[name] = _eval_type(value, globalns=globalns, localns=localns, type_params=type_params)
return hints
def _is_static_annot(annot: Any) -> bool:
return isinstance(annot, (dt.dtype, Buffer, Var))
def prim_func(func: Callable[_P, _T] = None,
*,
generator: bool = False) -> PrimFunc[_P, _T] | Callable[_P, PrimFunc[_P, _T]]:
"""
Decorator to create a primitive function (PrimFunc) for TileLang IR generation.
This decorator transforms a Python function into a TileLang primitive function by analyzing
its type annotations and generating intermediate representation (IR) code. It supports both
immediate construction (when all parameters are statically annotated) and generator mode
(for dynamic construction).
Parameters
----------
func : Callable[_P, _T], optional
The function to be decorated. Can be None when using decorator with arguments.
generator : bool, default=False
If True, returns a generator function that creates PrimFunc instances on demand.
If False, attempts to create a PrimFunc immediately using type annotations.
Returns
-------
PrimFunc[_P, _T] | Callable[_P, PrimFunc[_P, _T]]
- If `generator=False` and all parameters are statically annotated: returns a PrimFunc instance
- If `generator=True`: returns a callable that generates PrimFunc instances when invoked
- If used without parentheses: returns the decorator implementation function
Examples
--------
Static annotation mode (immediate construction):
>>> @prim_func
... def add_kernel(A: T.Buffer((128,), T.float32),
... B: T.Buffer((128,), T.float32)):
... for i in T.grid(128):
... B[i] = A[i] + 1.0
Generator mode (dynamic construction):
>>> @prim_func(generator=True)
... def dynamic_kernel(A=T.Tensor((128,), T.float32)):
... # function body
... pass
>>> kernel_instance = dynamic_kernel()
With custom parameters:
>>> @prim_func(generator=True)
... def parameterized_kernel(size: int = 128):
... # function body using size parameter
... pass
>>> kernel = parameterized_kernel(size=256)
See Also
--------
Builder : The IR builder class used for constructing primitive functions
mutate : Function used to generate IR from the decorated function
"""
def impl(func: Callable[_P, _T]) -> PrimFunc[_P, _T] | Callable[_P, PrimFunc[_P, _T]]:
sig = inspect.signature(func)
annot = get_type_hints(func)
for k in annot:
if callable(annot[k]):
annot[k] = annot[k]()
# check whether all arguments are annotated
all_arg_annotated = all([x in annot for x in sig.parameters])
# check whether all annotations are Buffer/Var/dtype
all_annot_are_static = all([_is_static_annot(x) for x in annot.values()])
ir_gen = mutate(func)
def prim_func_generator(*args, **kwargs):
builder = Builder()
with builder.prim_func(func.__name__):
ir_gen.gen(builder)(*args, **kwargs)
res = builder.get()
res.ir_gen = ir_gen
res.source = ir_gen.source
res.orig_func = func
return res
prim_func_generator.ir_gen = ir_gen
prim_func_generator.source = ir_gen.source
prim_func_generator.orig_func = func
if generator:
return prim_func_generator
if all_arg_annotated and all_annot_are_static:
return prim_func_generator(**annot)
else:
raise ValueError(
"Some arguments are not supported or statically annotated, \n"
"please check the annotations or set generator=True to get a prim_func generator.\n"
f"Argument Annotations: {annot}\n"
"Example usage of generator:\n"
"```py\n"
"@prim_func(generator=True)\n"
"def my_func(a=T.Tensor((128,), T.float32)): ...\n"
"return my_func()\n"
"```")
return impl(func) if func is not None else impl
from tilelang import tvm
from tvm import ir
import tvm_ffi
import torch
import ctypes
from typing import TYPE_CHECKING
from tvm import tir
import tvm.script.ir_builder.tir._ffi_api as tb_ffi
dtype = tvm.DataType
AnyDType = ir.Type | str | type | torch.dtype | dtype
_dtype_cvt = [
(None, 'handle', ctypes.c_long, 'long', None), # use long to repr void*
(bool, 'bool', ctypes.c_bool, 'bool', 'Boolean'),
(int, 'int32', ctypes.c_int32, 'int', 'Int32'),
(float, 'float32', ctypes.c_float, 'float', 'Float32'),
(torch.short, 'int16', ctypes.c_int16, 'short', 'Int16'),
(torch.int, 'int32', ctypes.c_int32, 'int', 'Int32'),
(torch.long, 'int64', ctypes.c_int64, 'long long', 'Int64'),
(torch.half, 'float16', None, None, 'Float16'),
(torch.float, 'float32', ctypes.c_float, 'float', 'Float32'),
(torch.double, 'float64', ctypes.c_double, 'double', 'Float64'),
# (pytype, 'tvm dtype str', 'ctypes dtype', 'cffi dtype')
(torch.bool, 'bool', ctypes.c_bool, 'bool', 'Boolean'),
(torch.int8, 'int8', ctypes.c_int8, 'char', 'Int8'),
(torch.int16, 'int16', ctypes.c_int16, 'short', 'Int16'),
(torch.int32, 'int32', ctypes.c_int32, 'int', 'Int32'),
(torch.int64, 'int64', ctypes.c_int64, 'long long', 'Int64'),
(torch.uint8, 'uint8', ctypes.c_uint8, 'unsigned char', 'UInt8'),
(torch.uint16, 'uint16', ctypes.c_uint16, 'unsigned short', 'UInt16'),
(torch.uint32, 'uint32', ctypes.c_uint32, 'unsigned int', 'UInt32'),
(torch.uint64, 'uint64', ctypes.c_uint64, 'unsigned long long', 'UInt64'),
(torch.float16, 'float16', None, None, 'Float16'),
(torch.float32, 'float32', ctypes.c_float, 'float', 'Float32'),
(torch.float64, 'float64', ctypes.c_double, 'double', 'Float64'),
(None, 'float8_e4m3', None, None, 'Float8E4M3'),
(torch.float8_e4m3fn, 'float8_e4m3fn', None, None, 'Float8E4M3FN'),
(torch.float8_e4m3fnuz, 'float8_e4m3fnuz', None, None, 'Float8E4M3FNUZ'),
(torch.float8_e5m2, 'float8_e5m2', None, None, 'Float8E5M2'),
(torch.float8_e5m2fnuz, 'float8_e5m2fnuz', None, None, 'Float8E5M2FNUZ'),
(torch.float8_e8m0fnu, 'float8_e8m0fnu', None, None, 'Float8E8M0FNU'),
(torch.bfloat16, 'bfloat16', None, None, 'BFloat16'),
]
def _create_type_mapper(sidx, didx, smapper=lambda x: x, dmapper=lambda x: x):
return {
smapper(item[sidx]): dmapper(item[didx])
for item in _dtype_cvt
if item[didx] is not None and item[sidx] is not None
}
_dtype_py2tvmstr = _create_type_mapper(0, 1)
_dtype_tvmstr2fficall = _create_type_mapper(1, 4, dmapper=lambda x: getattr(tb_ffi, x))
_dtype_tvm2py = _create_type_mapper(1, 0, lambda x: dtype(x))
_dtype_tvm2ctype = _create_type_mapper(1, 2, lambda x: dtype(x))
_dtype_tvm2cffi = _create_type_mapper(1, 3, lambda x: dtype(x))
def __dtype_eq__(self: dtype, other: AnyDType):
if isinstance(other, str):
return str.__eq__(self, other)
if other in _dtype_py2tvmstr:
return str.__eq__(self, _dtype_py2tvmstr[other])
return NotImplemented
def __dtype_ne__(self: dtype, other: AnyDType):
if isinstance(other, str):
return str.__ne__(self, other)
if other in _dtype_py2tvmstr:
return str.__ne__(self, _dtype_py2tvmstr[other])
return NotImplemented
def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var:
if self in _dtype_tvmstr2fficall:
return _dtype_tvmstr2fficall[self](expr, is_size_var)
# try to construct the ffi call
if self.startswith('uint'):
val = 'UInt' + self[4:]
elif self.startswith('int'):
val = 'Int' + self[3:]
elif self.startswith('float'):
val = 'Float' + self[5:]
elif self.startswith('bfloat'):
val = 'BFloat' + self[6:]
else:
raise TypeError(f'Invalid type {self}')
if '_' in val:
first, second = val.split('_', maxsplit=1)
val = first + second.upper()
call = getattr(tb_ffi, val, None)
if call is None:
raise TypeError(f"Convert to datatype `{self}` is not supported by tvm\n"
f"calling failed on `tvm.script.ir_builder.tir._ffi_api.{val}`")
return call(expr, is_size_var)
def __dtype_new__(cls, value: AnyDType) -> dtype:
if isinstance(value, str):
val = str.__new__(cls, value)
elif value in _dtype_py2tvmstr:
val = str.__new__(cls, _dtype_py2tvmstr[value])
else:
expected = set(list(_dtype_py2tvmstr.keys()) + list(_dtype_tvmstr2fficall.values()))
raise TypeError(f"Invalid DataType {value}({type(value)}), expect one of {expected}")
val.__tvm_ffi_dtype__ = tvm_ffi.core.DataType(val)
return val
dtype.__eq__ = __dtype_eq__
dtype.__req__ = __dtype_eq__
dtype.__ne__ = __dtype_ne__
dtype.__rne__ = __dtype_ne__
dtype.__call__ = __dtype_call__
dtype.__new__ = __dtype_new__
def get_tvm_dtype(value: AnyDType) -> dtype:
if isinstance(value, (dtype, ir.Type)):
return value
return dtype(value)
if TYPE_CHECKING:
# yapf: disable
class bool(dtype): ...
class short(dtype): ...
class int(dtype): ...
class long(dtype): ...
class half(dtype): ...
class float(dtype): ...
class double(dtype): ...
class int8(dtype): ...
class int16(dtype): ...
class int32(dtype): ...
class int64(dtype): ...
class int8x4(dtype): ...
class int16x4(dtype): ...
class int32x4(dtype): ...
class int64x4(dtype): ...
class int8x8(dtype): ...
class int16x8(dtype): ...
class int32x8(dtype): ...
class int64x8(dtype): ...
class int8x16(dtype): ...
class int16x16(dtype): ...
class int32x16(dtype): ...
class int64x16(dtype): ...
class int8x32(dtype): ...
class int16x32(dtype): ...
class int32x32(dtype): ...
class int64x32(dtype): ...
class int8x64(dtype): ...
class int16x64(dtype): ...
class int32x64(dtype): ...
class int64x64(dtype): ...
class uint8(dtype): ...
class uint16(dtype): ...
class uint32(dtype): ...
class uint64(dtype): ...
class uint8x4(dtype): ...
class uint16x4(dtype): ...
class uint32x4(dtype): ...
class uint64x4(dtype): ...
class uint8x8(dtype): ...
class uint16x8(dtype): ...
class uint32x8(dtype): ...
class uint64x8(dtype): ...
class uint8x16(dtype): ...
class uint16x16(dtype): ...
class uint32x16(dtype): ...
class uint64x16(dtype): ...
class uint8x32(dtype): ...
class uint16x32(dtype): ...
class uint32x32(dtype): ...
class uint64x32(dtype): ...
class uint8x64(dtype): ...
class uint16x64(dtype): ...
class uint32x64(dtype): ...
class uint64x64(dtype): ...
class float16(dtype): ...
class float32(dtype): ...
class float64(dtype): ...
class float16x2(dtype): ...
class float32x2(dtype): ...
class float64x2(dtype): ...
class float16x4(dtype): ...
class float32x4(dtype): ...
class float64x4(dtype): ...
class float16x8(dtype): ...
class float32x8(dtype): ...
class float64x8(dtype): ...
class float16x16(dtype): ...
class float32x16(dtype): ...
class float64x16(dtype): ...
class float16x32(dtype): ...
class float32x32(dtype): ...
class float64x32(dtype): ...
class float16x64(dtype): ...
class float32x64(dtype): ...
class float64x64(dtype): ...
class float8_e3m4(dtype): ...
class float8_e3m4x2(dtype): ...
class float8_e3m4x4(dtype): ...
class float8_e3m4x8(dtype): ...
class float8_e3m4x16(dtype): ...
class float8_e3m4x32(dtype): ...
class float8_e3m4x64(dtype): ...
class float8_e4m3(dtype): ...
class float8_e4m3x2(dtype): ...
class float8_e4m3x4(dtype): ...
class float8_e4m3x8(dtype): ...
class float8_e4m3x16(dtype): ...
class float8_e4m3x32(dtype): ...
class float8_e4m3x64(dtype): ...
class float8_e4m3b11fnuz(dtype): ...
class float8_e4m3b11fnuzx2(dtype): ...
class float8_e4m3b11fnuzx4(dtype): ...
class float8_e4m3b11fnuzx8(dtype): ...
class float8_e4m3b11fnuzx16(dtype): ...
class float8_e4m3b11fnuzx32(dtype): ...
class float8_e4m3b11fnuzx64(dtype): ...
class float8_e4m3fn(dtype): ...
class float8_e4m3fnx2(dtype): ...
class float8_e4m3fnx4(dtype): ...
class float8_e4m3fnx8(dtype): ...
class float8_e4m3fnx16(dtype): ...
class float8_e4m3fnx32(dtype): ...
class float8_e4m3fnx64(dtype): ...
class float8_e4m3fnuz(dtype): ...
class float8_e4m3fnuzx2(dtype): ...
class float8_e4m3fnuzx4(dtype): ...
class float8_e4m3fnuzx8(dtype): ...
class float8_e4m3fnuzx16(dtype): ...
class float8_e4m3fnuzx32(dtype): ...
class float8_e4m3fnuzx64(dtype): ...
class float8_e5m2(dtype): ...
class float8_e5m2x2(dtype): ...
class float8_e5m2x4(dtype): ...
class float8_e5m2x8(dtype): ...
class float8_e5m2x16(dtype): ...
class float8_e5m2x32(dtype): ...
class float8_e5m2x64(dtype): ...
class float8_e5m2fnuz(dtype): ...
class float8_e5m2fnuzx2(dtype): ...
class float8_e5m2fnuzx4(dtype): ...
class float8_e5m2fnuzx8(dtype): ...
class float8_e5m2fnuzx16(dtype): ...
class float8_e5m2fnuzx32(dtype): ...
class float8_e5m2fnuzx64(dtype): ...
class float8_e8m0fnu(dtype): ...
class float8_e8m0fnux2(dtype): ...
class float8_e8m0fnux4(dtype): ...
class float8_e8m0fnux8(dtype): ...
class float8_e8m0fnux16(dtype): ...
class float8_e8m0fnux32(dtype): ...
class float8_e8m0fnux64(dtype): ...
class float6_e2m3fn(dtype): ...
class float6_e2m3fnx2(dtype): ...
class float6_e2m3fnx4(dtype): ...
class float6_e2m3fnx8(dtype): ...
class float6_e2m3fnx16(dtype): ...
class float6_e2m3fnx32(dtype): ...
class float6_e2m3fnx64(dtype): ...
class float6_e3m2fn(dtype): ...
class float6_e3m2fnx2(dtype): ...
class float6_e3m2fnx4(dtype): ...
class float6_e3m2fnx8(dtype): ...
class float6_e3m2fnx16(dtype): ...
class float6_e3m2fnx32(dtype): ...
class float6_e3m2fnx64(dtype): ...
class float4_e2m1fn(dtype): ...
class float4_e2m1fnx2(dtype): ...
class float4_e2m1fnx4(dtype): ...
class float4_e2m1fnx8(dtype): ...
class float4_e2m1fnx16(dtype): ...
class float4_e2m1fnx32(dtype): ...
class float4_e2m1fnx64(dtype): ...
class bfloat16(dtype): ...
# yapf: enable
else:
bool = dtype('bool')
short = dtype('int16')
int = dtype('int32')
long = dtype('int64')
half = dtype('float16')
float = dtype('float32')
double = dtype('float64')
int8 = dtype('int8')
int16 = dtype('int16')
int32 = dtype('int32')
int64 = dtype('int64')
int8x4 = dtype('int8x4')
int16x4 = dtype('int16x4')
int32x4 = dtype('int32x4')
int64x4 = dtype('int64x4')
int8x8 = dtype('int8x8')
int16x8 = dtype('int16x8')
int32x8 = dtype('int32x8')
int64x8 = dtype('int64x8')
int8x16 = dtype('int8x16')
int16x16 = dtype('int16x16')
int32x16 = dtype('int32x16')
int64x16 = dtype('int64x16')
int8x32 = dtype('int8x32')
int16x32 = dtype('int16x32')
int32x32 = dtype('int32x32')
int64x32 = dtype('int64x32')
int8x64 = dtype('int8x64')
int16x64 = dtype('int16x64')
int32x64 = dtype('int32x64')
int64x64 = dtype('int64x64')
uint8 = dtype('uint8')
uint16 = dtype('uint16')
uint32 = dtype('uint32')
uint64 = dtype('uint64')
uint8x4 = dtype('uint8x4')
uint16x4 = dtype('uint16x4')
uint32x4 = dtype('uint32x4')
uint64x4 = dtype('uint64x4')
uint8x8 = dtype('uint8x8')
uint16x8 = dtype('uint16x8')
uint32x8 = dtype('uint32x8')
uint64x8 = dtype('uint64x8')
uint8x16 = dtype('uint8x16')
uint16x16 = dtype('uint16x16')
uint32x16 = dtype('uint32x16')
uint64x16 = dtype('uint64x16')
uint8x32 = dtype('uint8x32')
uint16x32 = dtype('uint16x32')
uint32x32 = dtype('uint32x32')
uint64x32 = dtype('uint64x32')
uint8x64 = dtype('uint8x64')
uint16x64 = dtype('uint16x64')
uint32x64 = dtype('uint32x64')
uint64x64 = dtype('uint64x64')
float16 = dtype('float16')
float32 = dtype('float32')
float64 = dtype('float64')
float16x2 = dtype('float16x2')
float32x2 = dtype('float32x2')
float64x2 = dtype('float64x2')
float16x4 = dtype('float16x4')
float32x4 = dtype('float32x4')
float64x4 = dtype('float64x4')
float16x8 = dtype('float16x8')
float32x8 = dtype('float32x8')
float64x8 = dtype('float64x8')
float16x16 = dtype('float16x16')
float32x16 = dtype('float32x16')
float64x16 = dtype('float64x16')
float16x32 = dtype('float16x32')
float32x32 = dtype('float32x32')
float64x32 = dtype('float64x32')
float16x64 = dtype('float16x64')
float32x64 = dtype('float32x64')
float64x64 = dtype('float64x64')
float8_e3m4 = dtype('float8_e3m4')
float8_e3m4x2 = dtype('float8_e3m4x2')
float8_e3m4x4 = dtype('float8_e3m4x4')
float8_e3m4x8 = dtype('float8_e3m4x8')
float8_e3m4x16 = dtype('float8_e3m4x16')
float8_e3m4x32 = dtype('float8_e3m4x32')
float8_e3m4x64 = dtype('float8_e3m4x64')
float8_e4m3 = dtype('float8_e4m3')
float8_e4m3x2 = dtype('float8_e4m3x2')
float8_e4m3x4 = dtype('float8_e4m3x4')
float8_e4m3x8 = dtype('float8_e4m3x8')
float8_e4m3x16 = dtype('float8_e4m3x16')
float8_e4m3x32 = dtype('float8_e4m3x32')
float8_e4m3x64 = dtype('float8_e4m3x64')
float8_e4m3b11fnuz = dtype('float8_e4m3b11fnuz')
float8_e4m3b11fnuzx2 = dtype('float8_e4m3b11fnuzx2')
float8_e4m3b11fnuzx4 = dtype('float8_e4m3b11fnuzx4')
float8_e4m3b11fnuzx8 = dtype('float8_e4m3b11fnuzx8')
float8_e4m3b11fnuzx16 = dtype('float8_e4m3b11fnuzx16')
float8_e4m3b11fnuzx32 = dtype('float8_e4m3b11fnuzx32')
float8_e4m3b11fnuzx64 = dtype('float8_e4m3b11fnuzx64')
float8_e4m3fn = dtype('float8_e4m3fn')
float8_e4m3fnx2 = dtype('float8_e4m3fnx2')
float8_e4m3fnx4 = dtype('float8_e4m3fnx4')
float8_e4m3fnx8 = dtype('float8_e4m3fnx8')
float8_e4m3fnx16 = dtype('float8_e4m3fnx16')
float8_e4m3fnx32 = dtype('float8_e4m3fnx32')
float8_e4m3fnx64 = dtype('float8_e4m3fnx64')
float8_e4m3fnuz = dtype('float8_e4m3fnuz')
float8_e4m3fnuzx2 = dtype('float8_e4m3fnuzx2')
float8_e4m3fnuzx4 = dtype('float8_e4m3fnuzx4')
float8_e4m3fnuzx8 = dtype('float8_e4m3fnuzx8')
float8_e4m3fnuzx16 = dtype('float8_e4m3fnuzx16')
float8_e4m3fnuzx32 = dtype('float8_e4m3fnuzx32')
float8_e4m3fnuzx64 = dtype('float8_e4m3fnuzx64')
float8_e5m2 = dtype('float8_e5m2')
float8_e5m2x2 = dtype('float8_e5m2x2')
float8_e5m2x4 = dtype('float8_e5m2x4')
float8_e5m2x8 = dtype('float8_e5m2x8')
float8_e5m2x16 = dtype('float8_e5m2x16')
float8_e5m2x32 = dtype('float8_e5m2x32')
float8_e5m2x64 = dtype('float8_e5m2x64')
float8_e5m2fnuz = dtype('float8_e5m2fnuz')
float8_e5m2fnuzx2 = dtype('float8_e5m2fnuzx2')
float8_e5m2fnuzx4 = dtype('float8_e5m2fnuzx4')
float8_e5m2fnuzx8 = dtype('float8_e5m2fnuzx8')
float8_e5m2fnuzx16 = dtype('float8_e5m2fnuzx16')
float8_e5m2fnuzx32 = dtype('float8_e5m2fnuzx32')
float8_e5m2fnuzx64 = dtype('float8_e5m2fnuzx64')
float8_e8m0fnu = dtype('float8_e8m0fnu')
float8_e8m0fnux2 = dtype('float8_e8m0fnux2')
float8_e8m0fnux4 = dtype('float8_e8m0fnux4')
float8_e8m0fnux8 = dtype('float8_e8m0fnux8')
float8_e8m0fnux16 = dtype('float8_e8m0fnux16')
float8_e8m0fnux32 = dtype('float8_e8m0fnux32')
float8_e8m0fnux64 = dtype('float8_e8m0fnux64')
float6_e2m3fn = dtype('float6_e2m3fn')
float6_e2m3fnx2 = dtype('float6_e2m3fnx2')
float6_e2m3fnx4 = dtype('float6_e2m3fnx4')
float6_e2m3fnx8 = dtype('float6_e2m3fnx8')
float6_e2m3fnx16 = dtype('float6_e2m3fnx16')
float6_e2m3fnx32 = dtype('float6_e2m3fnx32')
float6_e2m3fnx64 = dtype('float6_e2m3fnx64')
float6_e3m2fn = dtype('float6_e3m2fn')
float6_e3m2fnx2 = dtype('float6_e3m2fnx2')
float6_e3m2fnx4 = dtype('float6_e3m2fnx4')
float6_e3m2fnx8 = dtype('float6_e3m2fnx8')
float6_e3m2fnx16 = dtype('float6_e3m2fnx16')
float6_e3m2fnx32 = dtype('float6_e3m2fnx32')
float6_e3m2fnx64 = dtype('float6_e3m2fnx64')
float4_e2m1fn = dtype('float4_e2m1fn')
float4_e2m1fnx2 = dtype('float4_e2m1fnx2')
float4_e2m1fnx4 = dtype('float4_e2m1fnx4')
float4_e2m1fnx8 = dtype('float4_e2m1fnx8')
float4_e2m1fnx16 = dtype('float4_e2m1fnx16')
float4_e2m1fnx32 = dtype('float4_e2m1fnx32')
float4_e2m1fnx64 = dtype('float4_e2m1fnx64')
bfloat16 = dtype('bfloat16')
_all_dtypes = {
'bool',
'short',
'int',
'long',
'half',
'float',
'double',
'int8',
'int16',
'int32',
'int64',
'int8x4',
'int16x4',
'int32x4',
'int64x4',
'int8x8',
'int16x8',
'int32x8',
'int64x8',
'int8x16',
'int16x16',
'int32x16',
'int64x16',
'int8x32',
'int16x32',
'int32x32',
'int64x32',
'int8x64',
'int16x64',
'int32x64',
'int64x64',
'uint8',
'uint16',
'uint32',
'uint64',
'uint8x4',
'uint16x4',
'uint32x4',
'uint64x4',
'uint8x8',
'uint16x8',
'uint32x8',
'uint64x8',
'uint8x16',
'uint16x16',
'uint32x16',
'uint64x16',
'uint8x32',
'uint16x32',
'uint32x32',
'uint64x32',
'uint8x64',
'uint16x64',
'uint32x64',
'uint64x64',
'float16',
'float32',
'float64',
'float16x2',
'float32x2',
'float64x2',
'float16x4',
'float32x4',
'float64x4',
'float16x8',
'float32x8',
'float64x8',
'float16x16',
'float32x16',
'float64x16',
'float16x32',
'float32x32',
'float64x32',
'float16x64',
'float32x64',
'float64x64',
'float8_e3m4',
'float8_e3m4x2',
'float8_e3m4x4',
'float8_e3m4x8',
'float8_e3m4x16',
'float8_e3m4x32',
'float8_e3m4x64',
'float8_e4m3',
'float8_e4m3x2',
'float8_e4m3x4',
'float8_e4m3x8',
'float8_e4m3x16',
'float8_e4m3x32',
'float8_e4m3x64',
'float8_e4m3b11fnuz',
'float8_e4m3b11fnuzx2',
'float8_e4m3b11fnuzx4',
'float8_e4m3b11fnuzx8',
'float8_e4m3b11fnuzx16',
'float8_e4m3b11fnuzx32',
'float8_e4m3b11fnuzx64',
'float8_e4m3fn',
'float8_e4m3fnx2',
'float8_e4m3fnx4',
'float8_e4m3fnx8',
'float8_e4m3fnx16',
'float8_e4m3fnx32',
'float8_e4m3fnx64',
'float8_e4m3fnuz',
'float8_e4m3fnuzx2',
'float8_e4m3fnuzx4',
'float8_e4m3fnuzx8',
'float8_e4m3fnuzx16',
'float8_e4m3fnuzx32',
'float8_e4m3fnuzx64',
'float8_e5m2',
'float8_e5m2x2',
'float8_e5m2x4',
'float8_e5m2x8',
'float8_e5m2x16',
'float8_e5m2x32',
'float8_e5m2x64',
'float8_e5m2fnuz',
'float8_e5m2fnuzx2',
'float8_e5m2fnuzx4',
'float8_e5m2fnuzx8',
'float8_e5m2fnuzx16',
'float8_e5m2fnuzx32',
'float8_e5m2fnuzx64',
'float8_e8m0fnu',
'float8_e8m0fnux2',
'float8_e8m0fnux4',
'float8_e8m0fnux8',
'float8_e8m0fnux16',
'float8_e8m0fnux32',
'float8_e8m0fnux64',
'float6_e2m3fn',
'float6_e2m3fnx2',
'float6_e2m3fnx4',
'float6_e2m3fnx8',
'float6_e2m3fnx16',
'float6_e2m3fnx32',
'float6_e2m3fnx64',
'float6_e3m2fn',
'float6_e3m2fnx2',
'float6_e3m2fnx4',
'float6_e3m2fnx8',
'float6_e3m2fnx16',
'float6_e3m2fnx32',
'float6_e3m2fnx64',
'float4_e2m1fn',
'float4_e2m1fnx2',
'float4_e2m1fnx4',
'float4_e2m1fnx8',
'float4_e2m1fnx16',
'float4_e2m1fnx32',
'float4_e2m1fnx64',
'bfloat16',
}
__all__ = list(_all_dtypes) + [
'dtype',
'AnyDType',
'get_tvm_dtype',
]
from __future__ import annotations
import ast
import inspect
from typing import Any, Callable, Literal
from tilelang import env
from hashlib import sha256
import linecache
def disk_compile(source, name):
cache_dir = env.TILELANG_CACHE_DIR
if cache_dir is not None:
import os
save_dir = os.path.join(cache_dir, "py-cache")
os.makedirs(save_dir, exist_ok=True)
hash_sfx = sha256(source.encode('utf-8')).hexdigest()[:8]
path = os.path.join(save_dir, f"{name}.{hash_sfx}.py")
with open(path, 'w') as f:
f.write(source)
linecache.cache[path] = (len(source), None, source.splitlines(), path)
return compile(source, path, "exec")
def _remove_leading_ident(source: str):
lines = source.splitlines()
if not lines:
return source
ident_size = len(lines[0]) - len(lines[0].lstrip())
return "\n".join([line[ident_size:] if len(line) >= ident_size else line for line in lines])
def get_func_nonlocals(func):
"""A modified version of `inspect.getclosurevars`"""
if inspect.ismethod(func):
func = func.__func__
if not inspect.isfunction(func):
raise TypeError(f"{func!r} is not a Python function")
code = func.__code__
# Nonlocal references are named in co_freevars and resolved
# by looking them up in __closure__ by positional index
nonlocal_vars = {}
if func.__closure__ is not None:
for var, cell in zip(code.co_freevars, func.__closure__):
try:
nonlocal_vars[var] = cell.cell_contents
except ValueError as err:
# cell_contents may raise ValueError if the cell is empty.
if "empty" not in str(err):
raise
return nonlocal_vars
def inspect_function_capture(func: Callable) -> dict[str, Any]:
"""Capture function non-locals and global variables.
Parameters
----------
func : Callable
The function to inspect.
Returns
-------
res : Dict[str, Any]
The function variables map with non-local or global variables.
"""
captured = {
**func.__globals__, # type: ignore
**get_func_nonlocals(func),
}
return captured
def get_ast(func: Callable):
_, start = inspect.getsourcelines(func)
filename = inspect.getsourcefile(func) or inspect.getfile(func)
source = inspect.getsource(func)
source = _remove_leading_ident(source)
source = '\n' * (start - 1) + source
tree = ast.parse(source, filename=filename)
return tree
CompileMethod = Literal['direct', 'disk']
def get_compiled_object(source: str | ast.AST,
name: str,
filename: str = None,
globals: dict[str, Any] = None):
if isinstance(source, ast.AST):
assert filename is not None, "filename must be provided when source is an AST"
try:
if isinstance(source, ast.AST):
ast.fix_missing_locations(source)
compiled = compile(source, filename, 'exec')
else:
compiled = disk_compile(source, name)
except Exception as e:
source_str = source if isinstance(source, str) else ast.unparse(source)
raise RuntimeError(f'Failed to compile source for {name}, Error: {e}:\n{source_str}') from e
locs = {}
exec(compiled, globals, locs)
return locs[name]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment