test_tilelang_language_annotate_pad.py 1.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import tilelang
import tilelang.language as T
import tilelang.testing
import torch


# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0):

    @T.prim_func
    def main(
            A: T.Tensor((M, N), dtype),
            B: T.Tensor((M, N), dtype),
    ):
        # Initialize Kernel Context
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_N), dtype)

            T.annotate_padding({A_shared: pad_value})
            for i, j in T.Parallel(block_M, block_N):
                A_shared[i, j] = A[by * block_M + i - 10, bx * block_N + j]

            for i, j in T.Parallel(block_M, block_N):
                B[by * block_M + i, bx * block_N + j] = A_shared[i, j]

    return main


def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16", pad_value=0):
    program = tilelang_copy(M, N, block_M, block_N, dtype, pad_value=pad_value)
    kernel = tilelang.compile(
        program,
        out_idx=[1],
        target="cuda",
        pass_configs={
            "tl.disable_warp_specialized": True,
            "tl.disable_tma_lower": True
        })
    a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
    b = kernel(a)
    ref_b = torch.zeros_like(a)
    for i in range(M):
        if i >= 10:
            ref_b[i, :] = a[i - 10, :]
        else:
            ref_b[i, :] = pad_value
    torch.testing.assert_close(b, ref_b, rtol=1e-2, atol=1e-2)


def test_tilelang_copy():
    run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, pad_value=10)


if __name__ == "__main__":
    tilelang.testing.main()