Unverified Commit a3ce6e2f authored by Zheng Zhang's avatar Zheng Zhang Committed by GitHub
Browse files

Merge pull request #5 from BarclayII/gq-pytorch

updates to support nx 2.1
parents 572b289e 9c9ac7c9
......@@ -3,37 +3,25 @@ import torch as T
import torch.nn as NN
from util import *
class DiGraph(nx.DiGraph, NN.Module):
class DiGraph(NN.Module):
'''
Reserved attributes:
* state: node state vectors during message passing iterations
edges does not have "state vectors"; the "state" field is reserved for storing messages
* tag: node-/edge-specific feature tensors or other data
'''
def __init__(self, data=None, **attr):
def __init__(self, graph):
NN.Module.__init__(self)
nx.DiGraph.__init__(self, data=data, **attr)
self.G = graph
self.message_funcs = []
self.update_funcs = []
def add_node(self, n, state=None, tag=None, attr_dict=None, **attr):
nx.DiGraph.add_node(self, n, state=state, tag=None, attr_dict=attr_dict, **attr)
def add_nodes_from(self, nodes, state=None, tag=None, **attr):
nx.DiGraph.add_nodes_from(self, nodes, state=state, tag=tag, **attr)
def add_edge(self, u, v, tag=None, attr_dict=None, **attr):
nx.DiGraph.add_edge(self, u, v, tag=tag, attr_dict=attr_dict, **attr)
def add_edges_from(self, ebunch, tag=None, attr_dict=None, **attr):
nx.DiGraph.add_edges_from(self, ebunch, tag=tag, attr_dict=attr_dict, **attr)
def _nodes_or_all(self, nodes='all'):
return self.nodes() if nodes == 'all' else nodes
return self.G.nodes() if nodes == 'all' else nodes
def _edges_or_all(self, edges='all'):
return self.edges() if edges == 'all' else edges
return self.G.edges() if edges == 'all' else edges
def _node_tag_name(self, v):
return '(%s)' % v
......@@ -50,59 +38,67 @@ class DiGraph(nx.DiGraph, NN.Module):
nodes = self._nodes_or_all(nodes)
for v in nodes:
self.node[v]['state'] = tovar(T.zeros(shape))
self.G.node[v]['state'] = tovar(T.zeros(shape))
def init_node_tag_with(self, shape, init_func, dtype=T.float32, nodes='all', args=()):
nodes = self._nodes_or_all(nodes)
for v in nodes:
self.node[v]['tag'] = init_func(NN.Parameter(tovar(T.zeros(shape, dtype=dtype))), *args)
self.register_parameter(self._node_tag_name(v), self.node[v]['tag'])
self.G.node[v]['tag'] = init_func(NN.Parameter(tovar(T.zeros(shape, dtype=dtype))), *args)
self.register_parameter(self._node_tag_name(v), self.G.node[v]['tag'])
def init_edge_tag_with(self, shape, init_func, dtype=T.float32, edges='all', args=()):
edges = self._edges_or_all(edges)
for u, v in edges:
self[u][v]['tag'] = init_func(NN.Parameter(tovar(T.zeros(shape, dtype=dtype))), *args)
self.register_parameter(self._edge_tag_name(u, v), self[u][v]['tag'])
self.G[u][v]['tag'] = init_func(NN.Parameter(tovar(T.zeros(shape, dtype=dtype))), *args)
self.register_parameter(self._edge_tag_name(u, v), self.G[u][v]['tag'])
def remove_node_tag(self, nodes='all'):
nodes = self._nodes_or_all(nodes)
for v in nodes:
delattr(self, self._node_tag_name(v))
del self.node[v]['tag']
del self.G.node[v]['tag']
def remove_edge_tag(self, edges='all'):
edges = self._edges_or_all(edges)
for u, v in edges:
delattr(self, self._edge_tag_name(u, v))
del self[u][v]['tag']
del self.G[u][v]['tag']
@property
def node(self):
return self.G.node
@property
def edges(self):
return self.G.edges
def edge_tags(self):
for u, v in self.edges():
yield self[u][v]['tag']
for u, v in self.G.edges():
yield self.G[u][v]['tag']
def node_tags(self):
for v in self.nodes():
yield self.node[v]['tag']
for v in self.G.nodes():
yield self.G.node[v]['tag']
def states(self):
for v in self.nodes():
yield self.node[v]['state']
for v in self.G.nodes():
yield self.G.node[v]['state']
def named_edge_tags(self):
for u, v in self.edges():
yield ((u, v), self[u][v]['tag'])
for u, v in self.G.edges():
yield ((u, v), self.G[u][v]['tag'])
def named_node_tags(self):
for v in self.nodes():
yield (v, self.node[v]['tag'])
for v in self.G.nodes():
yield (v, self.G.node[v]['tag'])
def named_states(self):
for v in self.nodes():
yield (v, self.node[v]['state'])
for v in self.G.nodes():
yield (v, self.G.node[v]['state'])
def register_message_func(self, message_func, edges='all', batched=False):
'''
......@@ -126,24 +122,24 @@ class DiGraph(nx.DiGraph, NN.Module):
if batched:
# FIXME: need to optimize since we are repeatedly stacking and
# unpacking
source = T.stack([self.node[u]['state'] for u, _ in ebunch])
edge_tags = [self[u][v]['tag'] for u, v in ebunch]
source = T.stack([self.G.node[u]['state'] for u, _ in ebunch])
edge_tags = [self.G[u][v].get('tag', None) for u, v in ebunch]
if all(t is None for t in edge_tags):
edge_tag = None
else:
edge_tag = T.stack([self[u][v]['tag'] for u, v in ebunch])
edge_tag = T.stack([self.G[u][v]['tag'] for u, v in ebunch])
message = f(source, edge_tag)
for i, (u, v) in enumerate(ebunch):
self[u][v]['state'] = message[i]
self.G[u][v]['state'] = message[i]
else:
for u, v in ebunch:
self[u][v]['state'] = f(
self.node[u]['state'],
self[u][v]['tag']
self.G[u][v]['state'] = f(
self.G.node[u]['state'],
self.G[u][v]['tag']
)
# update state
# TODO: does it make sense to batch update the nodes?
for vbunch, f, batched in self.update_funcs:
for v in vbunch:
self.node[v]['state'] = f(self.node[v], self.in_edges(v, data=True))
self.G.node[v]['state'] = f(self.G.node[v], list(self.G.in_edges(v, data=True)))
......@@ -105,7 +105,7 @@ class TreeGlimpsedClassifier(NN.Module):
G.add_edges_from(hy_edge_list)
G.add_edges_from(hb_edge_list)
self.G = DiGraph(G)
self.G = DiGraph(nx.DiGraph(G))
hh_edge_list = [(u, v)
for u, v in self.G.edges()
if self.G.node[u]['type'] == self.G.node[v]['type'] == 'h']
......@@ -175,6 +175,7 @@ class TreeGlimpsedClassifier(NN.Module):
n_bh_edges, batch_size, _ = source.shape
# FIXME: really using self.x is a bad design here
_, nchan, nrows, ncols = self.x.size()
source, _ = self.glimpse.rescale(source, False)
_source = source.reshape(-1, self.glimpse.att_params)
m_b = T.relu(self.bh_1(_source))
......@@ -232,23 +233,19 @@ class TreeGlimpsedClassifier(NN.Module):
self.G.zero_node_state((self.h_dims,), batch_size, nodes=self.h_nodes_list)
self.G.zero_node_state((self.n_classes,), batch_size, nodes=self.y_nodes_list)
full = self.glimpse.full().unsqueeze(0).expand(batch_size, self.glimpse.att_params)
for v in self.G.nodes():
if self.G.node[v]['type'] == 'b':
# Initialize bbox variables to cover the entire canvas
self.G.node[v]['state'] = full
self.G.zero_node_state((self.glimpse.att_params,), batch_size, nodes=self.b_nodes_list)
for t in range(self.steps):
self.G.step()
# We don't change b of the root
self.G.node['b0']['state'] = full
self.G.node['b0']['state'].zero_()
self.y_pre = T.stack(
[self.G.node['y%d' % i]['state'] for i in range(self.n_nodes - 1, self.n_nodes - self.n_leaves - 1, -1)],
1
)
self.v_B = T.stack(
[self.G.node['b%d' % i]['state'] for i in range(self.n_nodes)],
[self.glimpse.rescale(self.G.node['b%d' % i]['state'], False)[0] for i in range(self.n_nodes)],
1,
)
self.y_logprob = F.log_softmax(self.y_pre)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment