test_tilelang_language_view.py 2.33 KB
Newer Older
1
2
3
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
4
import pytest
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21


def view_test(N, M, dtype, new_dtype=None):
    import tilelang.language as T

    new_shape = [N // M, M]
    if new_dtype:
        from tvm import DataType
        dtype_src = DataType(dtype)
        dtype_dst = DataType(new_dtype)
        src_bits = dtype_src.bits
        dst_bits = dtype_dst.bits
        scale = src_bits / dst_bits
        new_shape[-1] = int(M * scale)

    @T.prim_func
    def main(
22
23
            A: T.Tensor((N,), dtype),
            B: T.Tensor(new_shape, new_dtype if new_dtype else dtype),
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    ):
        with T.Kernel(1) as _:
            A_viewed = T.view(A, new_shape, dtype=new_dtype)
            T.copy(A_viewed, B)

    return main


def run_view(N, M, dtype, new_dtype=None):
    program = view_test(N, M, dtype, new_dtype)
    jit_kernel = tl.compile(program, out_idx=-1)
    profiler = jit_kernel.get_profiler()

    def ref_program(A):
        if new_dtype:
            from tilelang.utils.tensor import map_torch_type
            torch_dtype = map_torch_type(new_dtype)
            return A.view(N // M, M).view(dtype=torch_dtype)
        return A.view(N // M, M)

    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


def test_reshape_view():

    # Test view with same dtype
    run_view(1024, 32, "float32")
    run_view(2048, 64, "float16")

    # Test view with dtype conversion
    run_view(1024, 32, "float32", "float16")
    run_view(2048, 64, "float16", "float32")


58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def view_shape_mismatch_test(N, M, dtype, new_dtype=None):
    import tilelang.language as T

    new_shape = [N // M, M + 1]
    if new_dtype:
        from tvm import DataType
        dtype_src = DataType(dtype)
        dtype_dst = DataType(new_dtype)
        src_bits = dtype_src.bits
        dst_bits = dtype_dst.bits
        scale = src_bits / dst_bits
        new_shape[-1] = int(M * scale)

    @T.prim_func
    def main(
            A: T.Tensor((N,), dtype),
            B: T.Tensor(new_shape, new_dtype if new_dtype else dtype),
    ):
        with T.Kernel(1) as _:
            A_viewed = T.view(A, new_shape, dtype=new_dtype)
            T.copy(A_viewed, B)

    return main


def test_view_shape_mismatch():
    with pytest.raises(AssertionError):
        view_shape_mismatch_test(1024, 32, "float32")


88
89
if __name__ == "__main__":
    tilelang.testing.main()