test_heterograph-index.py 2.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import unittest

import backend as F

import dgl
import pytest
from dgl import DGLError
from utils import parametrize_idtype


def create_test_heterograph(idtype):
    # 3 users, 2 games, 2 developers
    # metagraph:
    #    ('user', 'follows', 'user'),
    #    ('user', 'plays', 'game'),
    #    ('user', 'wishes', 'game'),
    #    ('developer', 'develops', 'game')])

    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "wishes", "game"): ([0, 2], [1, 0]),
            ("developer", "develops", "game"): ([0, 1], [0, 1]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g


@unittest.skipIf(
    F._default_context_str == "cpu", reason="Need gpu for this test"
)
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch",
    reason="Pinning graph outplace only supported for PyTorch",
)
@parametrize_idtype
def test_pin_memory(idtype):
    g = create_test_heterograph(idtype)
    g.nodes["user"].data["h"] = F.ones((3, 5))
    g.nodes["game"].data["i"] = F.ones((2, 5))
    g.edges["plays"].data["e"] = F.ones((4, 4))
    g = g.to(F.cpu())
    assert not g.is_pinned()

    # Test pinning a CPU graph.
    g._graph.pin_memory()
    assert not g.is_pinned()
    g._graph = g._graph.pin_memory()
    assert g.is_pinned()
    assert g.device == F.cpu()

    # when clone with a new (different) formats, e.g., g.formats("csc")
    # ensure the new graphs are not pinned
    assert not g.formats("csc").is_pinned()
    assert not g.formats("csr").is_pinned()
    # 'coo' formats is the default and thus not cloned
    assert g.formats("coo").is_pinned()

    # Test pinning a GPU graph will cause error raised.
    g1 = g.to(F.cuda())
    with pytest.raises(DGLError):
        g1._graph.pin_memory()

    # Test pinning an empty homograph
    g2 = dgl.graph(([], []))
71
    assert not g2.is_pinned()
72
73
74
75
76
77
78
79
80
81
82
83
84
    g2._graph = g2._graph.pin_memory()
    assert g2.is_pinned()

    # Test pinning heterograph with 0 edge of one relation type
    g3 = dgl.heterograph(
        {("a", "b", "c"): ([0, 1], [1, 2]), ("c", "d", "c"): ([], [])}
    ).astype(idtype)
    g3._graph = g3._graph.pin_memory()
    assert g3.is_pinned()


if __name__ == "__main__":
    pass