"docs/en_US/Tutorial/Installation.md" did not exist on "a441558c7b79fa0feaf4868b4b8fa1d66b4120c1"
Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
import tilelang
import tilelang.language as T
import torch
import tilelang.testing
import tvm
from tvm.script.ir_builder.base import IRBuilderFrame
from tvm.tir.expr import IntImm, Var
def test_argument():
@T.prim_func
def test_argument(
t_1: T.bool,
t_2: T.short,
t_3: T.int,
t_4: T.long,
t_5: T.half,
t_6: T.float,
t_7: T.long,
t_8: T.int8,
t_9: T.int16,
t_10: T.int32,
t_11: T.int64,
t_12: T.uint8,
t_13: T.uint16,
t_14: T.uint32,
t_15: T.uint64,
t_16: T.float8_e4m3fn,
t_17: T.float8_e4m3fnuz,
t_18: T.float8_e5m2,
t_19: T.float8_e5m2fnuz,
t_20: T.float8_e8m0fnu,
t_21: T.float16,
t_22: T.bfloat16,
t_23: T.float32,
t_24: T.float64,
):
pass
def test_expr():
from tilelang.language.v2.dtypes import _all_dtypes
errors = []
for name in _all_dtypes:
dtype = getattr(T, name)
assert isinstance(dtype, tvm.DataType), f"{dtype} is not tvm.DataType"
try:
dtype(1.0)
dtype()
except TypeError:
pass
except Exception:
errors.append(name)
assert not errors
# def test_var_decl_sugar():
# @T.prim_func
# def test_var_decl_sugar():
# with T.Kernel(128, 128) as (bx, by):
# var_1: T.bool = 1.0
# var_2: T.short = 1.0
# var_3: T.int = 1.0
# var_4: T.long = 1.0
# var_5: T.half = 1.0
# var_6: T.float = 1.0
# var_7: T.long = 1.0
# var_8: T.int8 = 1.0
# var_9: T.int16 = 1.0
# var_10: T.int32 = 1.0
# var_11: T.int64 = 1.0
# var_12: T.uint8 = 1.0
# var_13: T.uint16 = 1.0
# var_14: T.uint32 = 1.0
# var_15: T.uint64 = 1.0
# var_16: T.float8_e4m3fn = 1.0
# var_17: T.float8_e4m3fnuz = 1.0
# var_18: T.float8_e5m2 = 1.0
# var_19: T.float8_e5m2fnuz = 1.0
# var_20: T.float8_e8m0fnu = 1.0
# var_21: T.float16 = 1.0
# var_22: T.bfloat16 = 1.0
# var_23: T.float32 = 1.0
# var_24: T.float64 = 1.0
# var_1: T.bool = var_1
# var_2: T.short = var_2
# var_3: T.int = var_3
# var_4: T.long = var_4
# var_5: T.half = var_5
# var_6: T.float = var_6
# var_7: T.long = var_7
# var_8: T.int8 = var_8
# var_9: T.int16 = var_9
# var_10: T.int32 = var_10
# var_11: T.int64 = var_11
# var_12: T.uint8 = var_12
# var_13: T.uint16 = var_13
# var_14: T.uint32 = var_14
# var_15: T.uint64 = var_15
# var_16: T.float8_e4m3fn = var_16
# var_17: T.float8_e4m3fnuz = var_17
# var_18: T.float8_e5m2 = var_18
# var_19: T.float8_e5m2fnuz = var_19
# var_20: T.float8_e8m0fnu = var_20
# var_21: T.float16 = var_21
# var_22: T.bfloat16 = var_22
# var_23: T.float32 = var_23
# var_24: T.float64 = var_24
# s = test_var_decl_sugar.script()
# for i in range(1, 25):
# assert f'var_{i}_1' in s
# assert 'tl.local_var_init' in s
def test_dtype_str_repr():
@T.prim_func
def test_str_repr():
buf_1 = T.alloc_buffer((1,), dtype=T.bool, scope='shared') # noqa F841
buf_2 = T.alloc_buffer((1,), dtype=T.short, scope='shared') # noqa F841
buf_3 = T.alloc_buffer((1,), dtype=T.int, scope='shared') # noqa F841
buf_4 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841
buf_5 = T.alloc_buffer((1,), dtype=T.half, scope='shared') # noqa F841
buf_6 = T.alloc_buffer((1,), dtype=T.float, scope='shared') # noqa F841
buf_7 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841
buf_8 = T.alloc_buffer((1,), dtype=T.int8, scope='shared') # noqa F841
buf_9 = T.alloc_buffer((1,), dtype=T.int16, scope='shared') # noqa F841
buf_10 = T.alloc_buffer((1,), dtype=T.int32, scope='shared') # noqa F841
buf_11 = T.alloc_buffer((1,), dtype=T.int64, scope='shared') # noqa F841
buf_12 = T.alloc_buffer((1,), dtype=T.uint8, scope='shared') # noqa F841
buf_13 = T.alloc_buffer((1,), dtype=T.uint16, scope='shared') # noqa F841
buf_14 = T.alloc_buffer((1,), dtype=T.uint32, scope='shared') # noqa F841
buf_15 = T.alloc_buffer((1,), dtype=T.uint64, scope='shared') # noqa F841
buf_16 = T.alloc_buffer((1,), dtype=T.float8_e4m3fn, scope='shared') # noqa F841
buf_17 = T.alloc_buffer((1,), dtype=T.float8_e4m3fnuz, scope='shared') # noqa F841
buf_18 = T.alloc_buffer((1,), dtype=T.float8_e5m2, scope='shared') # noqa F841
buf_19 = T.alloc_buffer((1,), dtype=T.float8_e5m2fnuz, scope='shared') # noqa F841
buf_20 = T.alloc_buffer((1,), dtype=T.float8_e8m0fnu, scope='shared') # noqa F841
buf_21 = T.alloc_buffer((1,), dtype=T.float16, scope='shared') # noqa F841
buf_22 = T.alloc_buffer((1,), dtype=T.bfloat16, scope='shared') # noqa F841
buf_23 = T.alloc_buffer((1,), dtype=T.float32, scope='shared') # noqa F841
buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared') # noqa F841
def test_torch_eq():
dtypes = [
T.bool,
T.short,
T.int,
T.long,
T.half,
T.float,
T.long,
T.int8,
T.int16,
T.int32,
T.int64,
T.uint8,
T.uint16,
T.uint32,
T.uint64,
T.float8_e4m3fn,
T.float8_e4m3fnuz,
T.float8_e5m2,
T.float8_e5m2fnuz,
T.float8_e8m0fnu,
T.float16,
T.bfloat16,
T.float32,
T.float64,
]
torch_dtypes = [
torch.bool,
torch.short,
torch.int,
torch.long,
torch.half,
torch.float,
torch.long,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
]
for a, b in zip(dtypes, torch_dtypes):
assert a == b, f"{a} and {b} are not equal"
assert T.dtype(b) == a, "dtype conversion error"
def test_var_assign():
@tilelang.jit(out_idx=-1)
@T.prim_func
def test_var_assign(A: T.Tensor((2,), T.int32)):
with T.Kernel(1) as _:
a: T.int32 = 1
b: T.int32 = a
a = 2
d: T.int32 = a
A[0] = b
A[1] = d
res = test_var_assign()()
assert res[0] == 1
assert res[1] == 2
def test_marco_return():
@T.macro
def macro_return_constant():
return 0
@T.macro
def macro_return_frame(x):
return T.alloc_var(T.float32, init=x)
@T.macro
def macro_return_expr(x):
y = x + 1.0
return y
@T.macro
def macro_apply_func(x, fn):
return fn(x)
def check(x, ty):
assert isinstance(x, ty)
@T.prim_func
def test_macro_return():
with T.Kernel(1) as _:
a = macro_return_constant()
b = macro_return_frame(3.0)
c = macro_return_expr(4.0)
d = macro_apply_func(5.0, lambda x: x * 2.0)
check(a, (int, float, T.PrimExpr))
check(b, T.PrimExpr)
check(c, T.PrimExpr)
check(d, T.PrimExpr)
def test_prim_func_generator():
@T.prim_func(generator=True)
def prim_func_gen(
A=T.Tensor((128,), T.float32), # noqa: B008
B=T.Tensor((128,), T.float32), # noqa: B008
):
with T.Kernel(128) as (tx,):
T.copy(A[tx], B[tx])
prim_func_gen()
@T.prim_func
def foo() -> T.Tensor((128,), T.float32):
pass
assert isinstance(foo, T.PrimFunc)
def test_serial_for_with_step():
@tilelang.jit(out_idx=-1)
@T.prim_func
def test_stepped_serial(A: T.Tensor((10,), T.int32)):
with T.Kernel(1) as _:
for i in range(0, 10, 2):
T.device_assert(0 <= i < 10 and i % 2 == 0, "i out of range")
A[i] = 1.0
for i in range(1, 10, 2):
T.device_assert(1 <= i < 10 and i % 2 == 1, "i out of range")
A[i] = 2.0
ker = test_stepped_serial()
res = ker()
ref = torch.tensor([1, 2, 1, 2, 1, 2, 1, 2, 1, 2], dtype=torch.int32, device='cuda')
assert torch.all(res == ref), f"Expected {ref}, but got {res}"
@tilelang.jit(out_idx=-1)
@T.prim_func
def test_serial_step_neg(A: T.Tensor((10,), T.int32)):
with T.Kernel(1) as _:
for i in range(10, 0, -1):
T.device_assert(0 < i <= 10, "i out of range")
A[10 - i] = i
ker = test_serial_step_neg()
res = ker()
ref = torch.tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1], dtype=torch.int32, device='cuda')
assert torch.all(res == ref), f"Expected {ref}, but got {res}"
assert isinstance(T.serial(1, 10, 1), IRBuilderFrame)
assert isinstance(T.serial(1, 10, IntImm('int32', 1)), IRBuilderFrame)
assert not isinstance(T.serial(1, 10, Var('tmp', 'int32')), IRBuilderFrame)
assert not isinstance(T.serial(10, -1, -1), IRBuilderFrame)
def test_swap_logic():
@tilelang.jit
@T.prim_func
def swap_var(A: T.Tensor[(2,), T.float32]):
with T.Kernel(1, threads=1) as _:
a = T.alloc_var(T.float32, A[0])
b = T.alloc_var(T.float32, A[1])
a, b = b, a
A[0], A[1] = a, b
@tilelang.jit
@T.prim_func
def swap_idx(A: T.Tensor[(2,), T.float32]):
with T.Kernel(1, threads=1) as _:
A[0], A[1] = A[1], A[0]
k_swap_var = swap_var()
data = torch.tensor([1.0, 2.0], dtype=torch.float32).cuda()
k_swap_var(data)
ref = torch.tensor([2.0, 1.0], dtype=torch.float32).cuda()
torch.testing.assert_close(data, ref)
k_swap_idx = swap_idx()
data = torch.tensor([1.0, 2.0], dtype=torch.float32).cuda()
k_swap_idx(data)
ref = torch.tensor([2.0, 1.0], dtype=torch.float32).cuda()
torch.testing.assert_close(data, ref)
def test_while_loop():
@tilelang.jit(out_idx=-1)
@T.prim_func
def test_while_loop(A: T.Tensor((1,), T.int32)):
with T.Kernel(1) as _:
i = T.alloc_var(T.int32, 0)
sum = T.alloc_var(T.int32)
while i < 10:
sum += i
i += 1
A[0] = sum
ker = test_while_loop()
A = ker()
assert A[0].item() == sum(range(10)), f"Expected {sum(range(10))}, but got {A[0].item()}"
if __name__ == '__main__':
tilelang.testing.main()
......@@ -209,4 +209,3 @@ def test_shuffle_elect_block_leader():
if __name__ == "__main__":
tilelang.testing.main()
# run_get_lane_id()
import torch
import tilelang
import tilelang.language as T
@tilelang.jit(out_idx=-1)
def get_inf_kernel(dtype: str):
@T.prim_func
def main(A: T.Tensor((32,), dtype)):
with T.Kernel(1, threads=32):
T.fill(A, T.infinity(dtype))
return main
def _test_infinity(dtype: str):
kernel = get_inf_kernel(dtype)
output = kernel()
assert torch.all(output == torch.inf), f'check failed for {dtype=}'
@tilelang.testing.requires_cuda
def test_infinity():
_test_infinity("float16")
_test_infinity("bfloat16")
_test_infinity("float32")
_test_infinity("float64")
if __name__ == "__main__":
tilelang.testing.main()
import tilelang.testing
from tilelang import tvm as tvm
from tilelang import language as T
def test_let_vectorize_load():
@T.prim_func
def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
b = A[0, 0:4]
A[0, 4:8] = b
mod = tvm.IRModule({"main": main})
mod = tvm.compile(mod, target="cuda")
assert "float4 b" in mod.mod.imports[0].inspect_source()
if __name__ == "__main__":
tilelang.testing.main()
from tilelang import tvm
import tilelang as tl
import tilelang.testing
from tvm.script import tir as T
@T.prim_func
def negative_index_before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")):
T.func_attr({"tir.noalias": True})
B[0] = A[T.int32(-1)]
@T.prim_func
def negative_index_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")):
T.func_attr({"tir.noalias": True})
B[0] = A[T.int32(15)]
@T.prim_func
def negative_index_loop_before(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")):
T.func_attr({"tir.noalias": True})
for i in T.serial(4):
B[i] = A[-i - 1]
@T.prim_func
def negative_index_loop_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")):
T.func_attr({"tir.noalias": True})
for i in T.serial(4):
B[i] = A[15 - i]
@T.prim_func
def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), "float32"),
B: T.Buffer((16,), "float32")):
T.func_attr({"tir.noalias": True})
for i in T.serial(16):
B[i] = A[shift + i]
def test_legalize_negative_index_scalar():
mod = tvm.IRModule({"main": negative_index_before})
transformed = tl.transform.LegalizeNegativeIndex()(mod)
tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_expected.body)
def test_legalize_negative_index_affine_expr():
mod = tvm.IRModule({"main": negative_index_loop_before})
transformed = tl.transform.LegalizeNegativeIndex()(mod)
tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_loop_expected.body)
def test_legalize_negative_index_symbolic_passthrough():
mod = tvm.IRModule({"main": negative_index_symbolic_before})
transformed = tl.transform.LegalizeNegativeIndex()(mod)
tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_symbolic_before.body)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -116,7 +116,6 @@ def test_reduce_sum():
def test_reduce_sum_shared():
run_reduce_sum(64, 64, mode="ss")
run_reduce_sum(32, 96, mode="ss")
def test_reduce_max():
......@@ -127,7 +126,6 @@ def test_reduce_max():
def test_reduce_max_shared():
run_shared_reduce(reduce_max_ss, lambda A: A.max(dim=1).values, 64, 64, "float32")
run_shared_reduce(reduce_max_ss, lambda A: A.max(dim=1).values, 96, 48, "float32")
def test_reduce_min_shared():
......
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
import torch
def reshape_test(N, M, dtype):
......@@ -129,5 +130,137 @@ def test_reshape_smem_2d_2_1d():
run_reshape_smem_2d_2_1d(2048, 64, "float16")
def reshape_fragment_test(N, M, dtype):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((N // M, M), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_shared = T.alloc_shared((N // M, M), dtype, scope="shared")
A_local = T.alloc_fragment((N // M, M), dtype)
B_shared = T.alloc_shared((N,), dtype, scope="shared")
T.copy(A, A_shared)
T.copy(A_shared, A_local)
A_local_reshape = T.reshape(A_local, [N])
T.copy(A_local_reshape, B_shared)
T.copy(B_shared, B)
return main
def run_reshape_fragment(N, M, dtype):
program = reshape_fragment_test(N, M, dtype)
jit_kernel = tl.compile(
program,
out_idx=-1,
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = jit_kernel.get_profiler()
def ref_program(A):
return A.reshape(N)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_reshape_fragment():
run_reshape_fragment(1024, 32, "float32")
run_reshape_fragment(2048, 64, "float16")
def reshape_layout_transform_shared(N, M, dtype):
import tilelang.language as T
from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout
@T.prim_func
def main(
A: T.Tensor((N // M, M), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_shared = T.alloc_shared((N // M, M), dtype, scope="shared")
T.annotate_layout({
A_shared: make_mma_swizzle_layout(A_shared),
})
T.copy(A, A_shared)
A_shared_reshape = T.reshape(A_shared, [N])
T.copy(A_shared_reshape, B)
return main
def run_reshape_layout_transform_shared(N, M, dtype):
program = reshape_layout_transform_shared(N, M, dtype)
jit_kernel = tl.compile(
program,
out_idx=-1,
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = jit_kernel.get_profiler()
def ref_program(A):
return A.reshape(N)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_reshape_layout_transform_shared():
run_reshape_layout_transform_shared(1024, 32, "float32")
run_reshape_layout_transform_shared(2048, 64, "float16")
def reduce_after_reshape_test(N, M, dtype):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N // M,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_shared = T.alloc_shared((N,), dtype, scope="shared")
A_local = T.alloc_fragment((N,), dtype)
B_local = T.alloc_fragment((N // M,), dtype)
T.copy(A, A_shared)
T.copy(A_shared, A_local)
A_local_reshape = T.reshape(A_local, [N // M, M])
T.reduce_max(A_local_reshape, B_local, dim=1)
T.copy(B_local, B)
return main
def run_reduce_after_reshape(N, M, dtype):
program = reduce_after_reshape_test(N, M, dtype)
jit_kernel = tl.compile(
program,
out_idx=-1,
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = jit_kernel.get_profiler()
def ref_program(A):
return torch.max(A.reshape(N // M, M), dim=1).values
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_reduce_after_reshape():
run_reduce_after_reshape(1024, 32, "float32")
run_reduce_after_reshape(2048, 64, "float16")
if __name__ == "__main__":
tilelang.testing.main()
import argparse
import torch
import tilelang
import tilelang.language as T
import torch
def ref_program(x, y):
......@@ -30,23 +29,29 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
return elem_add
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=128)
parser.add_argument("--n", type=int, default=128)
args, _ = parser.parse_known_args()
M, N = args.m, args.n
def run_elementwise_add(M, N):
a = torch.randn(M, N, dtype=torch.float32, device="cuda")
b = torch.randn(M, N, dtype=torch.float32, device="cuda")
# Default config
config = {"block_M": 128, "block_N": 128, "threads": 128}
block_M, block_N = 128, 128
config = {"block_M": block_M, "block_N": block_N, "threads": 128}
kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")
out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
print("All passed!")
code = kernel.get_kernel_source()
if block_N == N:
assert "tma_load" in code and "CUtensorMap" not in code
else:
assert "tma_load" in code and "CUtensorMap" in code
def main():
run_elementwise_add(128, 128)
run_elementwise_add(256, 128)
run_elementwise_add(256, 256)
if __name__ == "__main__":
......
import tilelang
import tilelang.language as T
import tilelang.testing
def test_var_assign() -> None:
@tilelang.jit(out_idx=-1)
def jit_kernel():
@T.prim_func
def test_var_assign(A: T.Tensor((2,), 'int32')):
with T.Kernel(1) as _:
a = T.alloc_var('int32', init=1)
b = T.alloc_var('int32', init=a) # b gets value of a
a = 2
d = T.alloc_var('int32', init=a) # c gets new value of a
A[0] = b
A[1] = d
print(test_var_assign)
return test_var_assign
kernel = jit_kernel()
print(kernel.get_kernel_source())
res = kernel()
assert res[0] == 1
assert res[1] == 2
if __name__ == '__main__':
tilelang.testing.main()
......@@ -17,8 +17,8 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
@T.prim_func
def main(
A: T.Tensor[(M), dtype_A], # noqa: F821
B: T.Tensor[(M), dtype_B], # noqa: F821
A: T.Tensor[(M,), dtype_A], # noqa: F821
B: T.Tensor[(M,), dtype_B], # noqa: F821
):
with T.Kernel(1, threads=128):
T.copy(A, B)
......@@ -26,6 +26,27 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
return main
@tilelang.jit
def parallel_vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
assert M % 256 == 0
@T.prim_func
def main(
A: T.Tensor[(M,), dtype_A], # noqa: F821
B: T.Tensor[(M,), dtype_B], # noqa: F821
):
with T.Kernel(1, threads=128):
A_local = T.alloc_fragment((M,), dtype_A)
B_local = T.alloc_fragment((M,), dtype_B)
T.copy(A, A_local)
for i in T.Parallel(M):
B_local[i] = A_local[i]
T.copy(B_local, B)
return main
def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, lanes: int = 2):
"""Run the vectorized cast kernel and check the correctness.
Args:
......@@ -37,17 +58,22 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
M = 128 * lanes
kernel = vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str)
kernel_parallel = parallel_vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str)
A = torch.randn(M, dtype=str2dtype[src_dtype_str]).cuda()
B = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda()
C = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda()
kernel(A, B)
kernel_parallel(A, C)
torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), B)
torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), C)
code = kernel.get_kernel_source()
code_parallel = kernel_parallel.get_kernel_source()
assert check_str in code, \
assert check_str in code and check_str in code_parallel, \
f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!"
......
import pytest
import torch
import tilelang
import tilelang.testing
import tilelang.language as T
tilelang.testing.set_random_seed()
VEC_SIZE = 32
@tilelang.jit
def fused_index_kernel(B: int, M: int, N: int, BLOCK_MN: int, BLOCK_K: int):
@T.prim_func
def main(
a: T.Buffer((B, M, N), "bfloat16"),
a_out: T.Buffer((B, M, N), "float32"),
):
with T.Kernel(
T.ceildiv(M, BLOCK_MN),
T.ceildiv(N, BLOCK_K),
B,
threads=128,
) as (pid_m, pid_n, pid_b):
a_fp32_local = T.alloc_fragment((BLOCK_MN * BLOCK_K // VEC_SIZE, VEC_SIZE), "float32")
offs_m = pid_m * BLOCK_MN
offs_n = pid_n * BLOCK_K
for i, j in T.Parallel(BLOCK_MN, BLOCK_K):
idx = i * BLOCK_K + j
a_out[pid_b, offs_m + i, offs_n + j] = a_fp32_local[idx // VEC_SIZE, idx % VEC_SIZE]
return main
def _require_cuda_tensor(shape, dtype):
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
try:
return torch.randn(*shape, device="cuda", dtype=dtype)
except RuntimeError as err:
pytest.skip(f"CUDA runtime unavailable: {err}")
def test_layout_infer_compiles_and_runs():
B, M, N = 1, 32, 64
BLOCK_MN, BLOCK_K = 32, 64
kernel = fused_index_kernel(B, M, N, BLOCK_MN, BLOCK_K)
a = _require_cuda_tensor((B, M, N), torch.bfloat16)
a_out = torch.empty((B, M, N), dtype=torch.float32, device=a.device)
# Ensure kernel compiles and executes without layout inversion failure
kernel(a, a_out)
assert a_out.shape == a.shape
assert a_out.dtype == torch.float32
if __name__ == "__main__":
tilelang.testing.main()
......@@ -397,6 +397,7 @@ def test_gemm_sr():
run_gemm_sr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2)
# float32 tests
# TODO(lei): fix in future
run_gemm_sr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_sr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_sr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
......
......@@ -159,8 +159,8 @@ def test_wgmma_marked_async():
def before():
with T.Kernel(1):
A_shared = T.decl_buffer((1,), "float16", scope="shared")
desc_a = T.decl_buffer((1,), "uint64", scope="local.descriptor")
desc_b = T.decl_buffer((1,), "uint64", scope="local.descriptor")
desc_a = T.decl_buffer((1,), "uint64", scope="local.descriptor.wgmma")
desc_b = T.decl_buffer((1,), "uint64", scope="local.descriptor.wgmma")
C_local = T.decl_buffer((32,), "float16", scope="local")
A_shared[0] = T.float16(0)
T.warpgroup_arrive()
......
......@@ -15,8 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
N = tvm.te.var("n")
K = tvm.te.var("k")
@tvm.script.ir.ir_module
class Before:
def before():
@T.prim_func
def main(B: T.Tensor((K, N), dtype),):
......@@ -38,8 +37,9 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
(block_N // vec_load_b) * (block_N // vec_load_b) + vec],
T.float16(0))
@tvm.script.ir.ir_module
class After:
return tvm.IRModule({'main': main})
def after():
@T.prim_func
def main(B: T.Tensor((K, N), dtype),):
......@@ -77,11 +77,13 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
bx * block_N + t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec], T.float16(0))
return tvm.IRModule({'main': main})
with tvm.target.Target(auto_target):
mod = tvm.tir.transform.BindTarget(auto_target)(Before)
mod = tvm.tir.transform.BindTarget(auto_target)(before())
mod = tl.transform.LayoutInference()(mod)
mod = tvm.tir.transform.Simplify()(mod)
ref_mod = tvm.tir.transform.BindTarget(auto_target)(After)
ref_mod = tvm.tir.transform.BindTarget(auto_target)(after())
ref_mod = tvm.tir.transform.Simplify()(ref_mod)
# Note(tzj): The structures are equal except one more "for" loop after the LayoutInference pass
# This loop is "for vec in T.parallel(1)",
......
......@@ -15,8 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
N = tvm.te.var("n")
K = tvm.te.var("k")
@tvm.script.ir.ir_module
class Before:
def before():
@T.prim_func
def main(B: T.Tensor((K, N), dtype),):
......@@ -25,8 +24,9 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(B[k * block_K, bx * block_N], B_shared)
@tvm.script.ir.ir_module
class After:
return tvm.IRModule({'main': main})
def after():
@T.prim_func
def main(B: T.Tensor((K, N), dtype),):
......@@ -64,11 +64,13 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
bx * block_N + t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec], T.float16(0))
return tvm.IRModule({'main': main})
with tvm.transform.PassContext():
mod = tvm.tir.transform.BindTarget(auto_target)(Before)
mod = tvm.tir.transform.BindTarget(auto_target)(before())
mod = tl.transform.LowerTileOp()(mod)
mod = tvm.tir.transform.Simplify()(mod)
ref_mod = tvm.tir.transform.BindTarget(auto_target)(After)
ref_mod = tvm.tir.transform.BindTarget(auto_target)(after())
ref_mod = tvm.tir.transform.Simplify()(ref_mod)
# Note(tzj): The structures are equal except the argument in "T.reads" function.
# The difference is just between the first index and the indices range, which is totally equivalent
......
......@@ -113,7 +113,7 @@ def test_multi_version_buffer_with_let():
shared = T.alloc_buffer((8,), "float32", scope="shared.dyn")
accum = T.alloc_buffer((8,), "float32", scope="local")
for k in T.serial(4, annotations={"num_stages": T.int32(2)}):
value: T.float32 = scales[k]
value = scales[k]
for i in T.serial(8):
shared[i] = value
for i in T.serial(8):
......@@ -125,7 +125,7 @@ def test_multi_version_buffer_with_let():
shared = T.alloc_buffer((2, 8), "float32", scope="shared.dyn")
accum = T.alloc_buffer((8,), "float32", scope="local")
for k in T.serial(4, annotations={"num_stages": T.int32(2)}):
value: T.float32 = scales[k]
value = scales[k]
for i in T.serial(8):
shared[k % 2, i] = value
for i in T.serial(8):
......
......@@ -188,5 +188,41 @@ def test_sync_let_stmt():
tvm.ir.assert_structural_equal(mod["main"], expected)
@tilelang.testing.requires_cuda
def test_sync_shared_dyn_stmatrix_loop_hoist():
@T.prim_func
def func():
buf_dyn_shmem = T.alloc_buffer((98304,), "uint8", scope="shared.dyn")
tx = T.launch_thread("threadIdx.x", 384)
for i in T.unroll(8):
off = (
i // 4 * 8192 + tx // 32 * 1024 + tx % 16 * 64 +
(tx % 8 // 4 + i % 4 // 2) % 2 * 32 + (tx % 4 // 2 + i % 2) % 2 * 16 +
(tx % 32 // 16 + tx % 2) % 2 * 8)
T.evaluate(
T.call_intrin(
"handle",
tvm.tir.op.Op.get("tl.ptx_stmatrix"),
T.int32(0),
T.int32(4),
T.tvm_access_ptr(
T.type_annotation("uint8"),
buf_dyn_shmem.data,
off,
98304 - off,
2,
),
T.int32(2),
))
mod = tvm.IRModule({"main": func})
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
s = str(mod)
assert 'T.tvm_storage_sync("shared.dyn")' in s
# Ensure the sync appears before the unrolled loop
assert s.index('T.tvm_storage_sync("shared.dyn")') < s.index("for i in T.unroll(8)")
if __name__ == "__main__":
tilelang.testing.main()
......@@ -4,24 +4,46 @@ import ctypes
import logging
import warnings
from tqdm import tqdm
from pathlib import Path
from tqdm.auto import tqdm
from importlib.metadata import PackageNotFoundError, version
try:
__version__ = version('tilelang')
except PackageNotFoundError:
def _compute_version() -> str:
"""Return the package version without being polluted by unrelated installs.
Preference order:
1) If running from a source checkout (VERSION file present at repo root),
use the dynamic version from version_provider (falls back to plain VERSION).
2) Otherwise, use importlib.metadata for the installed distribution.
3) As a last resort, return a dev sentinel.
"""
try:
from version_provider import dynamic_metadata
repo_root = Path(__file__).resolve().parent.parent
version_file = repo_root / "VERSION"
if version_file.is_file():
try:
from version_provider import dynamic_metadata # type: ignore
return dynamic_metadata("version")
except Exception:
# Fall back to the raw VERSION file if provider isn't available.
return version_file.read_text().strip()
except Exception:
# If any of the above fails, fall through to installed metadata.
pass
__version__ = dynamic_metadata('version')
try:
from importlib.metadata import version as _dist_version # py3.8+
return _dist_version("tilelang")
except Exception as exc:
warnings.warn(
f"tilelang version metadata unavailable ({exc!r}); using development version.",
RuntimeWarning,
stacklevel=2,
)
__version__ = "0.0.dev0"
return "0.0.dev0"
__version__ = _compute_version()
class TqdmLoggingHandler(logging.Handler):
......
"""FFI APIs for tilelang"""
import tvm.ffi
import tvm_ffi
# TVM_REGISTER_GLOBAL("tl.name").set_body_typed(func);
tvm.ffi._init_api("tl", __name__) # pylint: disable=protected-access
tvm_ffi.init_ffi_api("tl", __name__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment