"vscode:/vscode.git/clone" did not exist on "6972aed7289e2a817a1221c9dcd0771fcf14084e"
test_tilelang_language_view.py 1.53 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl


def view_test(N, M, dtype, new_dtype=None):
    import tilelang.language as T

    new_shape = [N // M, M]
    if new_dtype:
        from tvm import DataType
        dtype_src = DataType(dtype)
        dtype_dst = DataType(new_dtype)
        src_bits = dtype_src.bits
        dst_bits = dtype_dst.bits
        scale = src_bits / dst_bits
        new_shape[-1] = int(M * scale)

    @T.prim_func
    def main(
21
22
            A: T.Tensor((N,), dtype),
            B: T.Tensor(new_shape, new_dtype if new_dtype else dtype),
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    ):
        with T.Kernel(1) as _:
            A_viewed = T.view(A, new_shape, dtype=new_dtype)
            T.copy(A_viewed, B)

    return main


def run_view(N, M, dtype, new_dtype=None):
    program = view_test(N, M, dtype, new_dtype)
    jit_kernel = tl.compile(program, out_idx=-1)
    profiler = jit_kernel.get_profiler()

    def ref_program(A):
        if new_dtype:
            from tilelang.utils.tensor import map_torch_type
            torch_dtype = map_torch_type(new_dtype)
            return A.view(N // M, M).view(dtype=torch_dtype)
        return A.view(N // M, M)

    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


def test_reshape_view():

    # Test view with same dtype
    run_view(1024, 32, "float32")
    run_view(2048, 64, "float16")

    # Test view with dtype conversion
    run_view(1024, 32, "float32", "float16")
    run_view(2048, 64, "float16", "float32")


if __name__ == "__main__":
    tilelang.testing.main()