graph.py 11.1 KB
Newer Older
1
2
3
"""Base graph class specialized for neural networks on graphs.
"""

4
from collections import defaultdict
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import networkx as nx
from networkx.classes.digraph import DiGraph

import dgl.backend as F
from dgl.backend import Tensor
import dgl.utils as utils

__MSG__ = "__msg__"
__REPR__ = "__repr__"
__MFUNC__ = "__mfunc__"
__UFUNC__ = "__ufunc__"

class DGLGraph(DiGraph):
    """Base graph class specialized for neural networks on graphs.

    TODO(minjie): document of multi-node and multi-edge syntax.

    Parameters
    ----------
    data : graph data
        Data to initialize graph. Same as networkx's semantics.
    attr : keyword arguments, optional
        Attributes to add to graph as key=value pairs.
    """
    def __init__(self, graph_data=None, **attr):
        super(DGLGraph, self).__init__(graph_data, **attr)
        self.m_func = None
        self.u_func = None
        self.readout_func = None

    def init_reprs(self, h_init=None):
        print("[DEPRECATED]: please directly set node attrs "
              "(e.g. g.nodes[node]['x'] = val).")
        for n in self.nodes:
            self.set_repr(n, h_init)

    def set_repr(self, u, h_u, name=__REPR__):
        print("[DEPRECATED]: please directly set node attrs "
              "(e.g. g.nodes[node]['x'] = val).")
        assert u in self.nodes
        kwarg = {name: h_u}
        self.add_node(u, **kwarg)

    def get_repr(self, u, name=__REPR__):
        print("[DEPRECATED]: please directly get node attrs "
              "(e.g. g.nodes[node]['x']).")
        assert u in self.nodes
        return self.nodes[u][name]

    def register_message_func(self, message_func, edges='all', batchable=False):
        """Register computation on edges.

        The message function should be compatible with following signature:

        (node_reprs, node_reprs, edge_reprs) -> edge_reprs

        It computes the new edge representations (the same concept as messages)
        using the representations of the source node, target node and the edge
        itself. All node_reprs and edge_reprs are dictionaries.

        Parameters
        ----------
        message_func : callable
          Message function on the edge.
        edges : str, pair of nodes, pair of containers, pair of tensors
          The edges for which the message function is registered. Default is
          registering for all the edges. Registering for multiple edges is
          supported.
        batchable : bool
          Whether the provided message function allows batch computing.

        Examples
        --------

        Register for all edges.
        >>> g.register_message_func(mfunc)

        Register for a specific edge.
        >>> g.register_message_func(mfunc, (u, v))

        Register for multiple edges.
        >>> u = [u1, u2, u3, ...]
        >>> v = [v1, v2, v3, ...]
        >>> g.register_message_func(mfunc, (u, v))
        """
        if edges == 'all':
            self.m_func = message_func
        else:
            for e in edges:
                self.edges[e][__MFUNC__] = message_func

    def register_update_func(self, update_func, nodes='all', batchable=False):
        """Register computation on nodes.

        The update function should be compatible with following signature:

        (edge_reprs, node_reprs) -> node_reprs

        It computes the new node representations using the representations
        of the in-coming edges (the same concept as messages) and the node
        itself. All node_reprs and edge_reprs are dictionaries.

        Parameters
        ----------
        update_func : callable
          Update function on the node.
        nodes : str, node, container or tensor
          The nodes for which the update function is registered. Default is
          registering for all the nodes. Registering for multiple nodes is
          supported.
        batchable : bool
          Whether the provided update function allows batch computing.

        Examples
        --------

        Register for all nodes.
        >>> g.register_update_func(ufunc)

        Register for a specific node.
        >>> g.register_update_func(ufunc, u)

        Register for multiple nodes.
        >>> u = [u1, u2, u3, ...]
        >>> g.register_update_func(ufunc, u)
        """
        if nodes == 'all':
            self.u_func = update_func
        else:
            for n in nodes:
                self.nodes[n][__UFUNC__] = update_func

    def register_readout_func(self, readout_func):
        """Register computation on the whole graph.

        The readout_func should be compatible with following signature:

        (node_reprs, edge_reprs) -> any

        It takes the representations of selected nodes and edges and
        returns readout values.

        NOTE: readout function can be implemented outside of DGLGraph.
        One can simple get the node/edge reprs of the graph and perform
        arbitrary computation.

        Parameters
        ----------
        readout_func : callable
          The readout function.

        See Also
        --------
        readout
        """
        self.readout_func = readout_func

    def readout(self, nodes='all', edges='all'):
        """Trigger the readout function on the specified nodes/edges.

        Parameters
        ----------
        nodes : str, node, container or tensor
          The nodes to get reprs from.
        edges : str, pair of nodes, pair of containers or pair of tensors
          The edges to get reprs from.
        """
        nodes = self._nodes_or_all(nodes)
        edges = self._nodes_or_all(nodes)
