test_line_graph.py 1.24 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
4
5
6
7
8
9
10
import torch as th
import networkx as nx
import numpy as np
import dgl

D = 5

def test_line_graph():
    N = 5
    G = dgl.DGLGraph(nx.star_graph(N))
Minjie Wang's avatar
Minjie Wang committed
11
    G.set_e_repr({'h' : th.randn((2 * N, D))})
Lingfan Yu's avatar
Lingfan Yu committed
12
    n_edges = G.number_of_edges()
GaiYu0's avatar
GaiYu0 committed
13
14
    L = G.line_graph(shared=True)
    assert L.number_of_nodes() == 2 * N
Minjie Wang's avatar
Minjie Wang committed
15
    L.set_n_repr({'h' : th.randn((2 * N, D))})
Minjie Wang's avatar
Minjie Wang committed
16
17
18
19
    # update node features on line graph should reflect to edge features on
    # original graph.
    u = [0, 0, 2, 3]
    v = [1, 2, 0, 0]
GaiYu0's avatar
GaiYu0 committed
20
    eid = G.edge_ids(u, v)
Minjie Wang's avatar
Minjie Wang committed
21
22
    L.set_n_repr({'h' : th.zeros((4, D))}, eid)
    assert th.allclose(G.get_e_repr(u, v)['h'], th.zeros((4, D)))
Minjie Wang's avatar
Minjie Wang committed
23
24
25
26
27

    # adding a new node feature on line graph should also reflect to a new
    # edge feature on original graph
    data = th.randn(n_edges, D)
    L.set_n_repr({'w': data})
Minjie Wang's avatar
Minjie Wang committed
28
    assert th.allclose(G.get_e_repr()['w'], data)
Minjie Wang's avatar
Minjie Wang committed
29
30
31
32

def test_no_backtracking():
    N = 5
    G = dgl.DGLGraph(nx.star_graph(N))
GaiYu0's avatar
GaiYu0 committed
33
34
    L = G.line_graph(backtracking=False)
    assert L.number_of_nodes() == 2 * N
Minjie Wang's avatar
Minjie Wang committed
35
    for i in range(1, N):
GaiYu0's avatar
GaiYu0 committed
36
37
        e1 = G.edge_id(0, i)
        e2 = G.edge_id(i, 0)
38
39
        assert not L.has_edge_between(e1, e2)
        assert not L.has_edge_between(e2, e1)
Minjie Wang's avatar
Minjie Wang committed
40
41
42
43

if __name__ == '__main__':
    test_line_graph()
    test_no_backtracking()