test_basics.py 3.16 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from dgl import DGLError
from dgl.utils import toindex
from dgl.graph_index import create_graph_index
import networkx as nx

def test_edge_id():
    gi = create_graph_index(multigraph=False)
    assert not gi.is_multigraph()

    gi = create_graph_index(multigraph=True)

    gi.add_nodes(4)
    gi.add_edge(0, 1)
    eid = gi.edge_id(0, 1).tolist()
    assert len(eid) == 1
    assert eid[0] == 0
    assert gi.is_multigraph()

    # multiedges
    gi.add_edge(0, 1)
    eid = gi.edge_id(0, 1).tolist()
    assert len(eid) == 2
    assert eid[0] == 0
    assert eid[1] == 1

    gi.add_edges(toindex([0, 1, 1, 2]), toindex([2, 2, 2, 3]))
    src, dst, eid = gi.edge_ids(toindex([0, 0, 2, 1]), toindex([2, 1, 3, 2]))
    eid_answer = [2, 0, 1, 5, 3, 4]
    assert len(eid) == 6
    assert all(e == ea for e, ea in zip(eid, eid_answer))

    # find edges
    src, dst, eid = gi.find_edges(toindex([1, 3, 5]))
    assert len(src) == len(dst) == len(eid) == 3
    assert src[0] == 0 and src[1] == 1 and src[2] == 2
    assert dst[0] == 1 and dst[1] == 2 and dst[2] == 3
    assert eid[0] == 1 and eid[1] == 3 and eid[2] == 5

    # source broadcasting
    src, dst, eid = gi.edge_ids(toindex([0]), toindex([1, 2]))
    eid_answer = [0, 1, 2]
    assert len(eid) == 3
    assert all(e == ea for e, ea in zip(eid, eid_answer))

    # destination broadcasting
    src, dst, eid = gi.edge_ids(toindex([1, 0]), toindex([2]))
    eid_answer = [3, 4, 2]
    assert len(eid) == 3
    assert all(e == ea for e, ea in zip(eid, eid_answer))

    gi.clear()
    # the following assumes that grabbing nonexistent edge will throw an error
    try:
        gi.edge_id(0, 1)
        fail = True
    except DGLError:
        fail = False
    finally:
        assert not fail

    gi.add_nodes(4)
    gi.add_edge(0, 1)
    eid = gi.edge_id(0, 1).tolist()
    assert len(eid) == 1
    assert eid[0] == 0

def test_nx():
    gi = create_graph_index(multigraph=True)

    gi.add_nodes(2)
    gi.add_edge(0, 1)
    nxg = gi.to_networkx()
    assert len(nxg.nodes) == 2
    assert len(nxg.edges(0, 1)) == 1
    gi.add_edge(0, 1)
    nxg = gi.to_networkx()
    assert len(nxg.edges(0, 1)) == 2

    nxg = nx.DiGraph()
    nxg.add_edge(0, 1)
    gi = create_graph_index(nxg)
    assert not gi.is_multigraph()
    assert gi.number_of_nodes() == 2
    assert gi.number_of_edges() == 1
    assert gi.edge_id(0, 1)[0] == 0

    nxg = nx.MultiDiGraph()
    nxg.add_edge(0, 1)
    nxg.add_edge(0, 1)
    gi = create_graph_index(nxg, True)
    assert gi.is_multigraph()
    assert gi.number_of_nodes() == 2
    assert gi.number_of_edges() == 2
    assert 0 in gi.edge_id(0, 1)
    assert 1 in gi.edge_id(0, 1)

def test_predsucc():
    gi = create_graph_index(multigraph=True)

    gi.add_nodes(4)
    gi.add_edge(0, 1)
    gi.add_edge(0, 1)
    gi.add_edge(0, 2)
    gi.add_edge(2, 0)
    gi.add_edge(3, 0)
    gi.add_edge(0, 0)
    gi.add_edge(0, 0)

    pred = gi.predecessors(0)
    assert len(pred) == 3
    assert 2 in pred
    assert 3 in pred
    assert 0 in pred

    succ = gi.successors(0)
    assert len(succ) == 3
    assert 1 in succ
    assert 2 in succ
    assert 0 in succ


if __name__ == '__main__':
    test_edge_id()
    test_nx()
    test_predsucc()