graph.py 5.2 KB
Newer Older
Gan Quan's avatar
Gan Quan committed
1
2
3
import networkx as nx
import torch as T
import torch.nn as NN
Gan Quan's avatar
Gan Quan committed
4
from util import *
Gan Quan's avatar
Gan Quan committed
5

Gan Quan's avatar
Gan Quan committed
6
class DiGraph(NN.Module):
Gan Quan's avatar
Gan Quan committed
7
8
9
10
11
12
    '''
    Reserved attributes:
    * state: node state vectors during message passing iterations
        edges does not have "state vectors"; the "state" field is reserved for storing messages
    * tag: node-/edge-specific feature tensors or other data
    '''
Gan Quan's avatar
Gan Quan committed
13
    def __init__(self, graph):
Gan Quan's avatar
Gan Quan committed
14
15
        NN.Module.__init__(self)

Gan Quan's avatar
Gan Quan committed
16
        self.G = graph
Gan Quan's avatar
Gan Quan committed
17
18
19
20
        self.message_funcs = []
        self.update_funcs = []

    def _nodes_or_all(self, nodes='all'):
Gan Quan's avatar
Gan Quan committed
21
        return self.G.nodes() if nodes == 'all' else nodes
Gan Quan's avatar
Gan Quan committed
22
23

    def _edges_or_all(self, edges='all'):
Gan Quan's avatar
Gan Quan committed
24
        return self.G.edges() if edges == 'all' else edges
Gan Quan's avatar
Gan Quan committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

    def _node_tag_name(self, v):
        return '(%s)' % v

    def _edge_tag_name(self, u, v):
        return '(%s, %s)' % (min(u, v), max(u, v))

    def zero_node_state(self, state_dims, batch_size=None, nodes='all'):
        shape = (
                [batch_size] + list(state_dims)
                if batch_size is not None
                else state_dims
                )
        nodes = self._nodes_or_all(nodes)

        for v in nodes:
Gan Quan's avatar
Gan Quan committed
41
            self.G.node[v]['state'] = tovar(T.zeros(shape))
Gan Quan's avatar
Gan Quan committed
42
43
44
45
46

    def init_node_tag_with(self, shape, init_func, dtype=T.float32, nodes='all', args=()):
        nodes = self._nodes_or_all(nodes)

        for v in nodes:
Gan Quan's avatar
Gan Quan committed
47
48
            self.G.node[v]['tag'] = init_func(NN.Parameter(tovar(T.zeros(shape, dtype=dtype))), *args)
            self.register_parameter(self._node_tag_name(v), self.G.node[v]['tag'])
Gan Quan's avatar
Gan Quan committed
49
50
51
52
53

    def init_edge_tag_with(self, shape, init_func, dtype=T.float32, edges='all', args=()):
        edges = self._edges_or_all(edges)

        for u, v in edges:
Gan Quan's avatar
Gan Quan committed
54
55
            self.G[u][v]['tag'] = init_func(NN.Parameter(tovar(T.zeros(shape, dtype=dtype))), *args)
            self.register_parameter(self._edge_tag_name(u, v), self.G[u][v]['tag'])
Gan Quan's avatar
Gan Quan committed
56
57
58
59
60
61

    def remove_node_tag(self, nodes='all'):
        nodes = self._nodes_or_all(nodes)

        for v in nodes:
            delattr(self, self._node_tag_name(v))
Gan Quan's avatar
Gan Quan committed
62
            del self.G.node[v]['tag']
Gan Quan's avatar
Gan Quan committed
63
64
65
66
67
68

    def remove_edge_tag(self, edges='all'):
        edges = self._edges_or_all(edges)

        for u, v in edges:
            delattr(self, self._edge_tag_name(u, v))
Gan Quan's avatar
Gan Quan committed
69
70
71
72
73
74
75
76
77
            del self.G[u][v]['tag']

    @property
    def node(self):
        return self.G.node

    @property
    def edges(self):
        return self.G.edges
Gan Quan's avatar
Gan Quan committed
78
79

    def edge_tags(self):
Gan Quan's avatar
Gan Quan committed
80
81
        for u, v in self.G.edges():
            yield self.G[u][v]['tag']
