test_pin_memory.py 1.92 KB
Newer Older
1
2
3
4
5
import backend as F
import dgl
import pytest
import torch

6
@pytest.mark.skipif(F._default_context_str == 'cpu', reason="Need gpu for this test.")
7
8
9
10
11
12
13
14
15
def test_pin_noncontiguous():
    t = torch.empty([10, 100]).transpose(0, 1)

    assert not t.is_contiguous()
    assert not F.is_pinned(t)

    with pytest.raises(dgl.DGLError):
        dgl.utils.pin_memory_inplace(t)

16
@pytest.mark.skipif(F._default_context_str == 'cpu', reason="Need gpu for this test.")
17
18
19
20
21
22
23
24
25
26
def test_pin_view():
    t = torch.empty([100, 10])
    v = t[10:20]

    assert v.is_contiguous()
    assert not F.is_pinned(t)

    with pytest.raises(dgl.DGLError):
        dgl.utils.pin_memory_inplace(v)

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
@pytest.mark.skipif(F._default_context_str == 'cpu', reason='Need gpu for this test.')
def test_unpin_automatically():
    # run a sufficient number of iterations such that the memory pool should be
    # re-used
    for j in range(10):
        t = torch.ones(10000, 10)
        assert not F.is_pinned(t)
        nd = dgl.utils.pin_memory_inplace(t)
        assert F.is_pinned(t)
        del nd
        # dgl.ndarray will unpin its data upon destruction
        assert not F.is_pinned(t)
        del t

@pytest.mark.skipif(F._default_context_str == 'cpu', reason='Need gpu for this test.')
def test_pin_unpin_column():
    g = dgl.graph(([1, 2, 3, 4], [0, 0, 0, 0]))

    g.ndata['x'] = torch.randn(g.num_nodes())
    g.pin_memory_()
    assert g.is_pinned()
    assert g.ndata['x'].is_pinned()
    for col in g._node_frames[0].values():
        assert col.pinned_by_dgl
        assert col._data_nd is not None

    g.ndata['x'] = torch.randn(g.num_nodes())  # unpin the old ndata['x']
    assert g.is_pinned()
    for col in g._node_frames[0].values():
        assert not col.pinned_by_dgl
        assert col._data_nd is None
    assert not g.ndata['x'].is_pinned()
59
60
61
62

if __name__ == "__main__":
    test_pin_noncontiguous()
    test_pin_view()
63
64
    test_unpin_automatically()
    test_pin_unpin_column()