test.py 1.02 KB
Newer Older
zzhang-cn's avatar
zzhang-cn committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import networkx as nx
from networkx.classes.digraph import DiGraph

if __name__ == '__main__':
    from torch.autograd import Variable as Var

    th.random.manual_seed(0)

    print("testing vanilla RNN update")
    g_path = mx_Graph(nx.path_graph(2))
    g_path.set_repr(0, th.rand(2, 128))
    g_path.sendto(0, 1)
    g_path.recvfrom(1, [0])
    g_path.readout()

    '''
    # this makes a uni-edge tree
    tr = nx.bfs_tree(nx.balanced_tree(2, 3), 0)
    m_tr = mx_Graph(tr)
    m_tr.print_all()
    '''
    print("testing GRU update")
    g = mx_Graph(nx.path_graph(3))
    update_net = DefaultUpdateModule(h_dims=4, net_type='gru')
    g.register_update_func(update_net)
    msg_net = nn.Sequential(nn.Linear(4, 4), nn.ReLU())
    g.register_message_func(msg_net)

    for n in g:
        g.set_repr(n, th.rand(2, 4))

    y_pre = g.readout()
    g.update_from(0)
    y_after = g.readout()

    upd_nets = DefaultUpdateModule(h_dims=4, net_type='gru', n_func=2)
    g.register_update_func(upd_nets)
    g.update_from(0)
    g.update_from(0)