test_line_graph.py 1.33 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch as th
import networkx as nx
import numpy as np
import dgl

D = 5

def check_eq(a, b):
    return a.shape == b.shape and np.allclose(a.numpy(), b.numpy())

def test_line_graph():
    N = 5
    G = dgl.DGLGraph(nx.star_graph(N))
GaiYu0's avatar
GaiYu0 committed
14
    G.set_e_repr(th.randn((2 * N, D)))
Lingfan Yu's avatar
Lingfan Yu committed
15
    n_edges = G.number_of_edges()
GaiYu0's avatar
GaiYu0 committed
16
17
18
    L = G.line_graph(shared=True)
    assert L.number_of_nodes() == 2 * N
    L.set_n_repr(th.randn((2 * N, D)))
Minjie Wang's avatar
Minjie Wang committed
19
20
21
22
    # update node features on line graph should reflect to edge features on
    # original graph.
    u = [0, 0, 2, 3]
    v = [1, 2, 0, 0]
GaiYu0's avatar
GaiYu0 committed
23
    eid = G.edge_ids(u, v)
Minjie Wang's avatar
Minjie Wang committed
24
25
26
27
28
29
30
31
32
33
34
35
    L.set_n_repr(th.zeros((4, D)), eid)
    assert check_eq(G.get_e_repr(u, v), th.zeros((4, D)))

    # adding a new node feature on line graph should also reflect to a new
    # edge feature on original graph
    data = th.randn(n_edges, D)
    L.set_n_repr({'w': data})
    assert check_eq(G.get_e_repr()['w'], data)

def test_no_backtracking():
    N = 5
    G = dgl.DGLGraph(nx.star_graph(N))
GaiYu0's avatar
GaiYu0 committed
36
37
38
    G.set_e_repr(th.randn((2 * N, D)))
    L = G.line_graph(backtracking=False)
    assert L.number_of_nodes() == 2 * N
Minjie Wang's avatar
Minjie Wang committed
39
    for i in range(1, N):
GaiYu0's avatar
GaiYu0 committed
40
41
        e1 = G.edge_id(0, i)
        e2 = G.edge_id(i, 0)
42
43
        assert not L.has_edge_between(e1, e2)
        assert not L.has_edge_between(e2, e1)
Minjie Wang's avatar
Minjie Wang committed
44
45
46
47

if __name__ == '__main__':
    test_line_graph()
    test_no_backtracking()