graph.py 5.47 KB
Newer Older
Gan Quan's avatar
Gan Quan committed
1
2
3
import networkx as nx
import torch as T
import torch.nn as NN
Gan Quan's avatar
Gan Quan committed
4
from util import *
Gan Quan's avatar
Gan Quan committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

class DiGraph(nx.DiGraph, NN.Module):
    '''
    Reserved attributes:
    * state: node state vectors during message passing iterations
        edges does not have "state vectors"; the "state" field is reserved for storing messages
    * tag: node-/edge-specific feature tensors or other data
    '''
    def __init__(self, data=None, **attr):
        NN.Module.__init__(self)
        nx.DiGraph.__init__(self, data=data, **attr)

        self.message_funcs = []
        self.update_funcs = []

    def add_node(self, n, state=None, tag=None, attr_dict=None, **attr):
        nx.DiGraph.add_node(self, n, state=state, tag=None, attr_dict=attr_dict, **attr)

    def add_nodes_from(self, nodes, state=None, tag=None, **attr):
        nx.DiGraph.add_nodes_from(self, nodes, state=state, tag=tag, **attr)

    def add_edge(self, u, v, tag=None, attr_dict=None, **attr):
        nx.DiGraph.add_edge(self, u, v, tag=tag, attr_dict=attr_dict, **attr)

Gan Quan's avatar
Gan Quan committed
29
    def add_edges_from(self, ebunch, tag=None, attr_dict=None, **attr):
Gan Quan's avatar
Gan Quan committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
        nx.DiGraph.add_edges_from(self, ebunch, tag=tag, attr_dict=attr_dict, **attr)

    def _nodes_or_all(self, nodes='all'):
        return self.nodes() if nodes == 'all' else nodes

    def _edges_or_all(self, edges='all'):
        return self.edges() if edges == 'all' else edges

    def _node_tag_name(self, v):
        return '(%s)' % v

    def _edge_tag_name(self, u, v):
        return '(%s, %s)' % (min(u, v), max(u, v))

    def zero_node_state(self, state_dims, batch_size=None, nodes='all'):
        shape = (
                [batch_size] + list(state_dims)
                if batch_size is not None
                else state_dims
                )
        nodes = self._nodes_or_all(nodes)

        for v in nodes:
Gan Quan's avatar
Gan Quan committed
53
            self.node[v]['state'] = tovar(T.zeros(shape))
Gan Quan's avatar
Gan Quan committed
54
55
56
57
58

    def init_node_tag_with(self, shape, init_func, dtype=T.float32, nodes='all', args=()):
        nodes = self._nodes_or_all(nodes)

        for v in nodes:
Gan Quan's avatar
Gan Quan committed
59
            self.node[v]['tag'] = init_func(NN.Parameter(tovar(T.zeros(shape, dtype=dtype))), *args)
Gan Quan's avatar
Gan Quan committed
60
61
62
63
64
65
            self.register_parameter(self._node_tag_name(v), self.node[v]['tag'])

    def init_edge_tag_with(self, shape, init_func, dtype=T.float32, edges='all', args=()):
        edges = self._edges_or_all(edges)

        for u, v in edges:
Gan Quan's avatar
Gan Quan committed
66
            self[u][v]['tag'] = init_func(NN.Parameter(tovar(T.zeros(shape, dtype=dtype))), *args)
Gan Quan's avatar
Gan Quan committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
            self.register_parameter(self._edge_tag_name(u, v), self[u][v]['tag'])

    def remove_node_tag(self, nodes='all'):
        nodes = self._nodes_or_all(nodes)

        for v in nodes:
            delattr(self, self._node_tag_name(v))
            del self.node[v]['tag']

    def remove_edge_tag(self, edges='all'):
        edges = self._edges_or_all(edges)

        for u, v in edges:
            delattr(self, self._edge_tag_name(u, v))
            del self[u][v]['tag']

    def edge_tags(self):
        for u, v in self.edges():
            yield self[u][v]['tag']

    def node_tags(self):
        for v in self.nodes():
            yield self.node[v]['tag']

    def states(self):
        for v in self.nodes():
            yield self.node[v]['state']

    def named_edge_tags(self):
        for u, v in self.edges():
            yield ((u, v), self[u][v]['tag'])

    def named_node_tags(self):
        for v in self.nodes():
            yield (v, self.node[v]['tag'])

    def named_states(self):
        for v in self.nodes():
            yield (v, self.node[v]['state'])

    def register_message_func(self, message_func, edges='all', batched=False):
        '''
        batched: whether to do a single batched computation instead of iterating
        message function: accepts source state tensor and edge tag tensor, and
        returns a message tensor
        '''
        self.message_funcs.append((self._edges_or_all(edges), message_func, batched))

    def register_update_func(self, update_func, nodes='all', batched=False):
        '''
        batched: whether to do a single batched computation instead of iterating
        update function: accepts a node attribute dictionary (including state and tag),
Gan Quan's avatar
Gan Quan committed
119
        and a list of tuples (source node, target node, edge attribute dictionary)
Gan Quan's avatar
Gan Quan committed
120
121
122
123
124
125
126
127
128
129
        '''
        self.update_funcs.append((self._nodes_or_all(nodes), update_func, batched))

    def step(self):
        # update message
        for ebunch, f, batched in self.message_funcs:
            if batched:
                # FIXME: need to optimize since we are repeatedly stacking and
                # unpacking
                source = T.stack([self.node[u]['state'] for u, _ in ebunch])
Gan Quan's avatar
Gan Quan committed
130
131
132
133
134
                edge_tags = [self[u][v]['tag'] for u, v in ebunch]
                if all(t is None for t in edge_tags):
                    edge_tag = None
                else:
                    edge_tag = T.stack([self[u][v]['tag'] for u, v in ebunch])
Gan Quan's avatar
Gan Quan committed
135
                message = f(source, edge_tag)
Gan Quan's avatar
bugfix  
Gan Quan committed
136
137
                for i, (u, v) in enumerate(ebunch):
                    self[u][v]['state'] = message[i]
Gan Quan's avatar
Gan Quan committed
138
139
140
141
142
143
144
145
146
            else:
                for u, v in ebunch:
                    self[u][v]['state'] = f(
                            self.node[u]['state'],
                            self[u][v]['tag']
                            )

        # update state
        # TODO: does it make sense to batch update the nodes?
Gan Quan's avatar
Gan Quan committed
147
148
149
        for vbunch, f, batched in self.update_funcs:
            for v in vbunch:
                self.node[v]['state'] = f(self.node[v], self.in_edges(v, data=True))