Unverified Commit 75a43e4a authored by Zheng Zhang's avatar Zheng Zhang Committed by GitHub
Browse files

Merge pull request #7 from zzhang-cn/tensor

Tensor
parents 447d16bd fa900595
import networkx as nx
import torch as T
import torch.nn as NN
from util import *
class DiGraph(NN.Module):
'''
Reserved attributes:
* state: node state vectors during message passing iterations
edges does not have "state vectors"; the "state" field is reserved for storing messages
* tag: node-/edge-specific feature tensors or other data
'''
def __init__(self, graph):
NN.Module.__init__(self)
self.G = graph
self.message_funcs = []
self.update_funcs = []
def _nodes_or_all(self, nodes='all'):
return self.G.nodes() if nodes == 'all' else nodes
def _edges_or_all(self, edges='all'):
return self.G.edges() if edges == 'all' else edges
def _node_tag_name(self, v):
return '(%s)' % v
def _edge_tag_name(self, u, v):
return '(%s, %s)' % (min(u, v), max(u, v))
def zero_node_state(self, state_dims, batch_size=None, nodes='all'):
shape = (
[batch_size] + list(state_dims)
if batch_size is not None
else state_dims
)
nodes = self._nodes_or_all(nodes)
for v in nodes:
self.G.node[v]['state'] = tovar(T.zeros(shape))
def init_node_tag_with(self, shape, init_func, dtype=T.float32, nodes='all', args=()):
nodes = self._nodes_or_all(nodes)
for v in nodes:
self.G.node[v]['tag'] = init_func(NN.Parameter(tovar(T.zeros(shape, dtype=dtype))), *args)
self.register_parameter(self._node_tag_name(v), self.G.node[v]['tag'])
def init_edge_tag_with(self, shape, init_func, dtype=T.float32, edges='all', args=()):
edges = self._edges_or_all(edges)
for u, v in edges:
self.G[u][v]['tag'] = init_func(NN.Parameter(tovar(T.zeros(shape, dtype=dtype))), *args)
self.register_parameter(self._edge_tag_name(u, v), self.G[u][v]['tag'])
def remove_node_tag(self, nodes='all'):
nodes = self._nodes_or_all(nodes)
for v in nodes:
delattr(self, self._node_tag_name(v))
del self.G.node[v]['tag']
def remove_edge_tag(self, edges='all'):
edges = self._edges_or_all(edges)
for u, v in edges:
delattr(self, self._edge_tag_name(u, v))
del self.G[u][v]['tag']
@property
def node(self):
return self.G.node
@property
def edges(self):
return self.G.edges
def edge_tags(self):
for u, v in self.G.edges():
yield self.G[u][v]['tag']
def node_tags(self):
for v in self.G.nodes():
yield self.G.node[v]['tag']
def states(self):
for v in self.G.nodes():
yield self.G.node[v]['state']
def named_edge_tags(self):
for u, v in self.G.edges():
yield ((u, v), self.G[u][v]['tag'])
def named_node_tags(self):
for v in self.G.nodes():
yield (v, self.G.node[v]['tag'])
def named_states(self):
for v in self.G.nodes():
yield (v, self.G.node[v]['state'])
def register_message_func(self, message_func, edges='all', batched=False):
'''
batched: whether to do a single batched computation instead of iterating
message function: accepts source state tensor and edge tag tensor, and
returns a message tensor
'''
self.message_funcs.append((self._edges_or_all(edges), message_func, batched))
def register_update_func(self, update_func, nodes='all', batched=False):
'''
batched: whether to do a single batched computation instead of iterating
update function: accepts a node attribute dictionary (including state and tag),
and a list of tuples (source node, target node, edge attribute dictionary)
'''
self.update_funcs.append((self._nodes_or_all(nodes), update_func, batched))
def draw(self):
from networkx.drawing.nx_agraph import graphviz_layout
pos = graphviz_layout(self.G, prog='dot')
nx.draw(self.G, pos, with_labels=True)
def step(self):
# update message
for ebunch, f, batched in self.message_funcs:
if batched:
# FIXME: need to optimize since we are repeatedly stacking and
# unpacking
source = T.stack([self.G.node[u]['state'] for u, _ in ebunch])
edge_tags = [self.G[u][v].get('tag', None) for u, v in ebunch]
if all(t is None for t in edge_tags):
edge_tag = None
else:
edge_tag = T.stack([self.G[u][v]['tag'] for u, v in ebunch])
message = f(source, edge_tag)
for i, (u, v) in enumerate(ebunch):
self.G[u][v]['state'] = message[i]
else:
for u, v in ebunch:
self.G[u][v]['state'] = f(
self.G.node[u]['state'],
self.G[u][v]['tag']
)
# update state
# TODO: does it make sense to batch update the nodes?
for vbunch, f, batched in self.update_funcs:
for v in vbunch:
self.G.node[v]['state'] = f(self.G.node[v], list(self.G.in_edges(v, data=True)))
import torch as T
import torch.nn as NN
import torch.nn.init as INIT
import torch.nn.functional as F
import numpy as NP
import numpy.random as RNG
from util import *
from glimpse import create_glimpse
from zoneout import ZoneoutLSTMCell
from collections import namedtuple
import os
from graph import DiGraph
import networkx as nx
no_msg = os.getenv('NOMSG', False)
def build_cnn(**config):
cnn_list = []
filters = config['filters']
kernel_size = config['kernel_size']
in_channels = config.get('in_channels', 3)
final_pool_size = config['final_pool_size']
for i in range(len(filters)):
module = NN.Conv2d(
in_channels if i == 0 else filters[i-1],
filters[i],
kernel_size,
padding=tuple((_ - 1) // 2 for _ in kernel_size),
)
INIT.xavier_uniform_(module.weight)
INIT.constant_(module.bias, 0)
cnn_list.append(module)
if i < len(filters) - 1:
cnn_list.append(NN.LeakyReLU())
cnn_list.append(NN.AdaptiveMaxPool2d(final_pool_size))
return NN.Sequential(*cnn_list)
class TreeGlimpsedClassifier(NN.Module):
def __init__(self,
n_children=2,
n_depth=3,
h_dims=128,
node_tag_dims=128,
edge_tag_dims=128,
n_classes=10,
steps=5,
filters=[16, 32, 64, 128, 256],
kernel_size=(3, 3),
final_pool_size=(2, 2),
glimpse_type='gaussian',
glimpse_size=(15, 15),
):
'''
Basic idea:
* We detect objects through an undirected graphical model.
* The graphical model consists of a balanced tree of latent variables h
* Each h is then connected to a bbox variable b and a class variable y
* b of the root is fixed to cover the entire canvas
* All other h, b and y are updated through message passing
* The loss function should be either (not completed yet)
* multiset loss, or
* maximum bipartite matching (like Order Matters paper)
'''
NN.Module.__init__(self)
self.n_children = n_children
self.n_depth = n_depth
self.h_dims = h_dims
self.node_tag_dims = node_tag_dims
self.edge_tag_dims = edge_tag_dims
self.h_dims = h_dims
self.n_classes = n_classes
self.glimpse = create_glimpse(glimpse_type, glimpse_size)
self.steps = steps
self.cnn = build_cnn(
filters=filters,
kernel_size=kernel_size,
final_pool_size=final_pool_size,
)
# Create graph of latent variables
G = nx.balanced_tree(self.n_children, self.n_depth)
nx.relabel_nodes(G,
{i: 'h%d' % i for i in range(len(G.nodes()))},
False
)
self.h_nodes_list = h_nodes_list = list(G.nodes)
for h in h_nodes_list:
G.node[h]['type'] = 'h'
b_nodes_list = ['b%d' % i for i in range(len(h_nodes_list))]
y_nodes_list = ['y%d' % i for i in range(len(h_nodes_list))]
self.b_nodes_list = b_nodes_list
self.y_nodes_list = y_nodes_list
hy_edge_list = [(h, y) for h, y in zip(h_nodes_list, y_nodes_list)]
hb_edge_list = [(h, b) for h, b in zip(h_nodes_list, b_nodes_list)]
yh_edge_list = [(y, h) for y, h in zip(y_nodes_list, h_nodes_list)]
bh_edge_list = [(b, h) for b, h in zip(b_nodes_list, h_nodes_list)]
G.add_nodes_from(b_nodes_list, type='b')
G.add_nodes_from(y_nodes_list, type='y')
G.add_edges_from(hy_edge_list)
G.add_edges_from(hb_edge_list)
self.G = DiGraph(nx.DiGraph(G))
hh_edge_list = [(u, v)
for u, v in self.G.edges()
if self.G.node[u]['type'] == self.G.node[v]['type'] == 'h']
self.G.init_node_tag_with(node_tag_dims, T.nn.init.uniform_, args=(-.01, .01))
self.G.init_edge_tag_with(
edge_tag_dims,
T.nn.init.uniform_,
args=(-.01, .01),
edges=hy_edge_list + hb_edge_list + bh_edge_list
)
self.G.init_edge_tag_with(
h_dims * n_classes,
T.nn.init.uniform_,
args=(-.01, .01),
edges=yh_edge_list
)
# y -> h. An attention over embeddings dynamically generated through edge tags
self.G.register_message_func(self._y_to_h, edges=yh_edge_list, batched=True)
# b -> h. Projects b and edge tag to the same dimension, then concatenates and projects to h
self.bh_1 = NN.Linear(self.glimpse.att_params, h_dims)
self.bh_2 = NN.Linear(edge_tag_dims, h_dims)
self.bh_all = NN.Linear(2 * h_dims + filters[-1] * NP.prod(final_pool_size), h_dims)
self.G.register_message_func(self._b_to_h, edges=bh_edge_list, batched=True)
# h -> h. Just passes h itself
self.G.register_message_func(self._h_to_h, edges=hh_edge_list, batched=True)
# h -> b. Concatenates h with edge tag and go through MLP.
# Produces Δb
self.hb = NN.Linear(h_dims + edge_tag_dims, self.glimpse.att_params)
self.G.register_message_func(self._h_to_b, edges=hb_edge_list, batched=True)
# h -> y. Concatenates h with edge tag and go through MLP.
# Produces Δy
self.hy = NN.Linear(h_dims + edge_tag_dims, self.n_classes)
self.G.register_message_func(self._h_to_y, edges=hy_edge_list, batched=True)
# b update: just adds the original b by Δb
self.G.register_update_func(self._update_b, nodes=b_nodes_list, batched=False)
# y update: also adds y by Δy
self.G.register_update_func(self._update_y, nodes=y_nodes_list, batched=False)
# h update: simply adds h by the average messages and then passes it through ReLU
self.G.register_update_func(self._update_h, nodes=h_nodes_list, batched=False)
def _y_to_h(self, source, edge_tag):
'''
source: (n_yh_edges, batch_size, 10) logits
edge_tag: (n_yh_edges, edge_tag_dims)
'''
n_yh_edges, batch_size, _ = source.shape
w = edge_tag.reshape(n_yh_edges, 1, self.n_classes, self.h_dims)
w = w.expand(n_yh_edges, batch_size, self.n_classes, self.h_dims)
source = source[:, :, None, :]
return (F.softmax(source) @ w).reshape(n_yh_edges, batch_size, self.h_dims)
def _b_to_h(self, source, edge_tag):
'''
source: (n_bh_edges, batch_size, 6) bboxes
edge_tag: (n_bh_edges, edge_tag_dims)
'''
n_bh_edges, batch_size, _ = source.shape
# FIXME: really using self.x is a bad design here
_, nchan, nrows, ncols = self.x.size()
source, _ = self.glimpse.rescale(source, False)
_source = source.reshape(-1, self.glimpse.att_params)
m_b = T.relu(self.bh_1(_source))
m_t = T.relu(self.bh_2(edge_tag))
m_t = m_t[:, None, :].expand(n_bh_edges, batch_size, self.h_dims)
m_t = m_t.reshape(-1, self.h_dims)
# glimpse takes batch dimension first, glimpse dimension second.
# here, the dimension of @source is n_bh_edges (# of glimpses), then
# batch size, so we transpose them
g = self.glimpse(self.x, source.transpose(0, 1)).transpose(0, 1)
grows, gcols = g.size()[-2:]
g = g.reshape(n_bh_edges * batch_size, nchan, grows, gcols)
phi = self.cnn(g).reshape(n_bh_edges * batch_size, -1)
# TODO: add an attribute (g) to h
m = self.bh_all(T.cat([m_b, m_t, phi], 1))
m = m.reshape(n_bh_edges, batch_size, self.h_dims)
return m
def _h_to_h(self, source, edge_tag):
return source
def _h_to_b(self, source, edge_tag):
n_hb_edges, batch_size, _ = source.shape
edge_tag = edge_tag[:, None]
edge_tag = edge_tag.expand(n_hb_edges, batch_size, self.edge_tag_dims)
I = T.cat([source, edge_tag], -1).reshape(n_hb_edges * batch_size, -1)
db = self.hb(I)
return db.reshape(n_hb_edges, batch_size, -1)
def _h_to_y(self, source, edge_tag):
n_hy_edges, batch_size, _ = source.shape
edge_tag = edge_tag[:, None]
edge_tag = edge_tag.expand(n_hy_edges, batch_size, self.edge_tag_dims)
I = T.cat([source, edge_tag], -1).reshape(n_hy_edges * batch_size, -1)
dy = self.hy(I)
return dy.reshape(n_hy_edges, batch_size, -1)
def _update_b(self, b, b_n):
return b['state'] + b_n[0][2]['state']
def _update_y(self, y, y_n):
return y['state'] + y_n[0][2]['state']
def _update_h(self, h, h_n):
m = T.stack([e[2]['state'] for e in h_n]).mean(0)
return T.relu(h['state'] + m)
def forward(self, x, y=None):
self.x = x
batch_size = x.shape[0]
self.G.zero_node_state((self.h_dims,), batch_size, nodes=self.h_nodes_list)
self.G.zero_node_state((self.n_classes,), batch_size, nodes=self.y_nodes_list)
self.G.zero_node_state((self.glimpse.att_params,), batch_size, nodes=self.b_nodes_list)
for t in range(self.steps):
self.G.step()
# We don't change b of the root
self.G.node['b0']['state'].zero_()
self.y_pre = T.stack(
[self.G.node['y%d' % i]['state'] for i in range(self.n_nodes - 1, self.n_nodes - self.n_leaves - 1, -1)],
1
)
self.v_B = T.stack(
[self.glimpse.rescale(self.G.node['b%d' % i]['state'], False)[0] for i in range(self.n_nodes)],
1,
)
self.y_logprob = F.log_softmax(self.y_pre)
return self.G.node['h0']['state']
@property
def n_nodes(self):
return (self.n_children ** self.n_depth - 1) // (self.n_children - 1)
@property
def n_leaves(self):
return self.n_children ** (self.n_depth - 1)
import networkx as nx import networkx as nx
from networkx.classes.graph import Graph #from networkx.classes.graph import Graph
from networkx.classes.digraph import DiGraph
import torch as th
#import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable as Var
# TODO: make representation numpy/tensor from pytorch
# TODO: make message/update functions pytorch functions
# TODO: loss functions and training # TODO: loss functions and training
class mx_Graph(Graph): class mx_Graph(DiGraph):
def __init__(self, *args, **kargs): def __init__(self, *args, **kargs):
super(mx_Graph, self).__init__(*args, **kargs) super(mx_Graph, self).__init__(*args, **kargs)
self.set_msg_func() self.set_msg_func()
self.set_gather_func()
self.set_reduction_func()
self.set_update_func() self.set_update_func()
self.set_readout_func() self.set_readout_func()
self.init_reprs() self.init_reprs()
...@@ -26,47 +32,58 @@ class mx_Graph(Graph): ...@@ -26,47 +32,58 @@ class mx_Graph(Graph):
assert u in self.nodes assert u in self.nodes
return self.nodes[u][name] return self.nodes[u][name]
def set_reduction_func(self):
def _default_reduction_func(x_s):
out = th.stack(x_s)
out = th.sum(out, dim=0)
return out
self._reduction_func = _default_reduction_func
def set_gather_func(self, u=None):
pass
def set_msg_func(self, func=None, u=None): def set_msg_func(self, func=None, u=None):
"""Function that gathers messages from neighbors""" """Function that gathers messages from neighbors"""
def _default_msg_func(u): def _default_msg_func(u):
assert u in self.nodes assert u in self.nodes
msg_gathered = 0 msg_gathered = []
for v in self.adj[u]: for v in self.pred[u]:
x = self.get_repr(v) x = self.get_repr(v)
if x is not None: if x is not None:
msg_gathered += x msg_gathered.append(x)
return msg_gathered return self._reduction_func(msg_gathered)
# TODO: per node message function # TODO: per node message function
# TODO: 'sum' should be a separate function # TODO: 'sum' should be a separate function
if func == None: if func == None:
self.msg_func = _default_msg_func self._msg_func = _default_msg_func
else: else:
self.msg_func = func self._msg_func = func
def set_update_func(self, func=None, u=None): def set_update_func(self, func=None, u=None):
""" """
Update function upon receiving an aggregate Update function upon receiving an aggregate
message from a node's neighbor message from a node's neighbor
""" """
def _default_update_func(u, m): def _default_update_func(x, m):
h_new = self.nodes[u]['h'] + m return x + m
self.set_repr(u, h_new)
# TODO: per node update function # TODO: per node update function
if func == None: if func == None:
self.update_func = _default_update_func self._update_func = _default_update_func
else: else:
self.update_func = func self._update_func = func
def set_readout_func(self, func=None): def set_readout_func(self, func=None):
"""Readout function of the whole graph""" """Readout function of the whole graph"""
def _default_readout_func(): def _default_readout_func():
readout = 0 valid_hs = []
for n in self.nodes: for x in self.nodes:
readout += self.nodes[n]['h'] h = self.get_repr(x)
return readout if h is not None:
valid_hs.append(h)
return self._reduction_func(valid_hs)
#
if func == None: if func == None:
self.readout_func = _default_readout_func self.readout_func = _default_readout_func
else: else:
...@@ -78,15 +95,21 @@ class mx_Graph(Graph): ...@@ -78,15 +95,21 @@ class mx_Graph(Graph):
def update_to(self, u): def update_to(self, u):
"""Pull messages from 1-step away neighbors of u""" """Pull messages from 1-step away neighbors of u"""
assert u in self.nodes assert u in self.nodes
m = self.msg_func(u=u) m = self._msg_func(u=u)
self.update_func(u, m) x = self.get_repr(u)
# TODO: ugly hack
if x is None:
y = self._update_func(m)
else:
y = self._update_func(x, m)
self.set_repr(u, y)
def update_from(self, u): def update_from(self, u):
"""Update u's 1-step away neighbors""" """Update u's 1-step away neighbors"""
assert u in self.nodes assert u in self.nodes
# TODO: this asks v to pull from nodes other than # TODO: this asks v to pull from nodes other than
# TODO: u, is this a good thing? # TODO: u, is this a good thing?
for v in self.adj[u]: for v in self.succ[u]:
self.update_to(v) self.update_to(v)
def print_all(self): def print_all(self):
...@@ -95,25 +118,39 @@ class mx_Graph(Graph): ...@@ -95,25 +118,39 @@ class mx_Graph(Graph):
print() print()
if __name__ == '__main__': if __name__ == '__main__':
th.random.manual_seed(0)
''': this makes a digraph with double edges
tg = nx.path_graph(10) tg = nx.path_graph(10)
g = mx_Graph(tg) g = mx_Graph(tg)
g.print_all() g.print_all()
tr = nx.balanced_tree(2, 3) # this makes a uni-edge tree
tr = nx.bfs_tree(nx.balanced_tree(2, 3), 0)
m_tr = mx_Graph(tr) m_tr = mx_Graph(tr)
m_tr.print_all() m_tr.print_all()
'''
print("testing GRU update")
g = mx_Graph(nx.path_graph(3)) g = mx_Graph(nx.path_graph(3))
g.set_update_func(nn.GRUCell(4, 4))
for n in g: for n in g:
g.set_repr(n, int(n) + 10) g.set_repr(n, Var(th.rand(2, 4)))
g.print_all()
print(g.readout()) print("\t**before:"); g.print_all()
g.update_from(0)
print("before update:\t", g.nodes[0]) g.update_from(1)
g.update_to(0) print("\t**after:"); g.print_all()
print('after update:\t', g.nodes[0])
g.print_all() print("\ntesting fwd update")
g.clear()
print(g.readout()) g.add_path([0, 1, 2])
g.init_reprs()
fwd_net = nn.Sequential(nn.Linear(4, 4), nn.ReLU())
g.set_update_func(fwd_net)
g.set_repr(0, Var(th.rand(2, 4)))
print("\t**before:"); g.print_all()
g.update_from(0)
g.update_from(1)
print("\t**after:"); g.print_all()
import networkx as nx
from networkx.classes.graph import Graph
from networkx.classes.digraph import DiGraph
# TODO: make representation numpy/tensor from pytorch
# TODO: make message/update functions pytorch functions
# TODO: loss functions and training
class mx_Graph(DiGraph):
def __init__(self, *args, **kargs):
super(mx_Graph, self).__init__(*args, **kargs)
self.set_msg_func()
self.set_update_func()
self.set_readout_func()
self.init_reprs()
def init_reprs(self, h_init=None):
for n in self.nodes:
self.set_repr(n, h_init)
def set_repr(self, u, h_u, name='h'):
assert u in self.nodes
kwarg = {name: h_u}
self.add_node(u, **kwarg)
def get_repr(self, u, name='h'):
assert u in self.nodes
return self.nodes[u][name]
def set_msg_func(self, func=None, u=None):
"""Function that gathers messages from neighbors"""
def _default_msg_func(u):
assert u in self.nodes
msg_gathered = 0
for v in self.adj[u]:
x = self.get_repr(v)
if x is not None:
msg_gathered += x
return msg_gathered
# TODO: per node message function
# TODO: 'sum' should be a separate function
if func == None:
self.msg_func = _default_msg_func
else:
self.msg_func = func
def set_update_func(self, func=None, u=None):
"""
Update function upon receiving an aggregate
message from a node's neighbor
"""
def _default_update_func(u, m):
if self.nodes[u]['h'] is None:
h_new = m
else:
h_new = self.nodes[u]['h'] + m
self.set_repr(u, h_new)
# TODO: per node update function
if func == None:
self.update_func = _default_update_func
else:
self.update_func = func
def set_readout_func(self, func=None):
"""Readout function of the whole graph"""
def _default_readout_func():
readout = 0
for n in self.nodes:
readout += self.nodes[n]['h']
return readout
if func == None:
self.readout_func = _default_readout_func
else:
self.readout_func = func
def readout(self):
return self.readout_func()
def update_to(self, u):
"""Pull messages from 1-step away neighbors of u"""
assert u in self.nodes
m = self.msg_func(u=u)
self.update_func(u, m)
def update_from(self, u):
"""Update u's 1-step away neighbors"""
assert u in self.nodes
# TODO: this asks v to pull from nodes other than
# TODO: u, is this a good thing?
for v in self.adj[u]:
self.update_to(v)
def print_all(self):
for n in self.nodes:
print(n, self.nodes[n])
print()
if __name__ == '__main__':
tg = nx.path_graph(10)
g = mx_Graph(tg)
g.print_all()
tr = nx.balanced_tree(2, 3)
m_tr = mx_Graph(tr)
m_tr.print_all()
g = mx_Graph(nx.path_graph(3))
for n in g:
g.set_repr(n, int(n) + 10)
g.print_all()
print(g.readout())
print("before update:\t", g.nodes[0])
g.update_to(0)
print('after update:\t', g.nodes[0])
g.print_all()
print(g.readout())
g = mx_Graph(nx.bfs_tree(nx.path_graph(3), 0))
g.set_repr(0, 10)
g.print_all()
g.update_from(0)
g.print_all()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment