"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "ff1af7f2627e07b229265f8e6ad8e536bf5d2827"
test_tilelang_capture.py 969 Bytes
Newer Older
Kuris's avatar
Kuris committed
1
2
3
4
5
6
7
8
9
10
11
12
import tilelang.language as T
import tilelang.testing
import torch
import weakref
import gc


def test_tilelang_capture():
    @tilelang.jit(
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
13
14
        },
    )
Kuris's avatar
Kuris committed
15
16
    def get_dummy_kernel():
        @T.prim_func
17
18
19
        def dummy_kernel(
            a: T.Tensor[(1,), T.float32],
        ):
Kuris's avatar
Kuris committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
            with T.Kernel(1) as _:
                a[0] = 1

        return dummy_kernel

    a = torch.randn(1, 1024)
    a_weak = weakref.ref(a)
    _kernel = get_dummy_kernel()
    del a
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.empty_cache()
    a_upgrade = a_weak()
    assert a_upgrade is None, "A is not garbage collected"

    # use objgraph to debug
    # if a_upgrade is not None:
    #     objgraph.show_backrefs([a_upgrade], max_depth=5)


40
if __name__ == "__main__":
Kuris's avatar
Kuris committed
41
    tilelang.testing.main()