graph.py 7.67 KB
Newer Older
Gan Quan's avatar
Gan Quan committed
1
import networkx as nx
zzhang-cn's avatar
zzhang-cn committed
2
from networkx.classes.digraph import DiGraph
Gan Quan's avatar
Gan Quan committed
3
4


zzhang-cn's avatar
zzhang-cn committed
5
6
7
8
9
10
11
12
13
14
'''
Defult modules: this is Pytorch specific
    - MessageModule: copy
    - UpdateModule: vanilla RNN
    - ReadoutModule: bag of words
    - ReductionModule: bag of words
'''
import torch as th
import torch.nn as nn
import torch.nn.functional as F
Gan Quan's avatar
Gan Quan committed
15

zzhang-cn's avatar
zzhang-cn committed
16
17
18
19
20
21
22
class DefaultMessageModule(nn.Module):
    """
    Default message module:
        - copy
    """
    def __init__(self, *args, **kwargs):
        super(DefaultMessageModule, self).__init__(*args, **kwargs)
Gan Quan's avatar
Gan Quan committed
23

zzhang-cn's avatar
zzhang-cn committed
24
25
    def forward(self, x):
        return x
Gan Quan's avatar
Gan Quan committed
26

zzhang-cn's avatar
zzhang-cn committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class DefaultUpdateModule(nn.Module):
    """
    Default update module:
        - a vanilla GRU with ReLU, or GRU
    """
    def __init__(self, *args, **kwargs):
        super(DefaultUpdateModule, self).__init__()
        h_dims = self.h_dims = kwargs.get('h_dims', 128)
        net_type = self.net_type = kwargs.get('net_type', 'fwd')
        n_func = self.n_func = kwargs.get('n_func', 1)
        self.f_idx = 0
        self.reduce_func = DefaultReductionModule()
        if net_type == 'gru':
            self.net = [nn.GRUCell(h_dims, h_dims) for i in range(n_func)]
        else:
            self.net = [nn.Linear(2 * h_dims, h_dims) for i in range(n_func)]
Gan Quan's avatar
Gan Quan committed
43

zzhang-cn's avatar
zzhang-cn committed
44
45
46
47
48
49
50
51
52
53
54
55
    def forward(self, x, msgs):
        if not th.is_tensor(x):
            x = th.zeros_like(msgs[0])
        m = self.reduce_func(msgs)
        assert(self.f_idx < self.n_func)
        if self.net_type == 'gru':
            out = self.net[self.f_idx](m, x)
        else:
            _in = th.cat((m, x), 1)
            out = F.relu(self.net[self.f_idx](_in))
        self.f_idx += 1
        return out
Gan Quan's avatar
Gan Quan committed
56

zzhang-cn's avatar
zzhang-cn committed
57
58
    def reset_f_idx(self):
        self.f_idx = 0
Gan Quan's avatar
Gan Quan committed
59

zzhang-cn's avatar
zzhang-cn committed
60
61
62
63
64
65
66
class DefaultReductionModule(nn.Module):
    """
    Default readout:
        - bag of words
    """
    def __init__(self, *args, **kwargs):
        super(DefaultReductionModule, self).__init__(*args, **kwargs)
Gan Quan's avatar
Gan Quan committed
67

zzhang-cn's avatar
zzhang-cn committed
68
69
70
71
    def forward(self, x_s):
        out = th.stack(x_s)
        out = th.sum(out, dim=0)
        return out
Gan Quan's avatar
Gan Quan committed
72

zzhang-cn's avatar
zzhang-cn committed
73
74
75
76
77
78
79
80
class DefaultReadoutModule(nn.Module):
    """
    Default readout:
        - bag of words
    """
    def __init__(self, *args, **kwargs):
        super(DefaultReadoutModule, self).__init__(*args, **kwargs)
        self.reduce_func = DefaultReductionModule()
Gan Quan's avatar
Gan Quan committed
81

zzhang-cn's avatar
zzhang-cn committed
82
83
    def forward(self, x_s):
        return self.reduce_func(x_s)
Gan Quan's avatar
Gan Quan committed
84

zzhang-cn's avatar
zzhang-cn committed
85
86
87
88
89
90
91
92
93
94
95
96
class mx_Graph(DiGraph):
    '''
    Functions:
        - m_func: per edge (u, v), default is u['state']
        - u_func: per node u, default is RNN(m, u['state'])
    '''
    def __init__(self, *args, **kargs):
        super(mx_Graph, self).__init__(*args, **kargs)
        self.m_func = DefaultMessageModule()
        self.u_func = DefaultUpdateModule()
        self.readout_func = DefaultReadoutModule()
        self.init_reprs()
Gan Quan's avatar
Gan Quan committed
97

zzhang-cn's avatar
zzhang-cn committed
98
99
100
    def init_reprs(self, h_init=None):
        for n in self.nodes:
            self.set_repr(n, h_init)
Gan Quan's avatar
Gan Quan committed
101

zzhang-cn's avatar
zzhang-cn committed
102
103
104
105
    def set_repr(self, u, h_u, name='state'):
        assert u in self.nodes
        kwarg = {name: h_u}
        self.add_node(u, **kwarg)
Gan Quan's avatar
Gan Quan committed
106

zzhang-cn's avatar
zzhang-cn committed
107
108
109
    def get_repr(self, u, name='state'):
        assert u in self.nodes
        return self.nodes[u][name]
Gan Quan's avatar
Gan Quan committed
110

