test_basics.py 4.16 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
from dgl import DGLError
from dgl.utils import toindex
from dgl.graph_index import create_graph_index
import networkx as nx

def test_edge_id():
    gi = create_graph_index(multigraph=False)
    assert not gi.is_multigraph()

    gi = create_graph_index(multigraph=True)

    gi.add_nodes(4)
    gi.add_edge(0, 1)
14
    eid = gi.edge_id(0, 1).tonumpy()
15
16
17
18
19
20
    assert len(eid) == 1
    assert eid[0] == 0
    assert gi.is_multigraph()

    # multiedges
    gi.add_edge(0, 1)
21
    eid = gi.edge_id(0, 1).tonumpy()
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    assert len(eid) == 2
    assert eid[0] == 0
    assert eid[1] == 1

    gi.add_edges(toindex([0, 1, 1, 2]), toindex([2, 2, 2, 3]))
    src, dst, eid = gi.edge_ids(toindex([0, 0, 2, 1]), toindex([2, 1, 3, 2]))
    eid_answer = [2, 0, 1, 5, 3, 4]
    assert len(eid) == 6
    assert all(e == ea for e, ea in zip(eid, eid_answer))

    # find edges
    src, dst, eid = gi.find_edges(toindex([1, 3, 5]))
    assert len(src) == len(dst) == len(eid) == 3
    assert src[0] == 0 and src[1] == 1 and src[2] == 2
    assert dst[0] == 1 and dst[1] == 2 and dst[2] == 3
    assert eid[0] == 1 and eid[1] == 3 and eid[2] == 5

    # source broadcasting
    src, dst, eid = gi.edge_ids(toindex([0]), toindex([1, 2]))
    eid_answer = [0, 1, 2]
    assert len(eid) == 3
    assert all(e == ea for e, ea in zip(eid, eid_answer))

    # destination broadcasting
    src, dst, eid = gi.edge_ids(toindex([1, 0]), toindex([2]))
    eid_answer = [3, 4, 2]
    assert len(eid) == 3
    assert all(e == ea for e, ea in zip(eid, eid_answer))

    gi.clear()
    # the following assumes that grabbing nonexistent edge will throw an error
    try:
        gi.edge_id(0, 1)
        fail = True
    except DGLError:
        fail = False
    finally:
        assert not fail

    gi.add_nodes(4)
    gi.add_edge(0, 1)
63
    eid = gi.edge_id(0, 1).tonumpy()
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    assert len(eid) == 1
    assert eid[0] == 0

def test_nx():
    gi = create_graph_index(multigraph=True)

    gi.add_nodes(2)
    gi.add_edge(0, 1)
    nxg = gi.to_networkx()
    assert len(nxg.nodes) == 2
    assert len(nxg.edges(0, 1)) == 1
    gi.add_edge(0, 1)
    nxg = gi.to_networkx()
    assert len(nxg.edges(0, 1)) == 2

    nxg = nx.DiGraph()
    nxg.add_edge(0, 1)
    gi = create_graph_index(nxg)
    assert not gi.is_multigraph()
    assert gi.number_of_nodes() == 2
    assert gi.number_of_edges() == 1
    assert gi.edge_id(0, 1)[0] == 0

    nxg = nx.MultiDiGraph()
    nxg.add_edge(0, 1)
    nxg.add_edge(0, 1)
    gi = create_graph_index(nxg, True)
    assert gi.is_multigraph()
    assert gi.number_of_nodes() == 2
    assert gi.number_of_edges() == 2
    assert 0 in gi.edge_id(0, 1)
    assert 1 in gi.edge_id(0, 1)

97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    nxg = nx.DiGraph()
    nxg.add_nodes_from(range(3))
    gi = create_graph_index(nxg)
    assert gi.number_of_nodes() == 3
    assert gi.number_of_edges() == 0

    gi = create_graph_index()
    gi.add_nodes(3)
    nxg = gi.to_networkx()
    assert len(nxg.nodes) == 3
    assert len(nxg.edges) == 0

    nxg = nx.DiGraph()
    nxg.add_edge(0, 1, id=0)
    nxg.add_edge(1, 2, id=1)
    gi = create_graph_index(nxg)
    assert 0 in gi.edge_id(0, 1)
    assert 1 in gi.edge_id(1, 2)
    assert gi.number_of_edges() == 2
    assert gi.number_of_nodes() == 3

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def test_predsucc():
    gi = create_graph_index(multigraph=True)

    gi.add_nodes(4)
    gi.add_edge(0, 1)
    gi.add_edge(0, 1)
    gi.add_edge(0, 2)
    gi.add_edge(2, 0)
    gi.add_edge(3, 0)
    gi.add_edge(0, 0)
    gi.add_edge(0, 0)

    pred = gi.predecessors(0)
    assert len(pred) == 3
    assert 2 in pred
    assert 3 in pred
    assert 0 in pred

    succ = gi.successors(0)
    assert len(succ) == 3
    assert 1 in succ
    assert 2 in succ
    assert 0 in succ

142
143
144
145
146
147
def test_create_from_elist():
    elist = [(2, 1), (1, 0), (2, 0), (3, 0), (0, 2)]
    g = create_graph_index(elist)
    for i, (u, v) in enumerate(elist):
        assert g.edge_id(u, v)[0] == i
    # immutable graph
Minjie Wang's avatar
Minjie Wang committed
148
149
150
151
152
    # TODO: disabled due to torch support
    #g = create_graph_index(elist, readonly=True)
    #for i, (u, v) in enumerate(elist):
    #    print(u, v, g.edge_id(u, v)[0])
    #    assert g.edge_id(u, v)[0] == i
153
154
155
156
157

if __name__ == '__main__':
    test_edge_id()
    test_nx()
    test_predsucc()
158
    test_create_from_elist()