test_line_graph.py 1.23 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
4
import torch as th
import networkx as nx
import numpy as np
import dgl
5
import utils as U
Minjie Wang's avatar
Minjie Wang committed
6
7
8
9
10
11

D = 5

def test_line_graph():
    N = 5
    G = dgl.DGLGraph(nx.star_graph(N))
12
    G.edata['h'] = th.randn((2 * N, D))
Lingfan Yu's avatar
Lingfan Yu committed
13
    n_edges = G.number_of_edges()
GaiYu0's avatar
GaiYu0 committed
14
15
    L = G.line_graph(shared=True)
    assert L.number_of_nodes() == 2 * N
16
    L.ndata['h'] = th.randn((2 * N, D))
Minjie Wang's avatar
Minjie Wang committed
17
18
19
20
    # update node features on line graph should reflect to edge features on
    # original graph.
    u = [0, 0, 2, 3]
    v = [1, 2, 0, 0]
GaiYu0's avatar
GaiYu0 committed
21
    eid = G.edge_ids(u, v)
22
    L.nodes[eid].data['h'] = th.zeros((4, D))
23
    assert U.allclose(G.edges[u, v].data['h'], th.zeros((4, D)))
Minjie Wang's avatar
Minjie Wang committed
24
25
26
27

    # adding a new node feature on line graph should also reflect to a new
    # edge feature on original graph
    data = th.randn(n_edges, D)
28
    L.ndata['w'] = data
29
    assert U.allclose(G.edata['w'], data)
Minjie Wang's avatar
Minjie Wang committed
30
31
32
33

def test_no_backtracking():
    N = 5
    G = dgl.DGLGraph(nx.star_graph(N))
GaiYu0's avatar
GaiYu0 committed
34
35
    L = G.line_graph(backtracking=False)
    assert L.number_of_nodes() == 2 * N
Minjie Wang's avatar
Minjie Wang committed
36
    for i in range(1, N):
GaiYu0's avatar
GaiYu0 committed
37
38
        e1 = G.edge_id(0, i)
        e2 = G.edge_id(i, 0)
39
40
        assert not L.has_edge_between(e1, e2)
        assert not L.has_edge_between(e2, e1)
Minjie Wang's avatar
Minjie Wang committed
41
42
43
44

if __name__ == '__main__':
    test_line_graph()
    test_no_backtracking()