zzhang-cn's avatar
zzhang-cn committed
111
112
    def _nodes_or_all(self, nodes='all'):
        return self.nodes() if nodes == 'all' else nodes
Gan Quan's avatar
Gan Quan committed
113

zzhang-cn's avatar
zzhang-cn committed
114
115
    def _edges_or_all(self, edges='all'):
        return self.edges() if edges == 'all' else edges
Gan Quan's avatar
Gan Quan committed
116
117
118
119
120
121
122

    def register_message_func(self, message_func, edges='all', batched=False):
        '''
        batched: whether to do a single batched computation instead of iterating
        message function: accepts source state tensor and edge tag tensor, and
        returns a message tensor
        '''
zzhang-cn's avatar
zzhang-cn committed
123
124
125
126
127
        if edges == 'all':
            self.m_func = message_func
        else:
            for e in self.edges:
                self.edges[e]['m_func'] = message_func
Gan Quan's avatar
Gan Quan committed
128
129
130
131
132

    def register_update_func(self, update_func, nodes='all', batched=False):
        '''
        batched: whether to do a single batched computation instead of iterating
        update function: accepts a node attribute dictionary (including state and tag),
Gan Quan's avatar
Gan Quan committed
133
        and a list of tuples (source node, target node, edge attribute dictionary)
Gan Quan's avatar
Gan Quan committed
134
        '''
zzhang-cn's avatar
zzhang-cn committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        if nodes == 'all':
            self.u_func = update_func
        else:
            for n in nodes:
                self.node[n]['u_func'] = update_func

    def register_readout_func(self, readout_func):
        self.readout_func = readout_func

    def readout(self, nodes='all', **kwargs):
        nodes_state = []
        nodes = self._nodes_or_all(nodes)
        for n in nodes:
            nodes_state.append(self.get_repr(n))
        return self.readout_func(nodes_state, **kwargs)

    def sendto(self, u, v):
        """Compute message on edge u->v
        Args:
            u: source node
            v: destination node
        """
        f_msg = self.edges[(u, v)].get('m_func', self.m_func)
        m = f_msg(self.get_repr(u))
        self.edges[(u, v)]['msg'] = m

    def sendto_ebunch(self, ebunch):
        """Compute message on edge u->v
        Args:
            ebunch: a bunch of edges
        """
        #TODO: simplify the logics
        for u, v in ebunch:
            f_msg = self.edges[(u, v)].get('m_func', self.m_func)
            m = f_msg(self.get_repr(u))
            self.edges[(u, v)]['msg'] = m

    def recvfrom(self, u, nodes):
        """Update u by nodes
        Args:
            u: node to be updated
            nodes: nodes with pre-computed messages to u
        """
        m = [self.edges[(v, u)]['msg'] for v in nodes]
        f_update = self.nodes[u].get('u_func', self.u_func)
        x_new = f_update(self.get_repr(u), m)
        self.set_repr(u, x_new)

    def update_by_edge(self, e):
        u, v = e
        self.sendto(u, v)
        self.recvfrom(v, [u])

    def update_to(self, u):
        """Pull messages from 1-step away neighbors of u"""
        assert u in self.nodes

        for v in self.pred[u]:
            self.sendto(v, u)
        self.recvfrom(u, list(self.pred[u]))

    def update_from(self, u):
        """Update u's 1-step away neighbors"""
        assert u in self.nodes
        for v in self.succ[u]:
            self.update_to(v)

    def update_all_step(self):
        self.sendto_ebunch(self.edges)
        for u in self.nodes:
            self.recvfrom(u, list(self.pred[u]))
Gan Quan's avatar
Gan Quan committed
206

zzhang-cn's avatar
zzhang-cn committed
207
208
209
    def draw(self):
        from networkx.drawing.nx_agraph import graphviz_layout

zzhang-cn's avatar
zzhang-cn committed
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
        pos = graphviz_layout(self, prog='dot')
        nx.draw(self, pos, with_labels=True)

    def set_reduction_func(self):
        def _default_reduction_func(x_s):
            out = th.stack(x_s)
            out = th.sum(out, dim=0)
            return out
        self._reduction_func = _default_reduction_func

    def set_gather_func(self, u=None):
        pass

    def print_all(self):
        for n in self.nodes:
            print(n, self.nodes[n])
        print()

if __name__ == '__main__':
    from torch.autograd import Variable as Var

    th.random.manual_seed(0)

    print("testing vanilla RNN update")
    g_path = mx_Graph(nx.path_graph(2))
    g_path.set_repr(0, th.rand(2, 128))
    g_path.sendto(0, 1)
    g_path.recvfrom(1, [0])
    g_path.readout()

    '''
    # this makes a uni-edge tree
    tr = nx.bfs_tree(nx.balanced_tree(2, 3), 0)
    m_tr = mx_Graph(tr)
    m_tr.print_all()
    '''
    print("testing GRU update")
    g = mx_Graph(nx.path_graph(3))
    update_net = DefaultUpdateModule(h_dims=4, net_type='gru')
    g.register_update_func(update_net)
    msg_net = nn.Sequential(nn.Linear(4, 4), nn.ReLU())
    g.register_message_func(msg_net)

    for n in g:
        g.set_repr(n, th.rand(2, 4))

    y_pre = g.readout()
    g.update_from(0)
    y_after = g.readout()

    upd_nets = DefaultUpdateModule(h_dims=4, net_type='gru', n_func=2)
    g.register_update_func(upd_nets)
    g.update_from(0)
    g.update_from(0)