"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "bb3974807e3d9ab847de71a58a8f2985810a8166"
test_tilelang_issue_1008.py 1.41 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import tilelang
import tilelang.testing
from tilelang import language as T


@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
    },)
def _fill_with_static_region_kernel():
    num_tokens = T.symbolic('num_tokens')

    @T.prim_func
    def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']):  # noqa: F821
        with T.Kernel(num_tokens, threads=128) as _:
            T.fill(x[0:128], 0)

    return buggy_kernel


@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
    },)
def _fill_with_dynamic_region_kernel():
    num_tokens = T.symbolic('num_tokens')

    @T.prim_func
    def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']):  # noqa: F821
        with T.Kernel(num_tokens, threads=128) as _:
            a, b = T.alloc_var('int'), T.alloc_var('int')
            T.fill(x[a:b], 0)

    return buggy_kernel


def test_fill_with_static_region_kernel():
    kernel = _fill_with_static_region_kernel()
    x = torch.zeros((256,), dtype=torch.int64, device='cuda')
    kernel(x)


def test_fill_with_dynamic_region_kernel():
    kernel = _fill_with_dynamic_region_kernel()
    x = torch.zeros((256,), dtype=torch.int64, device='cuda')
    kernel(x)


if __name__ == '__main__':
    tilelang.testing.main()