Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
import tilelang
import tilelang.language as T
import torch
import tilelang.testing
import tvm
from tvm.script.ir_builder.base import IRBuilderFrame
from tvm.tir.expr import IntImm, Var
def test_argument():
@T.prim_func
def test_argument(
t_1: T.bool,
t_2: T.short,
t_3: T.int,
t_4: T.long,
t_5: T.half,
t_6: T.float,
t_7: T.long,
t_8: T.int8,
t_9: T.int16,
t_10: T.int32,
t_11: T.int64,
t_12: T.uint8,
t_13: T.uint16,
t_14: T.uint32,
t_15: T.uint64,
t_16: T.float8_e4m3fn,
t_17: T.float8_e4m3fnuz,
t_18: T.float8_e5m2,
t_19: T.float8_e5m2fnuz,
t_20: T.float8_e8m0fnu,
t_21: T.float16,
t_22: T.bfloat16,
t_23: T.float32,
t_24: T.float64,
):
pass
def test_expr():
from tilelang.language.v2.dtypes import _all_dtypes
errors = []
for name in _all_dtypes:
dtype = getattr(T, name)
assert isinstance(dtype, tvm.DataType), f"{dtype} is not tvm.DataType"
try:
dtype(1.0)
dtype()
except TypeError:
pass
except Exception:
errors.append(name)
assert not errors
# def test_var_decl_sugar():
# @T.prim_func
# def test_var_decl_sugar():
# with T.Kernel(128, 128) as (bx, by):
# var_1: T.bool = 1.0
# var_2: T.short = 1.0
# var_3: T.int = 1.0
# var_4: T.long = 1.0
# var_5: T.half = 1.0
# var_6: T.float = 1.0
# var_7: T.long = 1.0
# var_8: T.int8 = 1.0
# var_9: T.int16 = 1.0
# var_10: T.int32 = 1.0
# var_11: T.int64 = 1.0
# var_12: T.uint8 = 1.0
# var_13: T.uint16 = 1.0
# var_14: T.uint32 = 1.0
# var_15: T.uint64 = 1.0
# var_16: T.float8_e4m3fn = 1.0
# var_17: T.float8_e4m3fnuz = 1.0
# var_18: T.float8_e5m2 = 1.0
# var_19: T.float8_e5m2fnuz = 1.0
# var_20: T.float8_e8m0fnu = 1.0
# var_21: T.float16 = 1.0
# var_22: T.bfloat16 = 1.0
# var_23: T.float32 = 1.0
# var_24: T.float64 = 1.0
# var_1: T.bool = var_1
# var_2: T.short = var_2
# var_3: T.int = var_3
# var_4: T.long = var_4
# var_5: T.half = var_5
# var_6: T.float = var_6
# var_7: T.long = var_7
# var_8: T.int8 = var_8
# var_9: T.int16 = var_9
# var_10: T.int32 = var_10
# var_11: T.int64 = var_11
# var_12: T.uint8 = var_12
# var_13: T.uint16 = var_13
# var_14: T.uint32 = var_14
# var_15: T.uint64 = var_15
# var_16: T.float8_e4m3fn = var_16
# var_17: T.float8_e4m3fnuz = var_17
# var_18: T.float8_e5m2 = var_18
# var_19: T.float8_e5m2fnuz = var_19
# var_20: T.float8_e8m0fnu = var_20
# var_21: T.float16 = var_21
# var_22: T.bfloat16 = var_22
# var_23: T.float32 = var_23
# var_24: T.float64 = var_24
# s = test_var_decl_sugar.script()
# for i in range(1, 25):
# assert f'var_{i}_1' in s
# assert 'tl.local_var_init' in s
def test_dtype_str_repr():
@T.prim_func
def test_str_repr():
buf_1 = T.alloc_buffer((1,), dtype=T.bool, scope='shared') # noqa F841
buf_2 = T.alloc_buffer((1,), dtype=T.short, scope='shared') # noqa F841
buf_3 = T.alloc_buffer((1,), dtype=T.int, scope='shared') # noqa F841
buf_4 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841
buf_5 = T.alloc_buffer((1,), dtype=T.half, scope='shared') # noqa F841
buf_6 = T.alloc_buffer((1,), dtype=T.float, scope='shared') # noqa F841
buf_7 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841
buf_8 = T.alloc_buffer((1,), dtype=T.int8, scope='shared') # noqa F841
buf_9 = T.alloc_buffer((1,), dtype=T.int16, scope='shared') # noqa F841
buf_10 = T.alloc_buffer((1,), dtype=T.int32, scope='shared') # noqa F841
buf_11 = T.alloc_buffer((1,), dtype=T.int64, scope='shared') # noqa F841
buf_12 = T.alloc_buffer((1,), dtype=T.uint8, scope='shared') # noqa F841
buf_13 = T.alloc_buffer((1,), dtype=T.uint16, scope='shared') # noqa F841
buf_14 = T.alloc_buffer((1,), dtype=T.uint32, scope='shared') # noqa F841
buf_15 = T.alloc_buffer((1,), dtype=T.uint64, scope='shared') # noqa F841
buf_16 = T.alloc_buffer((1,), dtype=T.float8_e4m3fn, scope='shared') # noqa F841
buf_17 = T.alloc_buffer((1,), dtype=T.float8_e4m3fnuz, scope='shared') # noqa F841
buf_18 = T.alloc_buffer((1,), dtype=T.float8_e5m2, scope='shared') # noqa F841
buf_19 = T.alloc_buffer((1,), dtype=T.float8_e5m2fnuz, scope='shared') # noqa F841
buf_20 = T.alloc_buffer((1,), dtype=T.float8_e8m0fnu, scope='shared') # noqa F841
buf_21 = T.alloc_buffer((1,), dtype=T.float16, scope='shared') # noqa F841
buf_22 = T.alloc_buffer((1,), dtype=T.bfloat16, scope='shared') # noqa F841
buf_23 = T.alloc_buffer((1,), dtype=T.float32, scope='shared') # noqa F841
buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared') # noqa F841
def test_torch_eq():
dtypes = [
T.bool,
T.short,
T.int,
T.long,
T.half,
T.float,
T.long,
T.int8,
T.int16,
T.int32,
T.int64,
T.uint8,
T.uint16,
T.uint32,
T.uint64,
T.float8_e4m3fn,
T.float8_e4m3fnuz,
T.float8_e5m2,
T.float8_e5m2fnuz,
T.float8_e8m0fnu,
T.float16,
T.bfloat16,
T.float32,
T.float64,
]
torch_dtypes = [
torch.bool,
torch.short,
torch.int,
torch.long,
torch.half,
torch.float,
torch.long,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
]
for a, b in zip(dtypes, torch_dtypes):
assert a == b, f"{a} and {b} are not equal"
assert T.dtype(b) == a, "dtype conversion error"
def test_var_assign():
@tilelang.jit(out_idx=-1)
@T.prim_func
def test_var_assign(A: T.Tensor((2,), T.int32)):
with T.Kernel(1) as _:
a: T.int32 = 1
b: T.int32 = a
a = 2
d: T.int32 = a
A[0] = b
A[1] = d
res = test_var_assign()()
assert res[0] == 1
assert res[1] == 2
def test_marco_return():
@T.macro
def macro_return_constant():
return 0
@T.macro
def macro_return_frame(x):
return T.alloc_var(T.float32, init=x)
@T.macro
def macro_return_expr(x):
y = x + 1.0
return y
@T.macro
def macro_apply_func(x, fn):
return fn(x)
def check(x, ty):
assert isinstance(x, ty)
@T.prim_func
def test_macro_return():
with T.Kernel(1) as _:
a = macro_return_constant()
b = macro_return_frame(3.0)
c = macro_return_expr(4.0)
d = macro_apply_func(5.0, lambda x: x * 2.0)
check(a, (int, float, T.PrimExpr))
check(b, T.PrimExpr)
check(c, T.PrimExpr)
check(d, T.PrimExpr)
def test_prim_func_generator():
@T.prim_func(generator=True)
def prim_func_gen(
A=T.Tensor((128,), T.float32), # noqa: B008
B=T.Tensor((128,), T.float32), # noqa: B008
):
with T.Kernel(128) as (tx,):
T.copy(A[tx], B[tx])
prim_func_gen()
@T.prim_func
def foo() -> T.Tensor((128,), T.float32):
pass
assert isinstance(foo, T.PrimFunc)
def test_serial_for_with_step():
@tilelang.jit(out_idx=-1)
@T.prim_func
def test_stepped_serial(A: T.Tensor((10,), T.int32)):
with T.Kernel(1) as _:
for i in range(0, 10, 2):
T.device_assert(0 <= i < 10 and i % 2 == 0, "i out of range")
A[i] = 1.0
for i in range(1, 10, 2):
T.device_assert(1 <= i < 10 and i % 2 == 1, "i out of range")
A[i] = 2.0
ker = test_stepped_serial()
res = ker()
ref = torch.tensor([1, 2, 1, 2, 1, 2, 1, 2, 1, 2], dtype=torch.int32, device='cuda')
assert torch.all(res == ref), f"Expected {ref}, but got {res}"
@tilelang.jit(out_idx=-1)
@T.prim_func
def test_serial_step_neg(A: T.Tensor((10,), T.int32)):
with T.Kernel(1) as _:
for i in range(10, 0, -1):
T.device_assert(0 < i <= 10, "i out of range")
A[10 - i] = i
ker = test_serial_step_neg()
res = ker()
ref = torch.tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1], dtype=torch.int32, device='cuda')
assert torch.all(res == ref), f"Expected {ref}, but got {res}"
assert isinstance(T.serial(1, 10, 1), IRBuilderFrame)
assert isinstance(T.serial(1, 10, IntImm('int32', 1)), IRBuilderFrame)
assert not isinstance(T.serial(1, 10, Var('tmp', 'int32')), IRBuilderFrame)
assert not isinstance(T.serial(10, -1, -1), IRBuilderFrame)
def test_swap_logic():
@tilelang.jit
@T.prim_func
def swap_var(A: T.Tensor[(2,), T.float32]):
with T.Kernel(1, threads=1) as _:
a = T.alloc_var(T.float32, A[0])
b = T.alloc_var(T.float32, A[1])
a, b = b, a
A[0], A[1] = a, b
@tilelang.jit
@T.prim_func
def swap_idx(A: T.Tensor[(2,), T.float32]):
with T.Kernel(1, threads=1) as _:
A[0], A[1] = A[1], A[0]
k_swap_var = swap_var()
data = torch.tensor([1.0, 2.0], dtype=torch.float32).cuda()
k_swap_var(data)
ref = torch.tensor([2.0, 1.0], dtype=torch.float32).cuda()
torch.testing.assert_close(data, ref)
k_swap_idx = swap_idx()
data = torch.tensor([1.0, 2.0], dtype=torch.float32).cuda()
k_swap_idx(data)
ref = torch.tensor([2.0, 1.0], dtype=torch.float32).cuda()
torch.testing.assert_close(data, ref)
def test_while_loop():
@tilelang.jit(out_idx=-1)
@T.prim_func
def test_while_loop(A: T.Tensor((1,), T.int32)):
with T.Kernel(1) as _:
i = T.alloc_var(T.int32, 0)
sum = T.alloc_var(T.int32)
while i < 10:
sum += i
i += 1
A[0] = sum
ker = test_while_loop()
A = ker()
assert A[0].item() == sum(range(10)), f"Expected {sum(range(10))}, but got {A[0].item()}"
if __name__ == '__main__':
tilelang.testing.main()
...@@ -209,4 +209,3 @@ def test_shuffle_elect_block_leader(): ...@@ -209,4 +209,3 @@ def test_shuffle_elect_block_leader():
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
# run_get_lane_id()
import torch
import tilelang
import tilelang.language as T
@tilelang.jit(out_idx=-1)
def get_inf_kernel(dtype: str):
@T.prim_func
def main(A: T.Tensor((32,), dtype)):
with T.Kernel(1, threads=32):
T.fill(A, T.infinity(dtype))
return main
def _test_infinity(dtype: str):
kernel = get_inf_kernel(dtype)
output = kernel()
assert torch.all(output == torch.inf), f'check failed for {dtype=}'
@tilelang.testing.requires_cuda
def test_infinity():
_test_infinity("float16")
_test_infinity("bfloat16")
_test_infinity("float32")
_test_infinity("float64")
if __name__ == "__main__":
tilelang.testing.main()
import tilelang.testing
from tilelang import tvm as tvm
from tilelang import language as T
def test_let_vectorize_load():
@T.prim_func
def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
b = A[0, 0:4]
A[0, 4:8] = b
mod = tvm.IRModule({"main": main})
mod = tvm.compile(mod, target="cuda")
assert "float4 b" in mod.mod.imports[0].inspect_source()
if __name__ == "__main__":
tilelang.testing.main()
from tilelang import tvm
import tilelang as tl
import tilelang.testing
from tvm.script import tir as T
@T.prim_func
def negative_index_before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")):
T.func_attr({"tir.noalias": True})
B[0] = A[T.int32(-1)]
@T.prim_func
def negative_index_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")):
T.func_attr({"tir.noalias": True})
B[0] = A[T.int32(15)]
@T.prim_func
def negative_index_loop_before(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")):
T.func_attr({"tir.noalias": True})
for i in T.serial(4):
B[i] = A[-i - 1]
@T.prim_func
def negative_index_loop_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")):
T.func_attr({"tir.noalias": True})
for i in T.serial(4):
B[i] = A[15 - i]
@T.prim_func
def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), "float32"),
B: T.Buffer((16,), "float32")):
T.func_attr({"tir.noalias": True})
for i in T.serial(16):
B[i] = A[shift + i]
def test_legalize_negative_index_scalar():
mod = tvm.IRModule({"main": negative_index_before})
transformed = tl.transform.LegalizeNegativeIndex()(mod)
tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_expected.body)
def test_legalize_negative_index_affine_expr():
mod = tvm.IRModule({"main": negative_index_loop_before})
transformed = tl.transform.LegalizeNegativeIndex()(mod)
tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_loop_expected.body)
def test_legalize_negative_index_symbolic_passthrough():
mod = tvm.IRModule({"main": negative_index_symbolic_before})
transformed = tl.transform.LegalizeNegativeIndex()(mod)
tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_symbolic_before.body)
if __name__ == "__main__":
tilelang.testing.main()
...@@ -116,7 +116,6 @@ def test_reduce_sum(): ...@@ -116,7 +116,6 @@ def test_reduce_sum():
def test_reduce_sum_shared(): def test_reduce_sum_shared():
run_reduce_sum(64, 64, mode="ss") run_reduce_sum(64, 64, mode="ss")
run_reduce_sum(32, 96, mode="ss")
def test_reduce_max(): def test_reduce_max():
...@@ -127,7 +126,6 @@ def test_reduce_max(): ...@@ -127,7 +126,6 @@ def test_reduce_max():
def test_reduce_max_shared(): def test_reduce_max_shared():
run_shared_reduce(reduce_max_ss, lambda A: A.max(dim=1).values, 64, 64, "float32") run_shared_reduce(reduce_max_ss, lambda A: A.max(dim=1).values, 64, 64, "float32")
run_shared_reduce(reduce_max_ss, lambda A: A.max(dim=1).values, 96, 48, "float32")
def test_reduce_min_shared(): def test_reduce_min_shared():
......
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
import tilelang as tl import tilelang as tl
import torch
def reshape_test(N, M, dtype): def reshape_test(N, M, dtype):
...@@ -129,5 +130,137 @@ def test_reshape_smem_2d_2_1d(): ...@@ -129,5 +130,137 @@ def test_reshape_smem_2d_2_1d():
run_reshape_smem_2d_2_1d(2048, 64, "float16") run_reshape_smem_2d_2_1d(2048, 64, "float16")
def reshape_fragment_test(N, M, dtype):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((N // M, M), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_shared = T.alloc_shared((N // M, M), dtype, scope="shared")
A_local = T.alloc_fragment((N // M, M), dtype)
B_shared = T.alloc_shared((N,), dtype, scope="shared")
T.copy(A, A_shared)
T.copy(A_shared, A_local)
A_local_reshape = T.reshape(A_local, [N])
T.copy(A_local_reshape, B_shared)
T.copy(B_shared, B)
return main
def run_reshape_fragment(N, M, dtype):
program = reshape_fragment_test(N, M, dtype)
jit_kernel = tl.compile(
program,
out_idx=-1,
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = jit_kernel.get_profiler()
def ref_program(A):
return A.reshape(N)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_reshape_fragment():
run_reshape_fragment(1024, 32, "float32")
run_reshape_fragment(2048, 64, "float16")
def reshape_layout_transform_shared(N, M, dtype):
import tilelang.language as T
from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout
@T.prim_func
def main(
A: T.Tensor((N // M, M), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_shared = T.alloc_shared((N // M, M), dtype, scope="shared")
T.annotate_layout({
A_shared: make_mma_swizzle_layout(A_shared),
})
T.copy(A, A_shared)
A_shared_reshape = T.reshape(A_shared, [N])
T.copy(A_shared_reshape, B)
return main
def run_reshape_layout_transform_shared(N, M, dtype):
program = reshape_layout_transform_shared(N, M, dtype)
jit_kernel = tl.compile(
program,
out_idx=-1,
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = jit_kernel.get_profiler()
def ref_program(A):
return A.reshape(N)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_reshape_layout_transform_shared():
run_reshape_layout_transform_shared(1024, 32, "float32")
run_reshape_layout_transform_shared(2048, 64, "float16")
def reduce_after_reshape_test(N, M, dtype):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N // M,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_shared = T.alloc_shared((N,), dtype, scope="shared")
A_local = T.alloc_fragment((N,), dtype)
B_local = T.alloc_fragment((N // M,), dtype)
T.copy(A, A_shared)
T.copy(A_shared, A_local)
A_local_reshape = T.reshape(A_local, [N // M, M])
T.reduce_max(A_local_reshape, B_local, dim=1)
T.copy(B_local, B)
return main
def run_reduce_after_reshape(N, M, dtype):
program = reduce_after_reshape_test(N, M, dtype)
jit_kernel = tl.compile(
program,
out_idx=-1,
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = jit_kernel.get_profiler()
def ref_program(A):
return torch.max(A.reshape(N // M, M), dim=1).values
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_reduce_after_reshape():
run_reduce_after_reshape(1024, 32, "float32")
run_reduce_after_reshape(2048, 64, "float16")
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
import argparse import torch
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
import torch
def ref_program(x, y): def ref_program(x, y):
...@@ -30,23 +29,29 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): ...@@ -30,23 +29,29 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
return elem_add return elem_add
def main(): def run_elementwise_add(M, N):
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=128)
parser.add_argument("--n", type=int, default=128)
args, _ = parser.parse_known_args()
M, N = args.m, args.n
a = torch.randn(M, N, dtype=torch.float32, device="cuda") a = torch.randn(M, N, dtype=torch.float32, device="cuda")
b = torch.randn(M, N, dtype=torch.float32, device="cuda") b = torch.randn(M, N, dtype=torch.float32, device="cuda")
# Default config # Default config
config = {"block_M": 128, "block_N": 128, "threads": 128} block_M, block_N = 128, 128
config = {"block_M": block_M, "block_N": block_N, "threads": 128}
kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32") kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")
out = kernel(a, b) out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
print("All passed!")
code = kernel.get_kernel_source()
if block_N == N:
assert "tma_load" in code and "CUtensorMap" not in code
else:
assert "tma_load" in code and "CUtensorMap" in code
def main():
run_elementwise_add(128, 128)
run_elementwise_add(256, 128)
run_elementwise_add(256, 256)
if __name__ == "__main__": if __name__ == "__main__":
......
import tilelang
import tilelang.language as T
import tilelang.testing
def test_var_assign() -> None:
@tilelang.jit(out_idx=-1)
def jit_kernel():
@T.prim_func
def test_var_assign(A: T.Tensor((2,), 'int32')):
with T.Kernel(1) as _:
a = T.alloc_var('int32', init=1)
b = T.alloc_var('int32', init=a) # b gets value of a
a = 2
d = T.alloc_var('int32', init=a) # c gets new value of a
A[0] = b
A[1] = d
print(test_var_assign)
return test_var_assign
kernel = jit_kernel()
print(kernel.get_kernel_source())
res = kernel()
assert res[0] == 1
assert res[1] == 2
if __name__ == '__main__':
tilelang.testing.main()
...@@ -17,8 +17,8 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str): ...@@ -17,8 +17,8 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor[(M), dtype_A], # noqa: F821 A: T.Tensor[(M,), dtype_A], # noqa: F821
B: T.Tensor[(M), dtype_B], # noqa: F821 B: T.Tensor[(M,), dtype_B], # noqa: F821
): ):
with T.Kernel(1, threads=128): with T.Kernel(1, threads=128):
T.copy(A, B) T.copy(A, B)
...@@ -26,6 +26,27 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str): ...@@ -26,6 +26,27 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
return main return main
@tilelang.jit
def parallel_vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
assert M % 256 == 0
@T.prim_func
def main(
A: T.Tensor[(M,), dtype_A], # noqa: F821
B: T.Tensor[(M,), dtype_B], # noqa: F821
):
with T.Kernel(1, threads=128):
A_local = T.alloc_fragment((M,), dtype_A)
B_local = T.alloc_fragment((M,), dtype_B)
T.copy(A, A_local)
for i in T.Parallel(M):
B_local[i] = A_local[i]
T.copy(B_local, B)
return main
def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, lanes: int = 2): def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, lanes: int = 2):
"""Run the vectorized cast kernel and check the correctness. """Run the vectorized cast kernel and check the correctness.
Args: Args:
...@@ -37,17 +58,22 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, ...@@ -37,17 +58,22 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
M = 128 * lanes M = 128 * lanes
kernel = vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str) kernel = vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str)
kernel_parallel = parallel_vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str)
A = torch.randn(M, dtype=str2dtype[src_dtype_str]).cuda() A = torch.randn(M, dtype=str2dtype[src_dtype_str]).cuda()
B = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda() B = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda()
C = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda()
kernel(A, B) kernel(A, B)
kernel_parallel(A, C)
torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), B) torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), B)
torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), C)
code = kernel.get_kernel_source() code = kernel.get_kernel_source()
code_parallel = kernel_parallel.get_kernel_source()
assert check_str in code, \ assert check_str in code and check_str in code_parallel, \
f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!" f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!"
......
import pytest
import torch
import tilelang
import tilelang.testing
import tilelang.language as T
tilelang.testing.set_random_seed()
VEC_SIZE = 32
@tilelang.jit
def fused_index_kernel(B: int, M: int, N: int, BLOCK_MN: int, BLOCK_K: int):
@T.prim_func
def main(
a: T.Buffer((B, M, N), "bfloat16"),
a_out: T.Buffer((B, M, N), "float32"),
):
with T.Kernel(
T.ceildiv(M, BLOCK_MN),
T.ceildiv(N, BLOCK_K),
B,
threads=128,
) as (pid_m, pid_n, pid_b):
a_fp32_local = T.alloc_fragment((BLOCK_MN * BLOCK_K // VEC_SIZE, VEC_SIZE), "float32")
offs_m = pid_m * BLOCK_MN
offs_n = pid_n * BLOCK_K
for i, j in T.Parallel(BLOCK_MN, BLOCK_K):
idx = i * BLOCK_K + j
a_out[pid_b, offs_m + i, offs_n + j] = a_fp32_local[idx // VEC_SIZE, idx % VEC_SIZE]
return main
def _require_cuda_tensor(shape, dtype):
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
try:
return torch.randn(*shape, device="cuda", dtype=dtype)
except RuntimeError as err:
pytest.skip(f"CUDA runtime unavailable: {err}")
def test_layout_infer_compiles_and_runs():
B, M, N = 1, 32, 64
BLOCK_MN, BLOCK_K = 32, 64
kernel = fused_index_kernel(B, M, N, BLOCK_MN, BLOCK_K)
a = _require_cuda_tensor((B, M, N), torch.bfloat16)
a_out = torch.empty((B, M, N), dtype=torch.float32, device=a.device)
# Ensure kernel compiles and executes without layout inversion failure
kernel(a, a_out)
assert a_out.shape == a.shape
assert a_out.dtype == torch.float32
if __name__ == "__main__":
tilelang.testing.main()
...@@ -397,6 +397,7 @@ def test_gemm_sr(): ...@@ -397,6 +397,7 @@ def test_gemm_sr():
run_gemm_sr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2) run_gemm_sr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2)
# float32 tests # float32 tests
# TODO(lei): fix in future
run_gemm_sr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2) run_gemm_sr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_sr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2) run_gemm_sr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_sr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2) run_gemm_sr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
......
...@@ -159,8 +159,8 @@ def test_wgmma_marked_async(): ...@@ -159,8 +159,8 @@ def test_wgmma_marked_async():
def before(): def before():
with T.Kernel(1): with T.Kernel(1):
A_shared = T.decl_buffer((1,), "float16", scope="shared") A_shared = T.decl_buffer((1,), "float16", scope="shared")
desc_a = T.decl_buffer((1,), "uint64", scope="local.descriptor") desc_a = T.decl_buffer((1,), "uint64", scope="local.descriptor.wgmma")
desc_b = T.decl_buffer((1,), "uint64", scope="local.descriptor") desc_b = T.decl_buffer((1,), "uint64", scope="local.descriptor.wgmma")
C_local = T.decl_buffer((32,), "float16", scope="local") C_local = T.decl_buffer((32,), "float16", scope="local")
A_shared[0] = T.float16(0) A_shared[0] = T.float16(0)
T.warpgroup_arrive() T.warpgroup_arrive()
......
...@@ -15,8 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): ...@@ -15,8 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
N = tvm.te.var("n") N = tvm.te.var("n")
K = tvm.te.var("k") K = tvm.te.var("k")
@tvm.script.ir.ir_module def before():
class Before:
@T.prim_func @T.prim_func
def main(B: T.Tensor((K, N), dtype),): def main(B: T.Tensor((K, N), dtype),):
...@@ -38,8 +37,9 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): ...@@ -38,8 +37,9 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
(block_N // vec_load_b) * (block_N // vec_load_b) + vec], (block_N // vec_load_b) * (block_N // vec_load_b) + vec],
T.float16(0)) T.float16(0))
@tvm.script.ir.ir_module return tvm.IRModule({'main': main})
class After:
def after():
@T.prim_func @T.prim_func
def main(B: T.Tensor((K, N), dtype),): def main(B: T.Tensor((K, N), dtype),):
...@@ -77,11 +77,13 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): ...@@ -77,11 +77,13 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
bx * block_N + t % (block_N // vec_load_b) * bx * block_N + t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec], T.float16(0)) (block_N // vec_load_b) + vec], T.float16(0))
return tvm.IRModule({'main': main})
with tvm.target.Target(auto_target): with tvm.target.Target(auto_target):
mod = tvm.tir.transform.BindTarget(auto_target)(Before) mod = tvm.tir.transform.BindTarget(auto_target)(before())
mod = tl.transform.LayoutInference()(mod) mod = tl.transform.LayoutInference()(mod)
mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.Simplify()(mod)
ref_mod = tvm.tir.transform.BindTarget(auto_target)(After) ref_mod = tvm.tir.transform.BindTarget(auto_target)(after())
ref_mod = tvm.tir.transform.Simplify()(ref_mod) ref_mod = tvm.tir.transform.Simplify()(ref_mod)
# Note(tzj): The structures are equal except one more "for" loop after the LayoutInference pass # Note(tzj): The structures are equal except one more "for" loop after the LayoutInference pass
# This loop is "for vec in T.parallel(1)", # This loop is "for vec in T.parallel(1)",
......
...@@ -15,8 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): ...@@ -15,8 +15,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
N = tvm.te.var("n") N = tvm.te.var("n")
K = tvm.te.var("k") K = tvm.te.var("k")
@tvm.script.ir.ir_module def before():
class Before:
@T.prim_func @T.prim_func
def main(B: T.Tensor((K, N), dtype),): def main(B: T.Tensor((K, N), dtype),):
...@@ -25,8 +24,9 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): ...@@ -25,8 +24,9 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(B[k * block_K, bx * block_N], B_shared) T.copy(B[k * block_K, bx * block_N], B_shared)
@tvm.script.ir.ir_module return tvm.IRModule({'main': main})
class After:
def after():
@T.prim_func @T.prim_func
def main(B: T.Tensor((K, N), dtype),): def main(B: T.Tensor((K, N), dtype),):
...@@ -64,11 +64,13 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): ...@@ -64,11 +64,13 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
bx * block_N + t % (block_N // vec_load_b) * bx * block_N + t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec], T.float16(0)) (block_N // vec_load_b) + vec], T.float16(0))
return tvm.IRModule({'main': main})
with tvm.transform.PassContext(): with tvm.transform.PassContext():
mod = tvm.tir.transform.BindTarget(auto_target)(Before) mod = tvm.tir.transform.BindTarget(auto_target)(before())
mod = tl.transform.LowerTileOp()(mod) mod = tl.transform.LowerTileOp()(mod)
mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.Simplify()(mod)
ref_mod = tvm.tir.transform.BindTarget(auto_target)(After) ref_mod = tvm.tir.transform.BindTarget(auto_target)(after())
ref_mod = tvm.tir.transform.Simplify()(ref_mod) ref_mod = tvm.tir.transform.Simplify()(ref_mod)
# Note(tzj): The structures are equal except the argument in "T.reads" function. # Note(tzj): The structures are equal except the argument in "T.reads" function.
# The difference is just between the first index and the indices range, which is totally equivalent # The difference is just between the first index and the indices range, which is totally equivalent
......
...@@ -113,7 +113,7 @@ def test_multi_version_buffer_with_let(): ...@@ -113,7 +113,7 @@ def test_multi_version_buffer_with_let():
shared = T.alloc_buffer((8,), "float32", scope="shared.dyn") shared = T.alloc_buffer((8,), "float32", scope="shared.dyn")
accum = T.alloc_buffer((8,), "float32", scope="local") accum = T.alloc_buffer((8,), "float32", scope="local")
for k in T.serial(4, annotations={"num_stages": T.int32(2)}): for k in T.serial(4, annotations={"num_stages": T.int32(2)}):
value: T.float32 = scales[k] value = scales[k]
for i in T.serial(8): for i in T.serial(8):
shared[i] = value shared[i] = value
for i in T.serial(8): for i in T.serial(8):
...@@ -125,7 +125,7 @@ def test_multi_version_buffer_with_let(): ...@@ -125,7 +125,7 @@ def test_multi_version_buffer_with_let():
shared = T.alloc_buffer((2, 8), "float32", scope="shared.dyn") shared = T.alloc_buffer((2, 8), "float32", scope="shared.dyn")
accum = T.alloc_buffer((8,), "float32", scope="local") accum = T.alloc_buffer((8,), "float32", scope="local")
for k in T.serial(4, annotations={"num_stages": T.int32(2)}): for k in T.serial(4, annotations={"num_stages": T.int32(2)}):
value: T.float32 = scales[k] value = scales[k]
for i in T.serial(8): for i in T.serial(8):
shared[k % 2, i] = value shared[k % 2, i] = value
for i in T.serial(8): for i in T.serial(8):
......
...@@ -188,5 +188,41 @@ def test_sync_let_stmt(): ...@@ -188,5 +188,41 @@ def test_sync_let_stmt():
tvm.ir.assert_structural_equal(mod["main"], expected) tvm.ir.assert_structural_equal(mod["main"], expected)
@tilelang.testing.requires_cuda
def test_sync_shared_dyn_stmatrix_loop_hoist():
@T.prim_func
def func():
buf_dyn_shmem = T.alloc_buffer((98304,), "uint8", scope="shared.dyn")
tx = T.launch_thread("threadIdx.x", 384)
for i in T.unroll(8):
off = (
i // 4 * 8192 + tx // 32 * 1024 + tx % 16 * 64 +
(tx % 8 // 4 + i % 4 // 2) % 2 * 32 + (tx % 4 // 2 + i % 2) % 2 * 16 +
(tx % 32 // 16 + tx % 2) % 2 * 8)
T.evaluate(
T.call_intrin(
"handle",
tvm.tir.op.Op.get("tl.ptx_stmatrix"),
T.int32(0),
T.int32(4),
T.tvm_access_ptr(
T.type_annotation("uint8"),
buf_dyn_shmem.data,
off,
98304 - off,
2,
),
T.int32(2),
))
mod = tvm.IRModule({"main": func})
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
s = str(mod)
assert 'T.tvm_storage_sync("shared.dyn")' in s
# Ensure the sync appears before the unrolled loop
assert s.index('T.tvm_storage_sync("shared.dyn")') < s.index("for i in T.unroll(8)")
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -4,24 +4,46 @@ import ctypes ...@@ -4,24 +4,46 @@ import ctypes
import logging import logging
import warnings import warnings
from tqdm import tqdm from pathlib import Path
from tqdm.auto import tqdm
from importlib.metadata import PackageNotFoundError, version
try: def _compute_version() -> str:
__version__ = version('tilelang') """Return the package version without being polluted by unrelated installs.
except PackageNotFoundError:
Preference order:
1) If running from a source checkout (VERSION file present at repo root),
use the dynamic version from version_provider (falls back to plain VERSION).
2) Otherwise, use importlib.metadata for the installed distribution.
3) As a last resort, return a dev sentinel.
"""
try: try:
from version_provider import dynamic_metadata repo_root = Path(__file__).resolve().parent.parent
version_file = repo_root / "VERSION"
if version_file.is_file():
try:
from version_provider import dynamic_metadata # type: ignore
return dynamic_metadata("version")
except Exception:
# Fall back to the raw VERSION file if provider isn't available.
return version_file.read_text().strip()
except Exception:
# If any of the above fails, fall through to installed metadata.
pass
__version__ = dynamic_metadata('version') try:
from importlib.metadata import version as _dist_version # py3.8+
return _dist_version("tilelang")
except Exception as exc: except Exception as exc:
warnings.warn( warnings.warn(
f"tilelang version metadata unavailable ({exc!r}); using development version.", f"tilelang version metadata unavailable ({exc!r}); using development version.",
RuntimeWarning, RuntimeWarning,
stacklevel=2, stacklevel=2,
) )
__version__ = "0.0.dev0" return "0.0.dev0"
__version__ = _compute_version()
class TqdmLoggingHandler(logging.Handler): class TqdmLoggingHandler(logging.Handler):
......
"""FFI APIs for tilelang""" """FFI APIs for tilelang"""
import tvm.ffi import tvm_ffi
# TVM_REGISTER_GLOBAL("tl.name").set_body_typed(func); # TVM_REGISTER_GLOBAL("tl.name").set_body_typed(func);
tvm.ffi._init_api("tl", __name__) # pylint: disable=protected-access tvm_ffi.init_ffi_api("tl", __name__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment