test_subgraph.py 2.45 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
4
5
6
7
8
9
import torch as th
from torch.autograd import Variable
import numpy as np
from dgl.graph import DGLGraph

D = 5

def generate_graph(grad=False):
    g = DGLGraph()
10
    g.add_nodes(10)
Minjie Wang's avatar
Minjie Wang committed
11
12
13
14
15
16
17
18
19
20
21
22
    # create a graph where 0 is the source and 9 is the sink
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
    ncol = Variable(th.randn(10, D), requires_grad=grad)
    ecol = Variable(th.randn(17, D), requires_grad=grad)
    g.set_n_repr({'h' : ncol})
    g.set_e_repr({'l' : ecol})
    return g

Minjie Wang's avatar
Minjie Wang committed
23
def test_basics():
Minjie Wang's avatar
Minjie Wang committed
24
25
26
    g = generate_graph()
    h = g.get_n_repr()['h']
    l = g.get_e_repr()['l']
Minjie Wang's avatar
Minjie Wang committed
27
28
    nid = [0, 2, 3, 6, 7, 9]
    sg = g.subgraph(nid)
29
30
31
    eid = {2, 3, 4, 5, 10, 11, 12, 13, 16}
    assert set(sg.parent_eid.numpy()) == eid
    eid = sg.parent_eid
Minjie Wang's avatar
Minjie Wang committed
32
33
34
35
    # the subgraph is empty initially
    assert len(sg.get_n_repr()) == 0
    assert len(sg.get_e_repr()) == 0
    # the data is copied after explict copy from
Da Zheng's avatar
Da Zheng committed
36
    sg.copy_from_parent()
Minjie Wang's avatar
Minjie Wang committed
37
38
    assert len(sg.get_n_repr()) == 1
    assert len(sg.get_e_repr()) == 1
Minjie Wang's avatar
Minjie Wang committed
39
    sh = sg.get_n_repr()['h']
40
    assert th.allclose(h[nid], sh)
Minjie Wang's avatar
Minjie Wang committed
41
42
43
44
    '''
    s, d, eid
    0, 1, 0
    1, 9, 1
Minjie Wang's avatar
Minjie Wang committed
45
46
47
48
    0, 2, 2    1
    2, 9, 3    1
    0, 3, 4    1
    3, 9, 5    1
Minjie Wang's avatar
Minjie Wang committed
49
50
51
    0, 4, 6
    4, 9, 7
    0, 5, 8
Minjie Wang's avatar
Minjie Wang committed
52
53
54
55
56
    5, 9, 9       3
    0, 6, 10   1
    6, 9, 11   1  3
    0, 7, 12   1
    7, 9, 13   1  3
Minjie Wang's avatar
Minjie Wang committed
57
    0, 8, 14
Minjie Wang's avatar
Minjie Wang committed
58
59
    8, 9, 15      3
    9, 0, 16   1
Minjie Wang's avatar
Minjie Wang committed
60
    '''
61
    assert th.allclose(l[eid], sg.get_e_repr()['l'])
Minjie Wang's avatar
Minjie Wang committed
62
63
64
    # update the node/edge features on the subgraph should NOT
    # reflect to the parent graph.
    sg.set_n_repr({'h' : th.zeros((6, D))})
65
    assert th.allclose(h, g.get_n_repr()['h'])
Minjie Wang's avatar
Minjie Wang committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

def test_merge():
    g = generate_graph()
    g.set_n_repr({'h' : th.zeros((10, D))})
    g.set_e_repr({'l' : th.zeros((17, D))})
    # subgraphs
    sg1 = g.subgraph([0, 2, 3, 6, 7, 9])
    sg1.set_n_repr({'h' : th.ones((6, D))})
    sg1.set_e_repr({'l' : th.ones((9, D))})

    sg2 = g.subgraph([0, 2, 3, 4])
    sg2.set_n_repr({'h' : th.ones((4, D)) * 2})

    sg3 = g.subgraph([5, 6, 7, 8, 9])
    sg3.set_e_repr({'l' : th.ones((4, D)) * 3})

    g.merge([sg1, sg2, sg3])

    h = g.get_n_repr()['h'][:,0]
    l = g.get_e_repr()['l'][:,0]
86
87
    assert th.allclose(h, th.tensor([3., 0., 3., 3., 2., 0., 1., 1., 0., 1.]))
    assert th.allclose(l,
Minjie Wang's avatar
Minjie Wang committed
88
            th.tensor([0., 0., 1., 1., 1., 1., 0., 0., 0., 3., 1., 4., 1., 4., 0., 3., 1.]))
Minjie Wang's avatar
Minjie Wang committed
89
90

if __name__ == '__main__':
Minjie Wang's avatar
Minjie Wang committed
91
    test_basics()
92
    #test_merge()