test_tilelang_laguange_chain_equal.py 1.25 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import tilelang
import tilelang.testing
import tilelang.language as T
import torch


@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
    },)
def chain_equal(N, block_size, dtype="float32"):

    @T.prim_func
    def main(
            A: T.Tensor((N,), dtype),
            B: T.Tensor((N,), dtype),
            C: T.Tensor((N,), dtype),
    ):
        with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as bx:
            for lane in T.Parallel(block_size):
                idx = bx * block_size + lane
                A[idx] = B[idx] = C[idx] = 1

    return main


def run_chain_equal(N=128, block_size=64, dtype="float32"):
    kernel = chain_equal(N, block_size, dtype)
    A = torch.zeros((N,), dtype=torch.float32, device="cuda")
    B = torch.zeros((N,), dtype=torch.float32, device="cuda")
    C = torch.zeros((N,), dtype=torch.float32, device="cuda")
    kernel(A, B, C)
    ref = torch.ones_like(A)
    torch.testing.assert_close(A, ref)
    torch.testing.assert_close(B, ref)
    torch.testing.assert_close(C, ref)


@tilelang.testing.requires_cuda
def test_chain_equal():
    run_chain_equal()


if __name__ == "__main__":
    tilelang.testing.main()