mx.py 4.31 KB
Newer Older
zzhang-cn's avatar
zzhang-cn committed
1
import networkx as nx
zzhang-cn's avatar
zzhang-cn committed
2
3
4
5
6
7
8
#from networkx.classes.graph import Graph
from networkx.classes.digraph import DiGraph

import torch as th
#import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable as Var
zzhang-cn's avatar
zzhang-cn committed
9
10
11

# TODO: loss functions and training

zzhang-cn's avatar
zzhang-cn committed
12
class mx_Graph(DiGraph):
zzhang-cn's avatar
zzhang-cn committed
13
14
15
    def __init__(self, *args, **kargs):
        super(mx_Graph, self).__init__(*args, **kargs)
        self.set_msg_func()
zzhang-cn's avatar
zzhang-cn committed
16
17
        self.set_gather_func()
        self.set_reduction_func()
zzhang-cn's avatar
zzhang-cn committed
18
19
20
21
22
23
24
25
        self.set_update_func()
        self.set_readout_func()
        self.init_reprs()

    def init_reprs(self, h_init=None):
        for n in self.nodes:
            self.set_repr(n, h_init)

zzhang-cn's avatar
zzhang-cn committed
26
    def set_repr(self, u, h_u, name='h'):
zzhang-cn's avatar
zzhang-cn committed
27
        assert u in self.nodes
zzhang-cn's avatar
zzhang-cn committed
28
29
        kwarg = {name: h_u}
        self.add_node(u, **kwarg)
zzhang-cn's avatar
zzhang-cn committed
30

zzhang-cn's avatar
zzhang-cn committed
31
    def get_repr(self, u, name='h'):
zzhang-cn's avatar
zzhang-cn committed
32
        assert u in self.nodes
zzhang-cn's avatar
zzhang-cn committed
33
        return self.nodes[u][name]
zzhang-cn's avatar
zzhang-cn committed
34

zzhang-cn's avatar
zzhang-cn committed
35
36
37
38
39
40
41
42
43
44
    def set_reduction_func(self):
        def _default_reduction_func(x_s):
            out = th.stack(x_s)
            out = th.sum(out, dim=0)
            return out
        self._reduction_func = _default_reduction_func

    def set_gather_func(self, u=None):
        pass

zzhang-cn's avatar
zzhang-cn committed
45
46
47
48
    def set_msg_func(self, func=None, u=None):
        """Function that gathers messages from neighbors"""
        def _default_msg_func(u):
            assert u in self.nodes
zzhang-cn's avatar
zzhang-cn committed
49
50
            msg_gathered = []
            for v in self.pred[u]:
zzhang-cn's avatar
zzhang-cn committed
51
52
                x = self.get_repr(v)
                if x is not None:
zzhang-cn's avatar
zzhang-cn committed
53
54
                    msg_gathered.append(x)
            return self._reduction_func(msg_gathered)
zzhang-cn's avatar
zzhang-cn committed
55
56
57
58

        # TODO: per node message function
        # TODO: 'sum' should be a separate function
        if func == None:
zzhang-cn's avatar
zzhang-cn committed
59
            self._msg_func = _default_msg_func
zzhang-cn's avatar
zzhang-cn committed
60
        else:
zzhang-cn's avatar
zzhang-cn committed
61
            self._msg_func = func
zzhang-cn's avatar
zzhang-cn committed
62
63
64
65
66
67

    def set_update_func(self, func=None, u=None):
        """
        Update function upon receiving an aggregate
        message from a node's neighbor
        """
zzhang-cn's avatar
zzhang-cn committed
68
69
        def _default_update_func(x, m):
            return x + m
zzhang-cn's avatar
zzhang-cn committed
70
71
72

        # TODO: per node update function
        if func == None:
zzhang-cn's avatar
zzhang-cn committed
73
            self._update_func = _default_update_func
zzhang-cn's avatar
zzhang-cn committed
74
        else:
zzhang-cn's avatar
zzhang-cn committed
75
            self._update_func = func
zzhang-cn's avatar
zzhang-cn committed
76
77
78
79

    def set_readout_func(self, func=None):
        """Readout function of the whole graph"""
        def _default_readout_func():
zzhang-cn's avatar
zzhang-cn committed
80
81
82
83
84
85
86
            valid_hs = []
            for x in self.nodes:
                h = self.get_repr(x)
                if h is not None:
                    valid_hs.append(h)
            return self._reduction_func(valid_hs)
#
zzhang-cn's avatar
zzhang-cn committed
87
88
89
90
91
92
93
94
95
96
97
        if func == None:
            self.readout_func = _default_readout_func
        else:
            self.readout_func = func

    def readout(self):
        return self.readout_func()

    def update_to(self, u):
        """Pull messages from 1-step away neighbors of u"""
        assert u in self.nodes
zzhang-cn's avatar
zzhang-cn committed
98
99
100
101
102
103
104
105
        m = self._msg_func(u=u)
        x = self.get_repr(u)
        # TODO: ugly hack
        if x is None:
            y = self._update_func(m)
        else:
            y = self._update_func(x, m)
        self.set_repr(u, y)
zzhang-cn's avatar
zzhang-cn committed
106
107
108
109
110
111

    def update_from(self, u):
        """Update u's 1-step away neighbors"""
        assert u in self.nodes
        # TODO: this asks v to pull from nodes other than
        # TODO: u, is this a good thing?
zzhang-cn's avatar
zzhang-cn committed
112
        for v in self.succ[u]:
zzhang-cn's avatar
zzhang-cn committed
113
114
115
116
117
118
119
120
            self.update_to(v)

    def print_all(self):
        for n in self.nodes:
            print(n, self.nodes[n])
        print()

if __name__ == '__main__':
zzhang-cn's avatar
zzhang-cn committed
121
122
123
    th.random.manual_seed(0)

    ''': this makes a digraph with double edges
zzhang-cn's avatar
zzhang-cn committed
124
125
126
127
    tg = nx.path_graph(10)
    g = mx_Graph(tg)
    g.print_all()

zzhang-cn's avatar
zzhang-cn committed
128
129
    # this makes a uni-edge tree
    tr = nx.bfs_tree(nx.balanced_tree(2, 3), 0)
zzhang-cn's avatar
zzhang-cn committed
130
131
    m_tr = mx_Graph(tr)
    m_tr.print_all()
zzhang-cn's avatar
zzhang-cn committed
132
133
    '''
    print("testing GRU update")
zzhang-cn's avatar
zzhang-cn committed
134
    g = mx_Graph(nx.path_graph(3))
zzhang-cn's avatar
zzhang-cn committed
135
    g.set_update_func(nn.GRUCell(4, 4))
zzhang-cn's avatar
zzhang-cn committed
136
    for n in g:
zzhang-cn's avatar
zzhang-cn committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        g.set_repr(n, Var(th.rand(2, 4)))

    print("\t**before:"); g.print_all()
    g.update_from(0)
    g.update_from(1)
    print("\t**after:"); g.print_all()

    print("\ntesting fwd update")
    g.clear()
    g.add_path([0, 1, 2])
    g.init_reprs()

    fwd_net = nn.Sequential(nn.Linear(4, 4), nn.ReLU())
    g.set_update_func(fwd_net)

    g.set_repr(0, Var(th.rand(2, 4)))
    print("\t**before:"); g.print_all()
    g.update_from(0)
    g.update_from(1)
    print("\t**after:"); g.print_all()