"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "034673bbeb00452ed7167df35adbee5d436d3d52"
Unverified Commit 572b289e authored by Zheng Zhang's avatar Zheng Zhang Committed by GitHub
Browse files

Merge pull request #4 from BarclayII/gq-pytorch

more changes
parents 83e84e67 51391012
import networkx as nx import networkx as nx
import torch as T import torch as T
import torch.nn as NN import torch.nn as NN
from util import *
class DiGraph(nx.DiGraph, NN.Module): class DiGraph(nx.DiGraph, NN.Module):
''' '''
...@@ -25,7 +26,7 @@ class DiGraph(nx.DiGraph, NN.Module): ...@@ -25,7 +26,7 @@ class DiGraph(nx.DiGraph, NN.Module):
def add_edge(self, u, v, tag=None, attr_dict=None, **attr): def add_edge(self, u, v, tag=None, attr_dict=None, **attr):
nx.DiGraph.add_edge(self, u, v, tag=tag, attr_dict=attr_dict, **attr) nx.DiGraph.add_edge(self, u, v, tag=tag, attr_dict=attr_dict, **attr)
def add_edges_from(self, ebunch, tag=tag, attr_dict=None, **attr): def add_edges_from(self, ebunch, tag=None, attr_dict=None, **attr):
nx.DiGraph.add_edges_from(self, ebunch, tag=tag, attr_dict=attr_dict, **attr) nx.DiGraph.add_edges_from(self, ebunch, tag=tag, attr_dict=attr_dict, **attr)
def _nodes_or_all(self, nodes='all'): def _nodes_or_all(self, nodes='all'):
...@@ -49,20 +50,20 @@ class DiGraph(nx.DiGraph, NN.Module): ...@@ -49,20 +50,20 @@ class DiGraph(nx.DiGraph, NN.Module):
nodes = self._nodes_or_all(nodes) nodes = self._nodes_or_all(nodes)
for v in nodes: for v in nodes:
self.node[v]['state'] = T.zeros(shape) self.node[v]['state'] = tovar(T.zeros(shape))
def init_node_tag_with(self, shape, init_func, dtype=T.float32, nodes='all', args=()): def init_node_tag_with(self, shape, init_func, dtype=T.float32, nodes='all', args=()):
nodes = self._nodes_or_all(nodes) nodes = self._nodes_or_all(nodes)
for v in nodes: for v in nodes:
self.node[v]['tag'] = init_func(NN.Parameter(T.zeros(shape, dtype=dtype)), *args) self.node[v]['tag'] = init_func(NN.Parameter(tovar(T.zeros(shape, dtype=dtype))), *args)
self.register_parameter(self._node_tag_name(v), self.node[v]['tag']) self.register_parameter(self._node_tag_name(v), self.node[v]['tag'])
def init_edge_tag_with(self, shape, init_func, dtype=T.float32, edges='all', args=()): def init_edge_tag_with(self, shape, init_func, dtype=T.float32, edges='all', args=()):
edges = self._edges_or_all(edges) edges = self._edges_or_all(edges)
for u, v in edges: for u, v in edges:
self[u][v]['tag'] = init_func(NN.Parameter(T.zeros(shape, dtype=dtype)), *args) self[u][v]['tag'] = init_func(NN.Parameter(tovar(T.zeros(shape, dtype=dtype))), *args)
self.register_parameter(self._edge_tag_name(u, v), self[u][v]['tag']) self.register_parameter(self._edge_tag_name(u, v), self[u][v]['tag'])
def remove_node_tag(self, nodes='all'): def remove_node_tag(self, nodes='all'):
...@@ -115,7 +116,7 @@ class DiGraph(nx.DiGraph, NN.Module): ...@@ -115,7 +116,7 @@ class DiGraph(nx.DiGraph, NN.Module):
''' '''
batched: whether to do a single batched computation instead of iterating batched: whether to do a single batched computation instead of iterating
update function: accepts a node attribute dictionary (including state and tag), update function: accepts a node attribute dictionary (including state and tag),
and a dictionary of edge attribute dictionaries and a list of tuples (source node, target node, edge attribute dictionary)
''' '''
self.update_funcs.append((self._nodes_or_all(nodes), update_func, batched)) self.update_funcs.append((self._nodes_or_all(nodes), update_func, batched))
...@@ -126,7 +127,11 @@ class DiGraph(nx.DiGraph, NN.Module): ...@@ -126,7 +127,11 @@ class DiGraph(nx.DiGraph, NN.Module):
# FIXME: need to optimize since we are repeatedly stacking and # FIXME: need to optimize since we are repeatedly stacking and
# unpacking # unpacking
source = T.stack([self.node[u]['state'] for u, _ in ebunch]) source = T.stack([self.node[u]['state'] for u, _ in ebunch])
edge_tag = T.stack([self[u][v]['tag'] for u, v in ebunch]) edge_tags = [self[u][v]['tag'] for u, v in ebunch]
if all(t is None for t in edge_tags):
edge_tag = None
else:
edge_tag = T.stack([self[u][v]['tag'] for u, v in ebunch])
message = f(source, edge_tag) message = f(source, edge_tag)
for i, (u, v) in enumerate(ebunch): for i, (u, v) in enumerate(ebunch):
self[u][v]['state'] = message[i] self[u][v]['state'] = message[i]
...@@ -139,5 +144,6 @@ class DiGraph(nx.DiGraph, NN.Module): ...@@ -139,5 +144,6 @@ class DiGraph(nx.DiGraph, NN.Module):
# update state # update state
# TODO: does it make sense to batch update the nodes? # TODO: does it make sense to batch update the nodes?
for v, f in self.update_funcs: for vbunch, f, batched in self.update_funcs:
self.node[v]['state'] = f(self.node[v], self[v]) for v in vbunch:
self.node[v]['state'] = f(self.node[v], self.in_edges(v, data=True))
import torch as T import torch as T
import torch.nn as NN import torch.nn as NN
import networkx as nx import torch.nn.init as INIT
import torch.nn.functional as F
import numpy as NP
import numpy.random as RNG
from util import *
from glimpse import create_glimpse
from zoneout import ZoneoutLSTMCell
from collections import namedtuple
import os
from graph import DiGraph from graph import DiGraph
import networkx as nx
no_msg = os.getenv('NOMSG', False)
def build_cnn(**config):
cnn_list = []
filters = config['filters']
kernel_size = config['kernel_size']
in_channels = config.get('in_channels', 3)
final_pool_size = config['final_pool_size']
for i in range(len(filters)):
module = NN.Conv2d(
in_channels if i == 0 else filters[i-1],
filters[i],
kernel_size,
padding=tuple((_ - 1) // 2 for _ in kernel_size),
)
INIT.xavier_uniform(module.weight)
INIT.constant(module.bias, 0)
cnn_list.append(module)
if i < len(filters) - 1:
cnn_list.append(NN.LeakyReLU())
cnn_list.append(NN.AdaptiveMaxPool2d(final_pool_size))
return NN.Sequential(*cnn_list)
class TreeGlimpsedClassifier(NN.Module): class TreeGlimpsedClassifier(NN.Module):
def __init__(self, def __init__(self,
...@@ -10,9 +46,13 @@ class TreeGlimpsedClassifier(NN.Module): ...@@ -10,9 +46,13 @@ class TreeGlimpsedClassifier(NN.Module):
h_dims=128, h_dims=128,
node_tag_dims=128, node_tag_dims=128,
edge_tag_dims=128, edge_tag_dims=128,
h_dims=128,
n_classes=10, n_classes=10,
steps=5, steps=5,
filters=[16, 32, 64, 128, 256],
kernel_size=(3, 3),
final_pool_size=(2, 2),
glimpse_type='gaussian',
glimpse_size=(15, 15),
): ):
''' '''
Basic idea: Basic idea:
...@@ -33,20 +73,30 @@ class TreeGlimpsedClassifier(NN.Module): ...@@ -33,20 +73,30 @@ class TreeGlimpsedClassifier(NN.Module):
self.edge_tag_dims = edge_tag_dims self.edge_tag_dims = edge_tag_dims
self.h_dims = h_dims self.h_dims = h_dims
self.n_classes = n_classes self.n_classes = n_classes
self.glimpse = create_glimpse(glimpse_type, glimpse_size)
self.steps = steps
self.cnn = build_cnn(
filters=filters,
kernel_size=kernel_size,
final_pool_size=final_pool_size,
)
# Create graph of latent variables # Create graph of latent variables
G = nx.balanced_tree(self.n_children, self.n_depth) G = nx.balanced_tree(self.n_children, self.n_depth)
nx.relabel_nodes(G, nx.relabel_nodes(G,
{i: 'h%d' % i for i in range(self.G.nodes())}, {i: 'h%d' % i for i in range(len(G.nodes()))},
False False
) )
h_nodes_list = G.nodes() self.h_nodes_list = h_nodes_list = G.nodes()
for h in h_nodes_list: for h in h_nodes_list:
G.node[h]['type'] = 'h' G.node[h]['type'] = 'h'
b_nodes_list = ['b%d' % i for i in range(len(h_nodes_list))] b_nodes_list = ['b%d' % i for i in range(len(h_nodes_list))]
y_nodes_list = ['y%d' % i for i in range(len(h_nodes_list))] y_nodes_list = ['y%d' % i for i in range(len(h_nodes_list))]
self.b_nodes_list = b_nodes_list
self.y_nodes_list = y_nodes_list
hy_edge_list = [(h, y) for h, y in zip(h_nodes_list, y_nodes_list)] hy_edge_list = [(h, y) for h, y in zip(h_nodes_list, y_nodes_list)]
hb_edge_list = [(h, b) for h, b in zip(h_nodes_list, y_nodes_list)] hb_edge_list = [(h, b) for h, b in zip(h_nodes_list, b_nodes_list)]
yh_edge_list = [(y, h) for y, h in zip(y_nodes_list, h_nodes_list)] yh_edge_list = [(y, h) for y, h in zip(y_nodes_list, h_nodes_list)]
bh_edge_list = [(b, h) for b, h in zip(b_nodes_list, h_nodes_list)] bh_edge_list = [(b, h) for b, h in zip(b_nodes_list, h_nodes_list)]
...@@ -65,21 +115,22 @@ class TreeGlimpsedClassifier(NN.Module): ...@@ -65,21 +115,22 @@ class TreeGlimpsedClassifier(NN.Module):
edge_tag_dims, edge_tag_dims,
T.nn.init.uniform_, T.nn.init.uniform_,
args=(-.01, .01), args=(-.01, .01),
edges=hy_edge_list + hb_edge_list + yh_edge_list, bh_edge_list edges=hy_edge_list + hb_edge_list + bh_edge_list
)
self.G.init_edge_tag_with(
h_dims * n_classes,
T.nn.init.uniform_,
args=(-.01, .01),
edges=yh_edge_list
) )
# y -> h. An attention over embeddings dynamically generated through edge tags # y -> h. An attention over embeddings dynamically generated through edge tags
self.yh_emb = NN.Sequential(
NN.Linear(edge_tag_dims, h_dims),
NN.ReLU(),
NN.Linear(h_dims, n_classes * h_dims),
)
self.G.register_message_func(self._y_to_h, edges=yh_edge_list, batched=True) self.G.register_message_func(self._y_to_h, edges=yh_edge_list, batched=True)
# b -> h. Projects b and edge tag to the same dimension, then concatenates and projects to h # b -> h. Projects b and edge tag to the same dimension, then concatenates and projects to h
self.bh_1 = NN.Linear(self.glimpse.att_params, h_dims) self.bh_1 = NN.Linear(self.glimpse.att_params, h_dims)
self.bh_2 = NN.Linear(edge_tag_dims, h_dims) self.bh_2 = NN.Linear(edge_tag_dims, h_dims)
self.bh_all = NN.Linear(3 * h_dims, h_dims) self.bh_all = NN.Linear(2 * h_dims + filters[-1] * NP.prod(final_pool_size), h_dims)
self.G.register_message_func(self._b_to_h, edges=bh_edge_list, batched=True) self.G.register_message_func(self._b_to_h, edges=bh_edge_list, batched=True)
# h -> h. Just passes h itself # h -> h. Just passes h itself
...@@ -87,12 +138,12 @@ class TreeGlimpsedClassifier(NN.Module): ...@@ -87,12 +138,12 @@ class TreeGlimpsedClassifier(NN.Module):
# h -> b. Concatenates h with edge tag and go through MLP. # h -> b. Concatenates h with edge tag and go through MLP.
# Produces Δb # Produces Δb
self.hb = NN.Linear(hidden_layers + edge_tag_dims, self.glimpse.att_params) self.hb = NN.Linear(h_dims + edge_tag_dims, self.glimpse.att_params)
self.G.register_message_func(self._h_to_b, edges=hb_edge_list, batched=True) self.G.register_message_func(self._h_to_b, edges=hb_edge_list, batched=True)
# h -> y. Concatenates h with edge tag and go through MLP. # h -> y. Concatenates h with edge tag and go through MLP.
# Produces Δy # Produces Δy
self.hy = NN.Linear(hidden_layers + edge_tag_dims, self.n_classes) self.hy = NN.Linear(h_dims + edge_tag_dims, self.n_classes)
self.G.register_message_func(self._h_to_y, edges=hy_edge_list, batched=True) self.G.register_message_func(self._h_to_y, edges=hy_edge_list, batched=True)
# b update: just adds the original b by Δb # b update: just adds the original b by Δb
...@@ -111,13 +162,7 @@ class TreeGlimpsedClassifier(NN.Module): ...@@ -111,13 +162,7 @@ class TreeGlimpsedClassifier(NN.Module):
''' '''
n_yh_edges, batch_size, _ = source.shape n_yh_edges, batch_size, _ = source.shape
if not self._yh_emb_cached: w = edge_tag.reshape(n_yh_edges, 1, self.n_classes, self.h_dims)
self._yh_emb_cached = True
self._yh_emb_w = self.yh_emb(edge_tag)
self._yh_emb_w = self._yh_emb_w.reshape(
n_yh_edges, self.n_classes, self.h_dims)
w = self._yh_emb_w[:, None]
w = w.expand(n_yh_edges, batch_size, self.n_classes, self.h_dims) w = w.expand(n_yh_edges, batch_size, self.n_classes, self.h_dims)
source = source[:, :, None, :] source = source[:, :, None, :]
return (F.softmax(source) @ w).reshape(n_yh_edges, batch_size, self.h_dims) return (F.softmax(source) @ w).reshape(n_yh_edges, batch_size, self.h_dims)
...@@ -128,10 +173,11 @@ class TreeGlimpsedClassifier(NN.Module): ...@@ -128,10 +173,11 @@ class TreeGlimpsedClassifier(NN.Module):
edge_tag: (n_bh_edges, edge_tag_dims) edge_tag: (n_bh_edges, edge_tag_dims)
''' '''
n_bh_edges, batch_size, _ = source.shape n_bh_edges, batch_size, _ = source.shape
# FIXME: really using self.x is a bad design here
_, nchan, nrows, ncols = self.x.size() _, nchan, nrows, ncols = self.x.size()
source = source.reshape(-1, self.glimpse.att_params) _source = source.reshape(-1, self.glimpse.att_params)
m_b = T.relu(self.bh_1(source)) m_b = T.relu(self.bh_1(_source))
m_t = T.relu(self.bh_2(edge_tag)) m_t = T.relu(self.bh_2(edge_tag))
m_t = m_t[:, None, :].expand(n_bh_edges, batch_size, self.h_dims) m_t = m_t[:, None, :].expand(n_bh_edges, batch_size, self.h_dims)
m_t = m_t.reshape(-1, self.h_dims) m_t = m_t.reshape(-1, self.h_dims)
...@@ -140,9 +186,12 @@ class TreeGlimpsedClassifier(NN.Module): ...@@ -140,9 +186,12 @@ class TreeGlimpsedClassifier(NN.Module):
# here, the dimension of @source is n_bh_edges (# of glimpses), then # here, the dimension of @source is n_bh_edges (# of glimpses), then
# batch size, so we transpose them # batch size, so we transpose them
g = self.glimpse(self.x, source.transpose(0, 1)).transpose(0, 1) g = self.glimpse(self.x, source.transpose(0, 1)).transpose(0, 1)
g = g.reshape(n_bh_edges * batch_size, nchan, nrows, ncols) grows, gcols = g.size()[-2:]
g = g.reshape(n_bh_edges * batch_size, nchan, grows, gcols)
phi = self.cnn(g).reshape(n_bh_edges * batch_size, -1) phi = self.cnn(g).reshape(n_bh_edges * batch_size, -1)
# TODO: add an attribute (g) to h
m = self.bh_all(T.cat([m_b, m_t, phi], 1)) m = self.bh_all(T.cat([m_b, m_t, phi], 1))
m = m.reshape(n_bh_edges, batch_size, self.h_dims) m = m.reshape(n_bh_edges, batch_size, self.h_dims)
...@@ -156,40 +205,59 @@ class TreeGlimpsedClassifier(NN.Module): ...@@ -156,40 +205,59 @@ class TreeGlimpsedClassifier(NN.Module):
edge_tag = edge_tag[:, None] edge_tag = edge_tag[:, None]
edge_tag = edge_tag.expand(n_hb_edges, batch_size, self.edge_tag_dims) edge_tag = edge_tag.expand(n_hb_edges, batch_size, self.edge_tag_dims)
I = T.cat([source, edge_tag], -1).reshape(n_hb_edges * batch_size, -1) I = T.cat([source, edge_tag], -1).reshape(n_hb_edges * batch_size, -1)
b = self.hb(I) db = self.hb(I)
return db return db.reshape(n_hb_edges, batch_size, -1)
def _h_to_y(self, source, edge_tag): def _h_to_y(self, source, edge_tag):
n_hy_edges, batch_size, _ = source.shape n_hy_edges, batch_size, _ = source.shape
edge_tag = edge_tag[:, None] edge_tag = edge_tag[:, None]
edge_tag = edge_tag.expand(n_hb_edges, batch_size, self.edge_tag_dims) edge_tag = edge_tag.expand(n_hy_edges, batch_size, self.edge_tag_dims)
I = T.cat([source, edge_tag], -1).reshape(n_hb_edges * batch_size, -1) I = T.cat([source, edge_tag], -1).reshape(n_hy_edges * batch_size, -1)
y = self.hy(I) dy = self.hy(I)
return dy return dy.reshape(n_hy_edges, batch_size, -1)
def _update_b(self, b, b_n): def _update_b(self, b, b_n):
return b['state'] + list(b_n.values())[0]['state'] return b['state'] + b_n[0][2]['state']
def _update_y(self, y, y_n): def _update_y(self, y, y_n):
return y['state'] + list(y_n.values())[0]['state'] return y['state'] + y_n[0][2]['state']
def _update_h(self, h, h_n): def _update_h(self, h, h_n):
m = T.stack([e['state'] for e in h_n]).mean(0) m = T.stack([e[2]['state'] for e in h_n]).mean(0)
return T.relu(h + m) return T.relu(h['state'] + m)
def forward(self, x): def forward(self, x, y=None):
self.x = x
batch_size = x.shape[0] batch_size = x.shape[0]
self.G.zero_node_state(self.h_dims, batch_size, nodes=h_nodes_list) self.G.zero_node_state((self.h_dims,), batch_size, nodes=self.h_nodes_list)
self.G.zero_node_state((self.n_classes,), batch_size, nodes=self.y_nodes_list)
full = self.glimpse.full().unsqueeze(0).expand(batch_size, self.glimpse.att_params) full = self.glimpse.full().unsqueeze(0).expand(batch_size, self.glimpse.att_params)
for v in self.G.nodes(): for v in self.G.nodes():
if G.node[v]['type'] == 'b': if self.G.node[v]['type'] == 'b':
# Initialize bbox variables to cover the entire canvas # Initialize bbox variables to cover the entire canvas
self.G.node[v]['state'] = full self.G.node[v]['state'] = full
self._yh_emb_cached = False
for t in range(self.steps): for t in range(self.steps):
self.G.step() self.G.step()
# We don't change b of the root # We don't change b of the root
self.G.node['b0']['state'] = full self.G.node['b0']['state'] = full
self.y_pre = T.stack(
[self.G.node['y%d' % i]['state'] for i in range(self.n_nodes - 1, self.n_nodes - self.n_leaves - 1, -1)],
1
)
self.v_B = T.stack(
[self.G.node['b%d' % i]['state'] for i in range(self.n_nodes)],
1,
)
self.y_logprob = F.log_softmax(self.y_pre)
return self.G.node['h0']['state']
@property
def n_nodes(self):
return (self.n_children ** self.n_depth - 1) // (self.n_children - 1)
@property
def n_leaves(self):
return self.n_children ** (self.n_depth - 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment