Unverified Commit a3ce6e2f authored by Zheng Zhang's avatar Zheng Zhang Committed by GitHub
Browse files

Merge pull request #5 from BarclayII/gq-pytorch

updates to support nx 2.1
parents 572b289e 9c9ac7c9
...@@ -3,37 +3,25 @@ import torch as T ...@@ -3,37 +3,25 @@ import torch as T
import torch.nn as NN import torch.nn as NN
from util import * from util import *
class DiGraph(nx.DiGraph, NN.Module): class DiGraph(NN.Module):
''' '''
Reserved attributes: Reserved attributes:
* state: node state vectors during message passing iterations * state: node state vectors during message passing iterations
edges does not have "state vectors"; the "state" field is reserved for storing messages edges does not have "state vectors"; the "state" field is reserved for storing messages
* tag: node-/edge-specific feature tensors or other data * tag: node-/edge-specific feature tensors or other data
''' '''
def __init__(self, data=None, **attr): def __init__(self, graph):
NN.Module.__init__(self) NN.Module.__init__(self)
nx.DiGraph.__init__(self, data=data, **attr)
self.G = graph
self.message_funcs = [] self.message_funcs = []
self.update_funcs = [] self.update_funcs = []
def add_node(self, n, state=None, tag=None, attr_dict=None, **attr):
nx.DiGraph.add_node(self, n, state=state, tag=None, attr_dict=attr_dict, **attr)
def add_nodes_from(self, nodes, state=None, tag=None, **attr):
nx.DiGraph.add_nodes_from(self, nodes, state=state, tag=tag, **attr)
def add_edge(self, u, v, tag=None, attr_dict=None, **attr):
nx.DiGraph.add_edge(self, u, v, tag=tag, attr_dict=attr_dict, **attr)
def add_edges_from(self, ebunch, tag=None, attr_dict=None, **attr):
nx.DiGraph.add_edges_from(self, ebunch, tag=tag, attr_dict=attr_dict, **attr)
def _nodes_or_all(self, nodes='all'): def _nodes_or_all(self, nodes='all'):
return self.nodes() if nodes == 'all' else nodes return self.G.nodes() if nodes == 'all' else nodes
def _edges_or_all(self, edges='all'): def _edges_or_all(self, edges='all'):
return self.edges() if edges == 'all' else edges return self.G.edges() if edges == 'all' else edges
def _node_tag_name(self, v): def _node_tag_name(self, v):
return '(%s)' % v return '(%s)' % v
...@@ -50,59 +38,67 @@ class DiGraph(nx.DiGraph, NN.Module): ...@@ -50,59 +38,67 @@ class DiGraph(nx.DiGraph, NN.Module):
nodes = self._nodes_or_all(nodes) nodes = self._nodes_or_all(nodes)
for v in nodes: for v in nodes:
self.node[v]['state'] = tovar(T.zeros(shape)) self.G.node[v]['state'] = tovar(T.zeros(shape))
def init_node_tag_with(self, shape, init_func, dtype=T.float32, nodes='all', args=()): def init_node_tag_with(self, shape, init_func, dtype=T.float32, nodes='all', args=()):
nodes = self._nodes_or_all(nodes) nodes = self._nodes_or_all(nodes)
for v in nodes: for v in nodes:
self.node[v]['tag'] = init_func(NN.Parameter(tovar(T.zeros(shape, dtype=dtype))), *args) self.G.node[v]['tag'] = init_func(NN.Parameter(tovar(T.zeros(shape, dtype=dtype))), *args)
self.register_parameter(self._node_tag_name(v), self.node[v]['tag']) self.register_parameter(self._node_tag_name(v), self.G.node[v]['tag'])
def init_edge_tag_with(self, shape, init_func, dtype=T.float32, edges='all', args=()): def init_edge_tag_with(self, shape, init_func, dtype=T.float32, edges='all', args=()):
edges = self._edges_or_all(edges) edges = self._edges_or_all(edges)
for u, v in edges: for u, v in edges:
self[u][v]['tag'] = init_func(NN.Parameter(tovar(T.zeros(shape, dtype=dtype))), *args) self.G[u][v]['tag'] = init_func(NN.Parameter(tovar(T.zeros(shape, dtype=dtype))), *args)
self.register_parameter(self._edge_tag_name(u, v), self[u][v]['tag']) self.register_parameter(self._edge_tag_name(u, v), self.G[u][v]['tag'])
def remove_node_tag(self, nodes='all'): def remove_node_tag(self, nodes='all'):
nodes = self._nodes_or_all(nodes) nodes = self._nodes_or_all(nodes)
for v in nodes: for v in nodes:
delattr(self, self._node_tag_name(v)) delattr(self, self._node_tag_name(v))
del self.node[v]['tag'] del self.G.node[v]['tag']
def remove_edge_tag(self, edges='all'): def remove_edge_tag(self, edges='all'):
edges = self._edges_or_all(edges) edges = self._edges_or_all(edges)
for u, v in edges: for u, v in edges:
delattr(self, self._edge_tag_name(u, v)) delattr(self, self._edge_tag_name(u, v))
del self[u][v]['tag'] del self.G[u][v]['tag']
@property
def node(self):
return self.G.node
@property
def edges(self):
return self.G.edges
def edge_tags(self): def edge_tags(self):
for u, v in self.edges(): for u, v in self.G.edges():
yield self[u][v]['tag'] yield self.G[u][v]['tag']
def node_tags(self): def node_tags(self):
for v in self.nodes(): for v in self.G.nodes():
yield self.node[v]['tag'] yield self.G.node[v]['tag']
def states(self): def states(self):
for v in self.nodes(): for v in self.G.nodes():
yield self.node[v]['state'] yield self.G.node[v]['state']
def named_edge_tags(self): def named_edge_tags(self):
for u, v in self.edges(): for u, v in self.G.edges():
yield ((u, v), self[u][v]['tag']) yield ((u, v), self.G[u][v]['tag'])
def named_node_tags(self): def named_node_tags(self):
for v in self.nodes(): for v in self.G.nodes():
yield (v, self.node[v]['tag']) yield (v, self.G.node[v]['tag'])
def named_states(self): def named_states(self):
for v in self.nodes(): for v in self.G.nodes():
yield (v, self.node[v]['state']) yield (v, self.G.node[v]['state'])
def register_message_func(self, message_func, edges='all', batched=False): def register_message_func(self, message_func, edges='all', batched=False):
''' '''
...@@ -126,24 +122,24 @@ class DiGraph(nx.DiGraph, NN.Module): ...@@ -126,24 +122,24 @@ class DiGraph(nx.DiGraph, NN.Module):
if batched: if batched:
# FIXME: need to optimize since we are repeatedly stacking and # FIXME: need to optimize since we are repeatedly stacking and
# unpacking # unpacking
source = T.stack([self.node[u]['state'] for u, _ in ebunch]) source = T.stack([self.G.node[u]['state'] for u, _ in ebunch])
edge_tags = [self[u][v]['tag'] for u, v in ebunch] edge_tags = [self.G[u][v].get('tag', None) for u, v in ebunch]
if all(t is None for t in edge_tags): if all(t is None for t in edge_tags):
edge_tag = None edge_tag = None
else: else:
edge_tag = T.stack([self[u][v]['tag'] for u, v in ebunch]) edge_tag = T.stack([self.G[u][v]['tag'] for u, v in ebunch])
message = f(source, edge_tag) message = f(source, edge_tag)
for i, (u, v) in enumerate(ebunch): for i, (u, v) in enumerate(ebunch):
self[u][v]['state'] = message[i] self.G[u][v]['state'] = message[i]
else: else:
for u, v in ebunch: for u, v in ebunch:
self[u][v]['state'] = f( self.G[u][v]['state'] = f(
self.node[u]['state'], self.G.node[u]['state'],
self[u][v]['tag'] self.G[u][v]['tag']
) )
# update state # update state
# TODO: does it make sense to batch update the nodes? # TODO: does it make sense to batch update the nodes?
for vbunch, f, batched in self.update_funcs: for vbunch, f, batched in self.update_funcs:
for v in vbunch: for v in vbunch:
self.node[v]['state'] = f(self.node[v], self.in_edges(v, data=True)) self.G.node[v]['state'] = f(self.G.node[v], list(self.G.in_edges(v, data=True)))
...@@ -105,7 +105,7 @@ class TreeGlimpsedClassifier(NN.Module): ...@@ -105,7 +105,7 @@ class TreeGlimpsedClassifier(NN.Module):
G.add_edges_from(hy_edge_list) G.add_edges_from(hy_edge_list)
G.add_edges_from(hb_edge_list) G.add_edges_from(hb_edge_list)
self.G = DiGraph(G) self.G = DiGraph(nx.DiGraph(G))
hh_edge_list = [(u, v) hh_edge_list = [(u, v)
for u, v in self.G.edges() for u, v in self.G.edges()
if self.G.node[u]['type'] == self.G.node[v]['type'] == 'h'] if self.G.node[u]['type'] == self.G.node[v]['type'] == 'h']
...@@ -175,6 +175,7 @@ class TreeGlimpsedClassifier(NN.Module): ...@@ -175,6 +175,7 @@ class TreeGlimpsedClassifier(NN.Module):
n_bh_edges, batch_size, _ = source.shape n_bh_edges, batch_size, _ = source.shape
# FIXME: really using self.x is a bad design here # FIXME: really using self.x is a bad design here
_, nchan, nrows, ncols = self.x.size() _, nchan, nrows, ncols = self.x.size()
source, _ = self.glimpse.rescale(source, False)
_source = source.reshape(-1, self.glimpse.att_params) _source = source.reshape(-1, self.glimpse.att_params)
m_b = T.relu(self.bh_1(_source)) m_b = T.relu(self.bh_1(_source))
...@@ -232,23 +233,19 @@ class TreeGlimpsedClassifier(NN.Module): ...@@ -232,23 +233,19 @@ class TreeGlimpsedClassifier(NN.Module):
self.G.zero_node_state((self.h_dims,), batch_size, nodes=self.h_nodes_list) self.G.zero_node_state((self.h_dims,), batch_size, nodes=self.h_nodes_list)
self.G.zero_node_state((self.n_classes,), batch_size, nodes=self.y_nodes_list) self.G.zero_node_state((self.n_classes,), batch_size, nodes=self.y_nodes_list)
full = self.glimpse.full().unsqueeze(0).expand(batch_size, self.glimpse.att_params) self.G.zero_node_state((self.glimpse.att_params,), batch_size, nodes=self.b_nodes_list)
for v in self.G.nodes():
if self.G.node[v]['type'] == 'b':
# Initialize bbox variables to cover the entire canvas
self.G.node[v]['state'] = full
for t in range(self.steps): for t in range(self.steps):
self.G.step() self.G.step()
# We don't change b of the root # We don't change b of the root
self.G.node['b0']['state'] = full self.G.node['b0']['state'].zero_()
self.y_pre = T.stack( self.y_pre = T.stack(
[self.G.node['y%d' % i]['state'] for i in range(self.n_nodes - 1, self.n_nodes - self.n_leaves - 1, -1)], [self.G.node['y%d' % i]['state'] for i in range(self.n_nodes - 1, self.n_nodes - self.n_leaves - 1, -1)],
1 1
) )
self.v_B = T.stack( self.v_B = T.stack(
[self.G.node['b%d' % i]['state'] for i in range(self.n_nodes)], [self.glimpse.rescale(self.G.node['b%d' % i]['state'], False)[0] for i in range(self.n_nodes)],
1, 1,
) )
self.y_logprob = F.log_softmax(self.y_pre) self.y_logprob = F.log_softmax(self.y_pre)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment