test_tilelang_language_mask_op.py 5.3 KB
Newer Older
1
2
3
4
5
import tilelang
import tilelang.language as T
import torch


6
7
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
8
def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype=T.float16):
9
10
    @T.prim_func
    def main(
11
12
        A: T.Tensor((M, N), dtype),
        B: T.Tensor((M, N), dtype),
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
    ):
        # Initialize Kernel Context
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_N), dtype)

            tx = T.get_thread_binding(0)

            if tx < 128:
                for i, k in T.Parallel(block_M, block_N):
                    A_shared[i, k] = A[by * block_M + i, bx * block_N + k]

            T.copy(A_shared, B[by * block_M, bx * block_N])

    return main


29
def run_tilelang_copy_mask_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16):
30
31
    program = tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype)
    kernel = tilelang.compile(
32
33
        program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
    )
34
35
36
37
38
39
40
41
42
    a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
    b = kernel(a)
    torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)


def test_tilelang_copy_mask_parallel():
    run_tilelang_copy_mask_parallel(M=1024, N=1024, block_M=128, block_N=128)


43
44
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
45
def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype=T.float16):
46
47
    @T.prim_func
    def main(
48
49
        A: T.Tensor((M, N), dtype),
        B: T.Tensor((M, N), dtype),
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    ):
        # Initialize Kernel Context
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_N), dtype)

            tx = T.get_thread_binding(0)

            if tx < 128:
                T.copy(A[by * block_M, bx * block_N], A_shared)

            T.copy(A_shared, B[by * block_M, bx * block_N])

    return main


65
def run_tilelang_copy_mask_copy(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16):
66
67
    program = tilelang_copy_mask_copy(M, N, block_M, block_N, dtype)
    kernel = tilelang.compile(
68
69
        program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
    )
70
71
72
73
74
75
76
77
78
    a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
    b = kernel(a)
    torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)


def test_tilelang_copy_mask_copy():
    run_tilelang_copy_mask_copy(M=1024, N=1024, block_M=128, block_N=128)


79
80
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
81
def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype=T.float16):
82
83
    @T.prim_func
    def main(
84
85
        A: T.Tensor((M, N), dtype),
        B: T.Tensor((M, N), dtype),
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    ):
        # Initialize Kernel Context
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_N), dtype)

            tx = T.get_thread_binding(0)

            if tx >= 128 and tx < 256:
                for i, k in T.Parallel(block_M, block_N):
                    A_shared[i, k] = A[by * block_M + i, bx * block_N + k]

            T.copy(A_shared, B[by * block_M, bx * block_N])

    return main


102
def run_tilelang_copy_mask_parallel_range(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16):
103
104
    program = tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype)
    kernel = tilelang.compile(
105
106
        program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
    )
107
108
109
110
111
112
113
114
115
    a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
    b = kernel(a)
    torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)


def test_tilelang_copy_mask_parallel_range():
    run_tilelang_copy_mask_parallel_range(M=1024, N=1024, block_M=128, block_N=128)


116
117
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
118
def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype=T.float16):
119
120
    @T.prim_func
    def main(
121
122
        A: T.Tensor((M, N), dtype),
        B: T.Tensor((M, N), dtype),
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    ):
        # Initialize Kernel Context
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_N), dtype)

            tx = T.get_thread_binding(0)

            if tx >= 128 and tx < 256:
                T.copy(A[by * block_M, bx * block_N], A_shared)

            T.copy(A_shared, B[by * block_M, bx * block_N])

    return main


138
def run_tilelang_copy_mask_copy_range(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16):
139
140
    program = tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype)
    kernel = tilelang.compile(
141
142
        program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
    )
143
144
145
146
147
148
149
150
151
152
153
    a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
    b = kernel(a)
    torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)


def test_tilelang_copy_mask_copy_range():
    run_tilelang_copy_mask_copy_range(M=1024, N=1024, block_M=128, block_N=128)


if __name__ == "__main__":
    test_tilelang_copy_mask_copy_range()