174
        assert self.readout_func is not None, \
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
            "Readout function is not registered."
        # TODO(minjie): tensorize following loop.
        nstates = [self.nodes[n] for n in nodes]
        estates = [self.edges[e] for e in edges]
        return self.readout_func(nstates, estates)

    def sendto(self, u, v):
        """Trigger the message function on edge u->v

        Parameters
        ----------
        u : node, container or tensor
          The source node(s).
        v : node, container or tensor
          The destination node(s).
        """
        # TODO(minjie): tensorize the loop.
        for uu, vv in utils.edge_iter(u, v):
            f_msg = self.edges[uu, vv].get(__MFUNC__, self.m_func)
194
            assert f_msg is not None, \
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
                "message function not registered for edge (%s->%s)" % (uu, vv)
            m = f_msg(self.nodes[uu], self.nodes[vv], self.edges[uu, vv])
            self.edges[uu, vv][__MSG__] = m

    def recvfrom(self, u, preds=None):
        """Trigger the update function on node u.

        It computes the new node state using the messages and edge
        states from preds->u. If `u` is one node, `preds` is a list
        of predecessors. If `u` is a container or tensor of nodes,
        then `preds[i]` should be the predecessors of `u[i]`.

        Parameters
        ----------
        u : node, container or tensor
          The node to be updated.
        preds : container
          Nodes with pre-computed messages to u. Default is all
          the predecessors.
        """
215
        u_is_container = isinstance(u, list)
216
217
218
219
220
221
222
223
224
225
226
227
        u_is_tensor = isinstance(u, Tensor)
        # TODO(minjie): tensorize the loop.
        for i, uu in enumerate(utils.node_iter(u)):
            if preds is None:
                v = list(self.pred[uu])
            elif u_is_container or u_is_tensor:
                v = preds[i]
            else:
                v = preds
            # TODO(minjie): tensorize the message batching
            m = [self.edges[vv, uu][__MSG__] for vv in v]
            f_update = self.nodes[uu].get(__UFUNC__, self.u_func)
228
            assert f_update is not None, \
229
                "Update function not registered for node %s" % uu
230
            self.node[uu].update(f_update(self.nodes[uu], m))
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246

    def update_by_edge(self, u, v):
        """Trigger the message function on u->v and update v.

        Parameters
        ----------
        u : node, container or tensor
          The source node(s).
        v : node, container or tensor
          The destination node(s).
        """
        self.sendto(u, v)
        # TODO(minjie): tensorize the following loops.
        preds = defaultdict(list)
        for uu, vv in utils.edge_iter(u, v):
            preds[vv].append(uu)
Minjie Wang's avatar
Minjie Wang committed
247
248
249
250
251
252
253
254
        if len(preds) == 1:
            dst = list(preds.keys())[0]
            src = preds[dst]
            self.recvfrom(dst, src)
        elif len(preds) > 1:
            dst = list(preds.keys())
            src = [preds[d] for d in dst]
            self.recvfrom(dst, src)
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291

    def update_to(self, u):
        """Pull messages from the node's predecessors and then update it.

        Parameters
        ----------
        u : node, container or tensor
          The node to be updated.
        """
        # TODO(minjie): tensorize the following code.
        for uu in utils.node_iter(u):
            assert uu in self.nodes
            preds = list(self.pred[uu])
            self.sendto(preds, uu)
            self.recvfrom(uu, preds)

    def update_from(self, u):
        """Send message from the node to its successors and update them.

        Parameters
        ----------
        u : node, container or tensor
          The node that sends out messages.
        """
        # TODO(minjie): tensorize the following code.
        for uu in utils.node_iter(u):
            assert uu in self.nodes
            for v in self.succ[uu]:
                self.update_by_edge(uu, v)

    def update_all(self):
        """Send messages through all the edges and update all nodes.
        """
        # TODO(minjie): tensorize the following code.
        u = [uu for uu, _ in self.edges]
        v = [vv for _, vv in self.edges]
        self.sendto(u, v)
292
        self.recvfrom(list(self.nodes()))
293

294
    def propagate(self, iterator='bfs', **kwargs):
295
296
297
298
299
300
301
302
303
304
305
306
307
        """Propagate messages and update nodes using iterator.

        A convenient function for passing messages and updating
        nodes according to the iterator. The iterator can be
        any of the pre-defined iterators ('bfs', 'dfs', 'pre-order',
        'mid-order', 'post-order'). The computation will be unrolled
        in the backend efficiently. User can also provide custom
        iterator that generates the edges and nodes.

        Parameters
        ----------
        iterator : str or generator of steps.
          The iterator of the graph.
308
309
        kwargs : keyword arguments, optional
            Arguments for pre-defined iterators.
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
        """
        if isinstance(iterator, str):
            # TODO Call pre-defined routine to unroll the computation.
            raise RuntimeError('Not implemented.')
        else:
            # NOTE: the iteration can return multiple edges at each step.
            for u, v in iterator:
                self.update_by_edge(u, v)

    def draw(self):
        """Plot the graph using dot."""
        from networkx.drawing.nx_agraph import graphviz_layout

        pos = graphviz_layout(self, prog='dot')
        nx.draw(self, pos, with_labels=True)

    def _nodes_or_all(self, nodes='all'):
        return self.nodes() if nodes == 'all' else nodes

    def _edges_or_all(self, edges='all'):
        return self.edges() if edges == 'all' else edges