"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "f689bdf2cd7c1d136d2010f944d039110c8b4609"
test_tilelang_issue_830.py 2.05 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# ruff: noqa

import torch
import tilelang
import tilelang.testing
import tilelang.language as T


@tilelang.jit
def _empty_kernel():

    @T.prim_func
    def empty_kernel():
        with T.Kernel(1, threads=32) as thread_idx:
            pass

    return empty_kernel


20
@tilelang.testing.requires_cuda
21
def test_empty_kernel_lowering():
22
23
24
25
26
27
28
    # Ensure a valid CUDA runtime context is current on this thread for the
    # target device before using driver API calls. Without this, calls like
    # cuModuleLoadData can fail with CUDA_ERROR_INVALID_CONTEXT, especially
    # for kernels that don't touch any device memory or streams beforehand
    # (e.g., "empty" kernels) and therefore haven't triggered context
    # creation implicitly.
    torch.cuda.set_device(0)
29
30
31
32
33
34
    kernel = _empty_kernel()
    kernel()


@tilelang.jit
def _empty_with_dead_code_kernel():
35
    num_tokens = T.dynamic("num_tokens")
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

    @T.prim_func
    def buggy_kernel(x: T.Tensor[(num_tokens,), "float32"]):
        with T.Kernel(num_tokens, threads=32) as pid:
            y = x[pid]

    return buggy_kernel


@tilelang.testing.requires_cuda
def test_empty_with_dead_code_kernel():
    kernel = _empty_with_dead_code_kernel()
    x = torch.randn((128,), dtype=torch.float32, device="cuda")
    kernel(x)


@tilelang.jit
def _empty_kernel_with_binding_variants(use_tuple_binding: bool = False):

    @T.prim_func
    def kernel_with_tuple_kernel_binding():
        with T.Kernel(1, threads=32) as (pid,):
            print(pid)
            pass

    @T.prim_func
    def kernel_with_scalar_kernel_binding():
        with T.Kernel(1, threads=32) as pid:
            print(pid)
            pass

    return kernel_with_tuple_kernel_binding if use_tuple_binding else kernel_with_scalar_kernel_binding


70
@tilelang.testing.requires_cuda
71
def test_empty_kernel_with_binding_variants():
72
    torch.cuda.set_device(0)
73
74
75
76
77
78
79
80
81
    kernel = _empty_kernel_with_binding_variants()
    kernel()

    tuple_kernel = _empty_kernel_with_binding_variants(use_tuple_binding=True)
    tuple_kernel()


if __name__ == "__main__":
    tilelang.testing.main()