Gan Quan's avatar
Gan Quan committed
82
83

    def node_tags(self):
Gan Quan's avatar
Gan Quan committed
84
85
        for v in self.G.nodes():
            yield self.G.node[v]['tag']
Gan Quan's avatar
Gan Quan committed
86
87

    def states(self):
Gan Quan's avatar
Gan Quan committed
88
89
        for v in self.G.nodes():
            yield self.G.node[v]['state']
Gan Quan's avatar
Gan Quan committed
90
91

    def named_edge_tags(self):
Gan Quan's avatar
Gan Quan committed
92
93
        for u, v in self.G.edges():
            yield ((u, v), self.G[u][v]['tag'])
Gan Quan's avatar
Gan Quan committed
94
95

    def named_node_tags(self):
Gan Quan's avatar
Gan Quan committed
96
97
        for v in self.G.nodes():
            yield (v, self.G.node[v]['tag'])
Gan Quan's avatar
Gan Quan committed
98
99

    def named_states(self):
Gan Quan's avatar
Gan Quan committed
100
101
        for v in self.G.nodes():
            yield (v, self.G.node[v]['state'])
Gan Quan's avatar
Gan Quan committed
102
103
104
105
106
107
108
109
110
111
112
113
114

    def register_message_func(self, message_func, edges='all', batched=False):
        '''
        batched: whether to do a single batched computation instead of iterating
        message function: accepts source state tensor and edge tag tensor, and
        returns a message tensor
        '''
        self.message_funcs.append((self._edges_or_all(edges), message_func, batched))

    def register_update_func(self, update_func, nodes='all', batched=False):
        '''
        batched: whether to do a single batched computation instead of iterating
        update function: accepts a node attribute dictionary (including state and tag),
Gan Quan's avatar
Gan Quan committed
115
        and a list of tuples (source node, target node, edge attribute dictionary)
Gan Quan's avatar
Gan Quan committed
116
117
118
        '''
        self.update_funcs.append((self._nodes_or_all(nodes), update_func, batched))

zzhang-cn's avatar
zzhang-cn committed
119
120
121
122
123
124
    def draw(self):
        from networkx.drawing.nx_agraph import graphviz_layout

        pos = graphviz_layout(self.G, prog='dot')
        nx.draw(self.G, pos, with_labels=True)

Gan Quan's avatar
Gan Quan committed
125
126
127
128
129
130
    def step(self):
        # update message
        for ebunch, f, batched in self.message_funcs:
            if batched:
                # FIXME: need to optimize since we are repeatedly stacking and
                # unpacking
Gan Quan's avatar
Gan Quan committed
131
132
                source = T.stack([self.G.node[u]['state'] for u, _ in ebunch])
                edge_tags = [self.G[u][v].get('tag', None) for u, v in ebunch]
Gan Quan's avatar
Gan Quan committed
133
134
135
                if all(t is None for t in edge_tags):
                    edge_tag = None
                else:
Gan Quan's avatar
Gan Quan committed
136
                    edge_tag = T.stack([self.G[u][v]['tag'] for u, v in ebunch])
Gan Quan's avatar
Gan Quan committed
137
                message = f(source, edge_tag)
Gan Quan's avatar
bugfix  
Gan Quan committed
138
                for i, (u, v) in enumerate(ebunch):
Gan Quan's avatar
Gan Quan committed
139
                    self.G[u][v]['state'] = message[i]
Gan Quan's avatar
Gan Quan committed
140
141
            else:
                for u, v in ebunch:
Gan Quan's avatar
Gan Quan committed
142
143
144
                    self.G[u][v]['state'] = f(
                            self.G.node[u]['state'],
                            self.G[u][v]['tag']
Gan Quan's avatar
Gan Quan committed
145
146
147
148
                            )

        # update state
        # TODO: does it make sense to batch update the nodes?
Gan Quan's avatar
Gan Quan committed
149
150
        for vbunch, f, batched in self.update_funcs:
            for v in vbunch:
Gan Quan's avatar
Gan Quan committed
151
                self.G.node[v]['state'] = f(self.G.node[v], list(self.G.in_edges(v, data=True)))