"...resnet50_tensorflow.git" did not exist on "e80b385a20bee6a21af1683424575bea7c8457bc"
test_tilelang_language_reduce_sum.py 2.57 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl

tilelang.testing.set_random_seed()


def reduce_sum_test(M, N, dtype="float16"):
    import tilelang.language as T

    @T.prim_func
    def main(
            A: T.Tensor((M, N), dtype),
            B: T.Tensor((M,), dtype),
    ):
        with T.Kernel(1) as _:
            A_local = T.alloc_fragment((M, N), dtype)
            B_local = T.alloc_fragment((M,), dtype)

            # Copy input to local
            T.copy(A, A_local)
            # Perform reduce_sum operation
            T.reduce_sum(A_local, B_local, dim=1)
            # Copy result back
            T.copy(B_local, B)

    return main


def run_reduce_sum(M, N, dtype="float16"):
    program = reduce_sum_test(M, N, dtype)
    jit_kernel = tl.compile(program, out_idx=-1)
    profiler = jit_kernel.get_profiler()

    def ref_program(A):
        return A.sum(dim=1)

    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


def test_reduce_sum():
    # Test different sizes
    run_reduce_sum(256, 256)
    run_reduce_sum(512, 128)
    run_reduce_sum(128, 512)

    # Test different dtypes
    run_reduce_sum(256, 256, "float32")
    run_reduce_sum(256, 256, "float16")


def reduce_sum_test_clear(M, N, dtype="float16"):
    import tilelang.language as T

    @T.prim_func
    def main(
            A: T.Tensor((M, N), dtype),
            B: T.Tensor((M,), dtype),
    ):
        with T.Kernel(1, threads=32) as _:
            A_local = T.alloc_fragment((M, N), dtype)
            B_local = T.alloc_fragment((M,), dtype)

            T.copy(A, A_local)
            T.fill(B_local, 1)
            T.reduce_sum(A_local, B_local, dim=1, clear=False)
            T.copy(B_local, B)

    return main


def run_reduce_sum_clear(M, N, dtype="float16"):
    program = reduce_sum_test_clear(M, N, dtype)
    jit_kernel = tl.compile(
        program,
        out_idx=-1,
        pass_configs={
            "tl.disable_tma_lower": True,
            "tl.disable_warp_specialized": True,
        })
    print(jit_kernel.get_kernel_source())

    def ref_program(A):
        return A.sum(dim=1) + 1

    import torch
    dummp_A = torch.randn((M, N), dtype=getattr(torch, dtype)).cuda()
    ref_out = ref_program(dummp_A)
    tl_out = jit_kernel(dummp_A)
    print(tl_out)
    print(ref_out)
    torch.testing.assert_close(tl_out, ref_out, atol=1e-2, rtol=1e-2)


def test_reduce_sum_clear():
    run_reduce_sum_clear(256, 256, "float32")
    run_reduce_sum_clear(512, 128, "float32")
    run_reduce_sum_clear(128, 512, "float32")


if __name__ == "__main__":
    tilelang.testing.main()