test_tilelang_capture.py 944 Bytes
Newer Older
Kuris's avatar
Kuris committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import tilelang.language as T
import tilelang.testing
import torch
import weakref
import gc


def test_tilelang_capture():

    @tilelang.jit(
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        },)
    def get_dummy_kernel():

        @T.prim_func
        def dummy_kernel(a: T.Tensor[(1,), T.float32],):
            with T.Kernel(1) as _:
                a[0] = 1

        return dummy_kernel

    a = torch.randn(1, 1024)
    a_weak = weakref.ref(a)
    _kernel = get_dummy_kernel()
    del a
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.empty_cache()
    a_upgrade = a_weak()
    assert a_upgrade is None, "A is not garbage collected"

    # use objgraph to debug
    # if a_upgrade is not None:
    #     objgraph.show_backrefs([a_upgrade], max_depth=5)


if __name__ == '__main__':
    tilelang.testing.main()