test_convert.py 3.75 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import unittest

import backend as F
import dgl

from utils import parametrize_idtype


def get_nodes_by_ntype(nodes, ntype):
    return dict((k, v) for k, v in nodes.items() if v["ntype"] == ntype)


def edge_attrs(edge):
    # Edges in Networkx are in the format (src, dst, attrs)
    return edge[2]


def get_edges_by_etype(edges, etype):
    return [e for e in edges if edge_attrs(e)["etype"] == etype]


def check_attrs_for_nodes(nodes, attrs):
    return all(v.keys() == attrs for v in nodes.values())


def check_attr_values_for_nodes(nodes, attr_name, values):
    return F.allclose(
        F.stack([v[attr_name] for v in nodes.values()], 0), values
    )


def check_attrs_for_edges(edges, attrs):
    return all(edge_attrs(e).keys() == attrs for e in edges)


def check_attr_values_for_edges(edges, attr_name, values):
    return F.allclose(
        F.stack([edge_attrs(e)[attr_name] for e in edges], 0), values
    )


@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="`to_networkx` does not support graphs on GPU",
)
@parametrize_idtype
def test_to_networkx(idtype):
    # TODO: adapt and move code from the _test_nx_conversion function in
    # tests/python/common/function/test_basics.py to here
    # (pending resolution of https://github.com/dmlc/dgl/issues/5735).
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "follows", "topic"): ([1, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 3], [3, 4]),
        },
        idtype=idtype,
        device=F.ctx(),
    )

    n1 = F.randn((5, 3))
    n2 = F.randn((4, 2))
    e1 = F.randn((2, 3))
    e2 = F.randn((2, 2))

    g.nodes["game"].data["n"] = F.copy_to(n1, ctx=F.ctx())
    g.nodes["user"].data["n"] = F.copy_to(n2, ctx=F.ctx())
    g.edges[("user", "follows", "user")].data["e"] = F.copy_to(e1, ctx=F.ctx())
    g.edges["plays"].data["e"] = F.copy_to(e2, ctx=F.ctx())

    nxg = dgl.to_networkx(
        g,
        node_attrs=["n"],
        edge_attrs=["e"],
    )

    # Test nodes
    nxg_nodes = dict(nxg.nodes(data=True))
    assert len(nxg_nodes) == g.num_nodes()
    assert {v["ntype"] for v in nxg_nodes.values()} == set(g.ntypes)

    nxg_nodes_by_ntype = {}
    for ntype in g.ntypes:
        nxg_nodes_by_ntype[ntype] = get_nodes_by_ntype(nxg_nodes, ntype)
        assert g.num_nodes(ntype) == len(nxg_nodes_by_ntype[ntype])

    assert check_attrs_for_nodes(nxg_nodes_by_ntype["game"], {"ntype", "n"})
    assert check_attr_values_for_nodes(nxg_nodes_by_ntype["game"], "n", n1)
    assert check_attrs_for_nodes(nxg_nodes_by_ntype["user"], {"ntype", "n"})
    assert check_attr_values_for_nodes(nxg_nodes_by_ntype["user"], "n", n2)
    # Nodes without node attributes
    assert check_attrs_for_nodes(nxg_nodes_by_ntype["topic"], {"ntype"})

    # Test edges
    nxg_edges = list(nxg.edges(data=True))
    assert len(nxg_edges) == g.num_edges()
    assert {edge_attrs(e)["etype"] for e in nxg_edges} == set(
        g.canonical_etypes
    )

    nxg_edges_by_etype = {}
    for etype in g.canonical_etypes:
        nxg_edges_by_etype[etype] = get_edges_by_etype(nxg_edges, etype)
        assert g.num_edges(etype) == len(nxg_edges_by_etype[etype])

    assert check_attrs_for_edges(
        nxg_edges_by_etype[("user", "follows", "user")],
        {"id", "etype", "e"},
    )
    assert check_attr_values_for_edges(
        nxg_edges_by_etype[("user", "follows", "user")], "e", e1
    )
    assert check_attrs_for_edges(
        nxg_edges_by_etype[("user", "plays", "game")], {"id", "etype", "e"}
    )
    assert check_attr_values_for_edges(
        nxg_edges_by_etype[("user", "plays", "game")], "e", e2
    )
    # Edges without edge attributes
    assert check_attrs_for_edges(
        nxg_edges_by_etype[("user", "follows", "topic")], {"id", "etype"}
    )