"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "47c9380555fa1dd3c8c29ee544777a0464e965f6"
Commit 51391012 authored by Gan Quan's avatar Gan Quan
Browse files

graph and model change

parent 196e6a92
import networkx as nx import networkx as nx
import torch as T import torch as T
import torch.nn as NN import torch.nn as NN
from util import *
class DiGraph(nx.DiGraph, NN.Module): class DiGraph(nx.DiGraph, NN.Module):
''' '''
...@@ -25,7 +26,7 @@ class DiGraph(nx.DiGraph, NN.Module): ...@@ -25,7 +26,7 @@ class DiGraph(nx.DiGraph, NN.Module):
def add_edge(self, u, v, tag=None, attr_dict=None, **attr): def add_edge(self, u, v, tag=None, attr_dict=None, **attr):
nx.DiGraph.add_edge(self, u, v, tag=tag, attr_dict=attr_dict, **attr) nx.DiGraph.add_edge(self, u, v, tag=tag, attr_dict=attr_dict, **attr)
def add_edges_from(self, ebunch, tag=tag, attr_dict=None, **attr): def add_edges_from(self, ebunch, tag=None, attr_dict=None, **attr):
nx.DiGraph.add_edges_from(self, ebunch, tag=tag, attr_dict=attr_dict, **attr) nx.DiGraph.add_edges_from(self, ebunch, tag=tag, attr_dict=attr_dict, **attr)
def _nodes_or_all(self, nodes='all'): def _nodes_or_all(self, nodes='all'):
...@@ -49,20 +50,20 @@ class DiGraph(nx.DiGraph, NN.Module): ...@@ -49,20 +50,20 @@ class DiGraph(nx.DiGraph, NN.Module):
nodes = self._nodes_or_all(nodes) nodes = self._nodes_or_all(nodes)
for v in nodes: for v in nodes:
self.node[v]['state'] = T.zeros(shape) self.node[v]['state'] = tovar(T.zeros(shape))
def init_node_tag_with(self, shape, init_func, dtype=T.float32, nodes='all', args=()): def init_node_tag_with(self, shape, init_func, dtype=T.float32, nodes='all', args=()):
nodes = self._nodes_or_all(nodes) nodes = self._nodes_or_all(nodes)
for v in nodes: for v in nodes:
self.node[v]['tag'] = init_func(NN.Parameter(T.zeros(shape, dtype=dtype)), *args) self.node[v]['tag'] = init_func(NN.Parameter(tovar(T.zeros(shape, dtype=dtype))), *args)
self.register_parameter(self._node_tag_name(v), self.node[v]['tag']) self.register_parameter(self._node_tag_name(v), self.node[v]['tag'])
def init_edge_tag_with(self, shape, init_func, dtype=T.float32, edges='all', args=()): def init_edge_tag_with(self, shape, init_func, dtype=T.float32, edges='all', args=()):
edges = self._edges_or_all(edges) edges = self._edges_or_all(edges)
for u, v in edges: for u, v in edges:
self[u][v]['tag'] = init_func(NN.Parameter(T.zeros(shape, dtype=dtype)), *args) self[u][v]['tag'] = init_func(NN.Parameter(tovar(T.zeros(shape, dtype=dtype))), *args)
self.register_parameter(self._edge_tag_name(u, v), self[u][v]['tag']) self.register_parameter(self._edge_tag_name(u, v), self[u][v]['tag'])
def remove_node_tag(self, nodes='all'): def remove_node_tag(self, nodes='all'):
...@@ -115,7 +116,7 @@ class DiGraph(nx.DiGraph, NN.Module): ...@@ -115,7 +116,7 @@ class DiGraph(nx.DiGraph, NN.Module):
''' '''
batched: whether to do a single batched computation instead of iterating batched: whether to do a single batched computation instead of iterating
update function: accepts a node attribute dictionary (including state and tag), update function: accepts a node attribute dictionary (including state and tag),
and a dictionary of edge attribute dictionaries and a list of tuples (source node, target node, edge attribute dictionary)
''' '''
self.update_funcs.append((self._nodes_or_all(nodes), update_func, batched)) self.update_funcs.append((self._nodes_or_all(nodes), update_func, batched))
...@@ -126,7 +127,11 @@ class DiGraph(nx.DiGraph, NN.Module): ...@@ -126,7 +127,11 @@ class DiGraph(nx.DiGraph, NN.Module):
# FIXME: need to optimize since we are repeatedly stacking and # FIXME: need to optimize since we are repeatedly stacking and
# unpacking # unpacking
source = T.stack([self.node[u]['state'] for u, _ in ebunch]) source = T.stack([self.node[u]['state'] for u, _ in ebunch])
edge_tag = T.stack([self[u][v]['tag'] for u, v in ebunch]) edge_tags = [self[u][v]['tag'] for u, v in ebunch]
if all(t is None for t in edge_tags):
edge_tag = None
else:
edge_tag = T.stack([self[u][v]['tag'] for u, v in ebunch])
message = f(source, edge_tag) message = f(source, edge_tag)
for i, (u, v) in enumerate(ebunch): for i, (u, v) in enumerate(ebunch):
self[u][v]['state'] = message[i] self[u][v]['state'] = message[i]
...@@ -139,5 +144,6 @@ class DiGraph(nx.DiGraph, NN.Module): ...@@ -139,5 +144,6 @@ class DiGraph(nx.DiGraph, NN.Module):
# update state # update state
# TODO: does it make sense to batch update the nodes? # TODO: does it make sense to batch update the nodes?
for v, f in self.update_funcs: for vbunch, f, batched in self.update_funcs:
self.node[v]['state'] = f(self.node[v], self[v]) for v in vbunch:
self.node[v]['state'] = f(self.node[v], self.in_edges(v, data=True))
import torch as T import torch as T
import torch.nn as NN import torch.nn as NN
import networkx as nx import torch.nn.init as INIT
import torch.nn.functional as F
import numpy as NP
import numpy.random as RNG
from util import *
from glimpse import create_glimpse
from zoneout import ZoneoutLSTMCell
from collections import namedtuple
import os
from graph import DiGraph from graph import DiGraph
import networkx as nx
no_msg = os.getenv('NOMSG', False)
def build_cnn(**config):
cnn_list = []
filters = config['filters']
kernel_size = config['kernel_size']
in_channels = config.get('in_channels', 3)
final_pool_size = config['final_pool_size']
for i in range(len(filters)):
module = NN.Conv2d(
in_channels if i == 0 else filters[i-1],
filters[i],
kernel_size,
padding=tuple((_ - 1) // 2 for _ in kernel_size),
)
INIT.xavier_uniform(module.weight)
INIT.constant(module.bias, 0)
cnn_list.append(module)
if i < len(filters) - 1:
cnn_list.append(NN.LeakyReLU())
cnn_list.append(NN.AdaptiveMaxPool2d(final_pool_size))
return NN.Sequential(*cnn_list)
class TreeGlimpsedClassifier(NN.Module): class TreeGlimpsedClassifier(NN.Module):
def __init__(self, def __init__(self,
...@@ -10,9 +46,13 @@ class TreeGlimpsedClassifier(NN.Module): ...@@ -10,9 +46,13 @@ class TreeGlimpsedClassifier(NN.Module):
h_dims=128, h_dims=128,
node_tag_dims=128, node_tag_dims=128,
edge_tag_dims=128, edge_tag_dims=128,
h_dims=128,
n_classes=10, n_classes=10,
steps=5, steps=5,
filters=[16, 32, 64, 128, 256],
kernel_size=(3, 3),
final_pool_size=(2, 2),
glimpse_type='gaussian',
glimpse_size=(15, 15),
): ):
''' '''
Basic idea: Basic idea:
...@@ -33,20 +73,30 @@ class TreeGlimpsedClassifier(NN.Module): ...@@ -33,20 +73,30 @@ class TreeGlimpsedClassifier(NN.Module):
self.edge_tag_dims = edge_tag_dims self.edge_tag_dims = edge_tag_dims
self.h_dims = h_dims self.h_dims = h_dims
self.n_classes = n_classes self.n_classes = n_classes
self.glimpse = create_glimpse(glimpse_type, glimpse_size)
self.steps = steps
self.cnn = build_cnn(
filters=filters,
kernel_size=kernel_size,
final_pool_size=final_pool_size,
)
# Create graph of latent variables # Create graph of latent variables
G = nx.balanced_tree(self.n_children, self.n_depth) G = nx.balanced_tree(self.n_children, self.n_depth)
nx.relabel_nodes(G, nx.relabel_nodes(G,
{i: 'h%d' % i for i in range(self.G.nodes())}, {i: 'h%d' % i for i in range(len(G.nodes()))},
False False
) )
h_nodes_list = G.nodes() self.h_nodes_list = h_nodes_list = G.nodes()
for h in h_nodes_list: for h in h_nodes_list:
G.node[h]['type'] = 'h' G.node[h]['type'] = 'h'
b_nodes_list = ['b%d' % i for i in range(len(h_nodes_list))] b_nodes_list = ['b%d' % i for i in range(len(h_nodes_list))]
y_nodes_list = ['y%d' % i for i in range(len(h_nodes_list))] y_nodes_list = ['y%d' % i for i in range(len(h_nodes_list))]
self.b_nodes_list = b_nodes_list
self.y_nodes_list = y_nodes_list
hy_edge_list = [(h, y) for h, y in zip(h_nodes_list, y_nodes_list)] hy_edge_list = [(h, y) for h, y in zip(h_nodes_list, y_nodes_list)]
hb_edge_list = [(h, b) for h, b in zip(h_nodes_list, y_nodes_list)] hb_edge_list = [(h, b) for h, b in zip(h_nodes_list, b_nodes_list)]
yh_edge_list = [(y, h) for y, h in zip(y_nodes_list, h_nodes_list)] yh_edge_list = [(y, h) for y, h in zip(y_nodes_list, h_nodes_list)]
bh_edge_list = [(b, h) for b, h in zip(b_nodes_list, h_nodes_list)] bh_edge_list = [(b, h) for b, h in zip(b_nodes_list, h_nodes_list)]
...@@ -65,21 +115,22 @@ class TreeGlimpsedClassifier(NN.Module): ...@@ -65,21 +115,22 @@ class TreeGlimpsedClassifier(NN.Module):
edge_tag_dims, edge_tag_dims,
T.nn.init.uniform_, T.nn.init.uniform_,
args=(-.01, .01), args=(-.01, .01),
edges=hy_edge_list + hb_edge_list + yh_edge_list, bh_edge_list edges=hy_edge_list + hb_edge_list + bh_edge_list
)
self.G.init_edge_tag_with(
h_dims * n_classes,
T.nn.init.uniform_,
args=(-.01, .01),
edges=yh_edge_list
) )
# y -> h. An attention over embeddings dynamically generated through edge tags # y -> h. An attention over embeddings dynamically generated through edge tags
self.yh_emb = NN.Sequential(
NN.Linear(edge_tag_dims, h_dims),
NN.ReLU(),
NN.Linear(h_dims, n_classes * h_dims),
)
self.G.register_message_func(self._y_to_h, edges=yh_edge_list, batched=True) self.G.register_message_func(self._y_to_h, edges=yh_edge_list, batched=True)
# b -> h. Projects b and edge tag to the same dimension, then concatenates and projects to h # b -> h. Projects b and edge tag to the same dimension, then concatenates and projects to h
self.bh_1 = NN.Linear(self.glimpse.att_params, h_dims) self.bh_1 = NN.Linear(self.glimpse.att_params, h_dims)
self.bh_2 = NN.Linear(edge_tag_dims, h_dims) self.bh_2 = NN.Linear(edge_tag_dims, h_dims)
self.bh_all = NN.Linear(3 * h_dims, h_dims) self.bh_all = NN.Linear(2 * h_dims + filters[-1] * NP.prod(final_pool_size), h_dims)
self.G.register_message_func(self._b_to_h, edges=bh_edge_list, batched=True) self.G.register_message_func(self._b_to_h, edges=bh_edge_list, batched=True)
# h -> h. Just passes h itself # h -> h. Just passes h itself
...@@ -87,12 +138,12 @@ class TreeGlimpsedClassifier(NN.Module): ...@@ -87,12 +138,12 @@ class TreeGlimpsedClassifier(NN.Module):
# h -> b. Concatenates h with edge tag and go through MLP. # h -> b. Concatenates h with edge tag and go through MLP.
# Produces Δb # Produces Δb
self.hb = NN.Linear(hidden_layers + edge_tag_dims, self.glimpse.att_params) self.hb = NN.Linear(h_dims + edge_tag_dims, self.glimpse.att_params)
self.G.register_message_func(self._h_to_b, edges=hb_edge_list, batched=True) self.G.register_message_func(self._h_to_b, edges=hb_edge_list, batched=True)
# h -> y. Concatenates h with edge tag and go through MLP. # h -> y. Concatenates h with edge tag and go through MLP.
# Produces Δy # Produces Δy
self.hy = NN.Linear(hidden_layers + edge_tag_dims, self.n_classes) self.hy = NN.Linear(h_dims + edge_tag_dims, self.n_classes)
self.G.register_message_func(self._h_to_y, edges=hy_edge_list, batched=True) self.G.register_message_func(self._h_to_y, edges=hy_edge_list, batched=True)
# b update: just adds the original b by Δb # b update: just adds the original b by Δb
...@@ -111,13 +162,7 @@ class TreeGlimpsedClassifier(NN.Module): ...@@ -111,13 +162,7 @@ class TreeGlimpsedClassifier(NN.Module):
''' '''
n_yh_edges, batch_size, _ = source.shape n_yh_edges, batch_size, _ = source.shape
if not self._yh_emb_cached: w = edge_tag.reshape(n_yh_edges, 1, self.n_classes, self.h_dims)
self._yh_emb_cached = True
self._yh_emb_w = self.yh_emb(edge_tag)
self._yh_emb_w = self._yh_emb_w.reshape(
n_yh_edges, self.n_classes, self.h_dims)
w = self._yh_emb_w[:, None]
w = w.expand(n_yh_edges, batch_size, self.n_classes, self.h_dims) w = w.expand(n_yh_edges, batch_size, self.n_classes, self.h_dims)
source = source[:, :, None, :] source = source[:, :, None, :]
return (F.softmax(source) @ w).reshape(n_yh_edges, batch_size, self.h_dims) return (F.softmax(source) @ w).reshape(n_yh_edges, batch_size, self.h_dims)
...@@ -128,10 +173,11 @@ class TreeGlimpsedClassifier(NN.Module): ...@@ -128,10 +173,11 @@ class TreeGlimpsedClassifier(NN.Module):
edge_tag: (n_bh_edges, edge_tag_dims) edge_tag: (n_bh_edges, edge_tag_dims)
''' '''
n_bh_edges, batch_size, _ = source.shape n_bh_edges, batch_size, _ = source.shape
# FIXME: really using self.x is a bad design here
_, nchan, nrows, ncols = self.x.size() _, nchan, nrows, ncols = self.x.size()
source = source.reshape(-1, self.glimpse.att_params) _source = source.reshape(-1, self.glimpse.att_params)
m_b = T.relu(self.bh_1(source)) m_b = T.relu(self.bh_1(_source))
m_t = T.relu(self.bh_2(edge_tag)) m_t = T.relu(self.bh_2(edge_tag))
m_t = m_t[:, None, :].expand(n_bh_edges, batch_size, self.h_dims) m_t = m_t[:, None, :].expand(n_bh_edges, batch_size, self.h_dims)
m_t = m_t.reshape(-1, self.h_dims) m_t = m_t.reshape(-1, self.h_dims)
...@@ -140,9 +186,12 @@ class TreeGlimpsedClassifier(NN.Module): ...@@ -140,9 +186,12 @@ class TreeGlimpsedClassifier(NN.Module):
# here, the dimension of @source is n_bh_edges (# of glimpses), then # here, the dimension of @source is n_bh_edges (# of glimpses), then
# batch size, so we transpose them # batch size, so we transpose them
g = self.glimpse(self.x, source.transpose(0, 1)).transpose(0, 1) g = self.glimpse(self.x, source.transpose(0, 1)).transpose(0, 1)
g = g.reshape(n_bh_edges * batch_size, nchan, nrows, ncols) grows, gcols = g.size()[-2:]
g = g.reshape(n_bh_edges * batch_size, nchan, grows, gcols)
phi = self.cnn(g).reshape(n_bh_edges * batch_size, -1) phi = self.cnn(g).reshape(n_bh_edges * batch_size, -1)
# TODO: add an attribute (g) to h
m = self.bh_all(T.cat([m_b, m_t, phi], 1)) m = self.bh_all(T.cat([m_b, m_t, phi], 1))
m = m.reshape(n_bh_edges, batch_size, self.h_dims) m = m.reshape(n_bh_edges, batch_size, self.h_dims)
...@@ -156,40 +205,59 @@ class TreeGlimpsedClassifier(NN.Module): ...@@ -156,40 +205,59 @@ class TreeGlimpsedClassifier(NN.Module):
edge_tag = edge_tag[:, None] edge_tag = edge_tag[:, None]
edge_tag = edge_tag.expand(n_hb_edges, batch_size, self.edge_tag_dims) edge_tag = edge_tag.expand(n_hb_edges, batch_size, self.edge_tag_dims)
I = T.cat([source, edge_tag], -1).reshape(n_hb_edges * batch_size, -1) I = T.cat([source, edge_tag], -1).reshape(n_hb_edges * batch_size, -1)
b = self.hb(I) db = self.hb(I)
return db return db.reshape(n_hb_edges, batch_size, -1)
def _h_to_y(self, source, edge_tag): def _h_to_y(self, source, edge_tag):
n_hy_edges, batch_size, _ = source.shape n_hy_edges, batch_size, _ = source.shape
edge_tag = edge_tag[:, None] edge_tag = edge_tag[:, None]
edge_tag = edge_tag.expand(n_hb_edges, batch_size, self.edge_tag_dims) edge_tag = edge_tag.expand(n_hy_edges, batch_size, self.edge_tag_dims)
I = T.cat([source, edge_tag], -1).reshape(n_hb_edges * batch_size, -1) I = T.cat([source, edge_tag], -1).reshape(n_hy_edges * batch_size, -1)
y = self.hy(I) dy = self.hy(I)
return dy return dy.reshape(n_hy_edges, batch_size, -1)
def _update_b(self, b, b_n): def _update_b(self, b, b_n):
return b['state'] + list(b_n.values())[0]['state'] return b['state'] + b_n[0][2]['state']
def _update_y(self, y, y_n): def _update_y(self, y, y_n):
return y['state'] + list(y_n.values())[0]['state'] return y['state'] + y_n[0][2]['state']
def _update_h(self, h, h_n): def _update_h(self, h, h_n):
m = T.stack([e['state'] for e in h_n]).mean(0) m = T.stack([e[2]['state'] for e in h_n]).mean(0)
return T.relu(h + m) return T.relu(h['state'] + m)
def forward(self, x): def forward(self, x, y=None):
self.x = x
batch_size = x.shape[0] batch_size = x.shape[0]
self.G.zero_node_state(self.h_dims, batch_size, nodes=h_nodes_list) self.G.zero_node_state((self.h_dims,), batch_size, nodes=self.h_nodes_list)
self.G.zero_node_state((self.n_classes,), batch_size, nodes=self.y_nodes_list)
full = self.glimpse.full().unsqueeze(0).expand(batch_size, self.glimpse.att_params) full = self.glimpse.full().unsqueeze(0).expand(batch_size, self.glimpse.att_params)
for v in self.G.nodes(): for v in self.G.nodes():
if G.node[v]['type'] == 'b': if self.G.node[v]['type'] == 'b':
# Initialize bbox variables to cover the entire canvas # Initialize bbox variables to cover the entire canvas
self.G.node[v]['state'] = full self.G.node[v]['state'] = full
self._yh_emb_cached = False
for t in range(self.steps): for t in range(self.steps):
self.G.step() self.G.step()
# We don't change b of the root # We don't change b of the root
self.G.node['b0']['state'] = full self.G.node['b0']['state'] = full
self.y_pre = T.stack(
[self.G.node['y%d' % i]['state'] for i in range(self.n_nodes - 1, self.n_nodes - self.n_leaves - 1, -1)],
1
)
self.v_B = T.stack(
[self.G.node['b%d' % i]['state'] for i in range(self.n_nodes)],
1,
)
self.y_logprob = F.log_softmax(self.y_pre)
return self.G.node['h0']['state']
@property
def n_nodes(self):
return (self.n_children ** self.n_depth - 1) // (self.n_children - 1)
@property
def n_leaves(self):
return self.n_children ** (self.n_depth - 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment