graph.py 5.14 KB
Newer Older
Gan Quan's avatar
Gan Quan committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import networkx as nx
import torch as T
import torch.nn as NN

class DiGraph(nx.DiGraph, NN.Module):
    '''
    Reserved attributes:
    * state: node state vectors during message passing iterations
        edges does not have "state vectors"; the "state" field is reserved for storing messages
    * tag: node-/edge-specific feature tensors or other data
    '''
    def __init__(self, data=None, **attr):
        NN.Module.__init__(self)
        nx.DiGraph.__init__(self, data=data, **attr)

        self.message_funcs = []
        self.update_funcs = []

    def add_node(self, n, state=None, tag=None, attr_dict=None, **attr):
        nx.DiGraph.add_node(self, n, state=state, tag=None, attr_dict=attr_dict, **attr)

    def add_nodes_from(self, nodes, state=None, tag=None, **attr):
        nx.DiGraph.add_nodes_from(self, nodes, state=state, tag=tag, **attr)

    def add_edge(self, u, v, tag=None, attr_dict=None, **attr):
        nx.DiGraph.add_edge(self, u, v, tag=tag, attr_dict=attr_dict, **attr)

    def add_edges_from(self, ebunch, tag=tag, attr_dict=None, **attr):
        nx.DiGraph.add_edges_from(self, ebunch, tag=tag, attr_dict=attr_dict, **attr)

    def _nodes_or_all(self, nodes='all'):
        return self.nodes() if nodes == 'all' else nodes

    def _edges_or_all(self, edges='all'):
        return self.edges() if edges == 'all' else edges

    def _node_tag_name(self, v):
        return '(%s)' % v

    def _edge_tag_name(self, u, v):
        return '(%s, %s)' % (min(u, v), max(u, v))

    def zero_node_state(self, state_dims, batch_size=None, nodes='all'):
        shape = (
                [batch_size] + list(state_dims)
                if batch_size is not None
                else state_dims
                )
        nodes = self._nodes_or_all(nodes)

        for v in nodes:
            self.node[v]['state'] = T.zeros(shape)

    def init_node_tag_with(self, shape, init_func, dtype=T.float32, nodes='all', args=()):
        nodes = self._nodes_or_all(nodes)

        for v in nodes:
            self.node[v]['tag'] = init_func(NN.Parameter(T.zeros(shape, dtype=dtype)), *args)
            self.register_parameter(self._node_tag_name(v), self.node[v]['tag'])

    def init_edge_tag_with(self, shape, init_func, dtype=T.float32, edges='all', args=()):
        edges = self._edges_or_all(edges)

        for u, v in edges:
            self[u][v]['tag'] = init_func(NN.Parameter(T.zeros(shape, dtype=dtype)), *args)
            self.register_parameter(self._edge_tag_name(u, v), self[u][v]['tag'])

    def remove_node_tag(self, nodes='all'):
        nodes = self._nodes_or_all(nodes)

        for v in nodes:
            delattr(self, self._node_tag_name(v))
            del self.node[v]['tag']

    def remove_edge_tag(self, edges='all'):
        edges = self._edges_or_all(edges)

        for u, v in edges:
            delattr(self, self._edge_tag_name(u, v))
            del self[u][v]['tag']

    def edge_tags(self):
        for u, v in self.edges():
            yield self[u][v]['tag']

    def node_tags(self):
        for v in self.nodes():
            yield self.node[v]['tag']

    def states(self):
        for v in self.nodes():
            yield self.node[v]['state']

    def named_edge_tags(self):
        for u, v in self.edges():
            yield ((u, v), self[u][v]['tag'])

    def named_node_tags(self):
        for v in self.nodes():
            yield (v, self.node[v]['tag'])

    def named_states(self):
        for v in self.nodes():
            yield (v, self.node[v]['state'])

    def register_message_func(self, message_func, edges='all', batched=False):
        '''
        batched: whether to do a single batched computation instead of iterating
        message function: accepts source state tensor and edge tag tensor, and
        returns a message tensor
        '''
        self.message_funcs.append((self._edges_or_all(edges), message_func, batched))

    def register_update_func(self, update_func, nodes='all', batched=False):
        '''
        batched: whether to do a single batched computation instead of iterating
        update function: accepts a node attribute dictionary (including state and tag),
        and a dictionary of edge attribute dictionaries
        '''
        self.update_funcs.append((self._nodes_or_all(nodes), update_func, batched))

    def step(self):
        # update message
        for ebunch, f, batched in self.message_funcs:
            if batched:
                # FIXME: need to optimize since we are repeatedly stacking and
                # unpacking
                source = T.stack([self.node[u]['state'] for u, _ in ebunch])
                edge_tag = T.stack([self[u][v]['tag'] for u, v in ebunch])
                message = f(source, edge_tag)
                for u, v in ebunch:
                    self[u][v]['state'] = message
            else:
                for u, v in ebunch:
                    self[u][v]['state'] = f(
                            self.node[u]['state'],
                            self[u][v]['tag']
                            )

        # update state
        # TODO: does it make sense to batch update the nodes?
        for v, f in self.update_funcs:
            self.node[v]['state'] = f(self.node[v], self[v])