"docs/vscode:/vscode.git/clone" did not exist on "fec0167a123eddc60891320dd263e974671ad1c9"
test_line_graph.py 1.36 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
4
5
6
7
8
9
10
11
import torch as th
import networkx as nx
import numpy as np
import dgl

D = 5

def check_eq(a, b):
    return a.shape == b.shape and np.allclose(a.numpy(), b.numpy())

def test_line_graph():
Lingfan Yu's avatar
Lingfan Yu committed
12
13
14
    # FIXME
    return
    """
Minjie Wang's avatar
Minjie Wang committed
15
16
17
    N = 5
    G = dgl.DGLGraph(nx.star_graph(N))
    G.set_e_repr(th.randn((2*N, D)))
Lingfan Yu's avatar
Lingfan Yu committed
18
    n_edges = G.number_of_edges()
Minjie Wang's avatar
Minjie Wang committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    L = dgl.line_graph(G)
    assert L.number_of_nodes() == 2*N
    # update node features on line graph should reflect to edge features on
    # original graph.
    u = [0, 0, 2, 3]
    v = [1, 2, 0, 0]
    eid = G.get_edge_id(u, v)
    L.set_n_repr(th.zeros((4, D)), eid)
    assert check_eq(G.get_e_repr(u, v), th.zeros((4, D)))

    # adding a new node feature on line graph should also reflect to a new
    # edge feature on original graph
    data = th.randn(n_edges, D)
    L.set_n_repr({'w': data})
    assert check_eq(G.get_e_repr()['w'], data)
Lingfan Yu's avatar
Lingfan Yu committed
34
    """
Minjie Wang's avatar
Minjie Wang committed
35
36

def test_no_backtracking():
Lingfan Yu's avatar
Lingfan Yu committed
37
38
39
    # FIXME
    return
    """
Minjie Wang's avatar
Minjie Wang committed
40
41
42
43
44
45
46
47
48
49
    N = 5
    G = dgl.DGLGraph(nx.star_graph(N))
    G.set_e_repr(th.randn((2*N, D)))
    L = dgl.line_graph(G, no_backtracking=True)
    assert L.number_of_nodes() == 2*N
    for i in range(1, N):
        e1 = G.get_edge_id(0, i)
        e2 = G.get_edge_id(i, 0)
        assert not L.has_edge(e1, e2)
        assert not L.has_edge(e2, e1)
Lingfan Yu's avatar
Lingfan Yu committed
50
    """
Minjie Wang's avatar
Minjie Wang committed
51
52
53
54

if __name__ == '__main__':
    test_line_graph()
    test_no_backtracking()