test_tilelang_issue_1008.py 1.41 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import torch
import tilelang
import tilelang.testing
from tilelang import language as T


@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
11
12
    },
)
13
def _fill_with_static_region_kernel():
14
    num_tokens = T.symbolic("num_tokens")
15
16

    @T.prim_func
17
    def buggy_kernel(x: T.Tensor[(num_tokens,), "int64"]):  # noqa: F821
18
19
20
21
22
23
24
25
26
27
        with T.Kernel(num_tokens, threads=128) as _:
            T.fill(x[0:128], 0)

    return buggy_kernel


@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
28
29
    },
)
30
def _fill_with_dynamic_region_kernel():
31
    num_tokens = T.symbolic("num_tokens")
32
33

    @T.prim_func
34
    def buggy_kernel(x: T.Tensor[(num_tokens,), "int64"]):  # noqa: F821
35
        with T.Kernel(num_tokens, threads=128) as _:
36
            a, b = T.alloc_var("int"), T.alloc_var("int")
37
38
39
40
41
42
43
            T.fill(x[a:b], 0)

    return buggy_kernel


def test_fill_with_static_region_kernel():
    kernel = _fill_with_static_region_kernel()
44
    x = torch.zeros((256,), dtype=torch.int64, device="cuda")
45
46
47
48
49
    kernel(x)


def test_fill_with_dynamic_region_kernel():
    kernel = _fill_with_dynamic_region_kernel()
50
    x = torch.zeros((256,), dtype=torch.int64, device="cuda")
51
52
53
    kernel(x)


54
if __name__ == "__main__":
55
    tilelang.testing.main()