Commit 2455934a authored by Minjie Wang's avatar Minjie Wang
Browse files

checkin runnable topdown model in the new API

parent 8e2e68df
from .mnist import MNISTMulti
from .wrapper import wrap_output
import torch as T
from torch.utils.data import Dataset
from torchvision.datasets import MNIST
from itertools import product
from util import *
import os
import cv2
import numpy as NP
import numpy.random as RNG
def mnist_bbox(data):
n_rows, n_cols = data.size()
rowwise_max = data.max(0)[0]
colwise_max = data.max(1)[0]
rowwise_max_mask = rowwise_max == 0
colwise_max_mask = colwise_max == 0
left = T.cumprod(rowwise_max_mask, 0).sum()
top = T.cumprod(colwise_max_mask, 0).sum()
right = n_cols - T.cumprod(reverse(rowwise_max_mask, 0), 0).sum()
bottom = n_rows - T.cumprod(reverse(colwise_max_mask, 0), 0).sum()
x = (left + right) / 2
y = (top + bottom) / 2
w = right - left
h = bottom - top
return T.FloatTensor([x, y, w, h])
class MNISTMulti(Dataset):
dir_ = 'multi'
seeds = {'train': 1000, 'valid': 2000, 'test': 3000}
attr_prefix = {'train': 'training', 'valid': 'valid', 'test': 'test'}
n_classes = 10
@property
def _meta(self):
return '%d-%d-%d-%d.pt' % (
self.image_rows,
self.image_cols,
self.n_digits,
self.backrand)
@property
def training_file(self):
return os.path.join(self.dir_, 'training-' + self._meta)
@property
def test_file(self):
return os.path.join(self.dir_, 'test-' + self._meta)
@property
def valid_file(self):
return os.path.join(self.dir_, 'valid-' + self._meta)
def __init__(self,
root,
mode='train',
transform=None,
target_transform=None,
download=False,
image_rows=100,
image_cols=100,
n_digits=1,
size_multiplier=1,
backrand=0):
self.mode = mode
self.image_rows = image_rows
self.image_cols = image_cols
self.n_digits = n_digits
self.backrand = backrand
if os.path.exists(self.dir_):
if os.path.isfile(self.dir_):
raise NotADirectoryError(self.dir_)
elif os.path.exists(getattr(self, self.attr_prefix[mode] + '_file')):
data = T.load(getattr(self, self.attr_prefix[mode] + '_file'))
for k in data:
setattr(self, mode + '_' + k, data[k])
self.size = getattr(self, mode + '_data').size()[0]
return
elif not os.path.exists(self.dir_):
os.makedirs(self.dir_)
valid_src_size = 10000 // n_digits
for _mode in ['train', 'valid', 'test']:
_train = (_mode != 'test')
mnist = MNIST(root, _train, transform, target_transform, download)
if _mode == 'train':
src_data = mnist.train_data[:-valid_src_size]
src_labels = mnist.train_labels[:-valid_src_size]
elif _mode == 'valid':
src_data = mnist.train_data[-valid_src_size:]
src_labels = mnist.train_labels[-valid_src_size:]
elif _mode == 'test':
src_data = mnist.test_data
src_labels = mnist.test_labels
with T.random.fork_rng():
T.random.manual_seed(self.seeds[_mode])
n_samples, n_rows, n_cols = src_data.size()
n_new_samples = n_samples * n_digits
data = T.ByteTensor(n_new_samples, image_rows, image_cols).zero_()
labels = T.LongTensor(n_new_samples, n_digits).zero_()
locs = T.LongTensor(n_new_samples, n_digits, 4).zero_()
for i, j in product(range(n_digits), range(n_digits * size_multiplier)):
pos_rows = (T.LongTensor(n_samples).random_() %
(image_rows - n_rows))
pos_cols = (T.LongTensor(n_samples).random_() %
(image_cols - n_cols))
perm = T.randperm(n_samples)
for k, idx in zip(
range(n_samples * j, n_samples * (j + 1)), perm):
cur_rows = RNG.randint(n_rows // 3 * 2, n_rows)
cur_cols = RNG.randint(n_rows // 3 * 2, n_cols)
row = RNG.randint(image_rows - cur_rows)
col = RNG.randint(image_cols - cur_cols)
cur_data = T.from_numpy(
cv2.resize(
src_data[idx].numpy(),
(cur_cols, cur_rows))
)
data[k, row:row+cur_rows, col:col+cur_cols][cur_data != 0] = cur_data[cur_data != 0]
labels[k, i] = src_labels[idx]
locs[k, i] = mnist_bbox(cur_data)
locs[k, i, 0] += col
locs[k, i, 1] += row
if backrand:
data += (data.new(*data.size()).random_() % backrand) * (data == 0)
T.save({
'data': data,
'labels': labels,
'locs': locs,
}, getattr(self, self.attr_prefix[_mode] + '_file'))
if _mode == mode:
setattr(self, mode + '_data', data)
setattr(self, mode + '_labels', labels)
setattr(self, mode + '_locs', locs)
self.size = data.size()[0]
def __len__(self):
return self.size
def __getitem__(self, i):
return tuple(getattr(self, self.mode + '_' + k)[i] for k in ['data', 'labels', 'locs'])
from torch.utils.data import DataLoader
from functools import wraps
def wrap_output(dataloader, output_wrapper):
def wrapped_collate_fn(old_collate_fn):
@wraps(old_collate_fn)
def new_collate_fn(input_):
output = old_collate_fn(input_)
return output_wrapper(*output)
return new_collate_fn
dataloader.collate_fn = wrapped_collate_fn(dataloader.collate_fn)
return dataloader
import torch as T
import torch.nn.functional as F
from torch.distributions import Normal
class LogNormal(Normal):
def sample(self):
x = Normal.sample(self)
return T.exp(x)
def sample_n(self, n):
x = Normal.sample_n(self, n)
return T.exp(x)
def log_prob(self, x):
y = T.log(x)
return Normal.log_prob(self, y) - y
class SigmoidNormal(Normal):
def sample(self):
x = Normal.sample(self)
return F.sigmoid(x)
def sample_n(self, n):
x = Normal.sample_n(self, n)
return F.sigmoid(x)
def log_prob(self, x):
# sigmoid^{-1}(x) = log(x) - log(1 - x)
y = T.log(x + 1e-8) - T.log(1 - x + 1e-8)
return Normal.log_prob(self, y) - T.log(x + 1e-8) - T.log(1 - x + 1e-8)
import torch as T
import torch.nn.functional as F
import torch.nn as NN
from util import *
from distributions import LogNormal, SigmoidNormal
def gaussian_masks(c, d, s, len_, glim_len):
'''
c, d, s: 2D Tensor (batch_size, n_glims)
len_, glim_len: int
returns: 4D Tensor (batch_size, n_glims, glim_len, len_)
each row is a 1D Gaussian
'''
batch_size, n_glims = c.size()
# The original HART code did not shift the coordinates by
# glim_len / 2. The generated Gaussian attention does not
# correspond to the actual crop of the bbox.
# Possibly a bug?
R = tovar(T.arange(0, glim_len).view(1, 1, 1, -1) - glim_len / 2)
C = T.arange(0, len_).view(1, 1, -1, 1)
C = C.expand(batch_size, n_glims, len_, 1)
C = tovar(C)
c = c[:, :, None, None]
d = d[:, :, None, None]
s = s[:, :, None, None]
cr = c + R * d
#sr = tovar(T.ones(cr.size())) * s
sr = s
mask = C - cr
mask = (-0.5 * (mask / sr) ** 2).exp()
mask = mask / (mask.sum(2, keepdim=True) + 1e-8)
return mask
def extract_gaussian_glims(x, a, glim_size):
'''
x: 4D Tensor (batch_size, nchannels, nrows, ncols)
a: 3D Tensor (batch_size, n_glims, att_params)
att_params: (cx, cy, dx, dy, sx, sy)
returns:
5D Tensor (batch_size, n_glims, nchannels, n_glim_rows, n_glim_cols)
'''
batch_size, n_glims, _ = a.size()
cx, cy, dx, dy, sx, sy = T.unbind(a, -1)
_, nchannels, nrows, ncols = x.size()
n_glim_rows, n_glim_cols = glim_size
# (batch_size, n_glims, nrows, n_glim_rows)
Fy = gaussian_masks(cy, dy, sy, nrows, n_glim_rows)
# (batch_size, n_glims, ncols, n_glim_cols)
Fx = gaussian_masks(cx, dx, sx, ncols, n_glim_cols)
# (batch_size, n_glims, 1, nrows, n_glim_rows)
Fy = Fy.unsqueeze(2)
# (batch_size, n_glims, 1, ncols, n_glim_cols)
Fx = Fx.unsqueeze(2)
# (batch_size, 1, nchannels, nrows, ncols)
x = x.unsqueeze(1)
# (batch_size, n_glims, nchannels, n_glim_rows, n_glim_cols)
g = Fy.transpose(-1, -2) @ x @ Fx
return g
softplus_zero = F.softplus(tovar([0]))
class GaussianGlimpse(NN.Module):
att_params = 6
def __init__(self, glim_size):
NN.Module.__init__(self)
self.glim_size = glim_size
@classmethod
def full(cls):
return tovar([0.5, 0.5, 1, 1, 0.5, 0.5])
#return tovar([0.5, 0.5, 1, 1, 0.1, 0.1])
@classmethod
def rescale(cls, x, glimpse_sample):
if not glimpse_sample:
y = [
#F.sigmoid(x[..., 0]), # cx
#F.sigmoid(x[..., 1]), # cy
#F.sigmoid(x[..., 2]) * 2,
#F.sigmoid(x[..., 3]) * 2,
#F.sigmoid(x[..., 4]),
#F.sigmoid(x[..., 5]),
x[..., 0] + 0.5,
x[..., 1] + 0.5,
x[..., 2] + 1,
x[..., 3] + 1,
F.sigmoid(x[..., 4]),
F.sigmoid(x[..., 5]),
#T.zeros_like(x[..., 4]) + 0.1,
#T.zeros_like(x[..., 5]) + 0.1,
]
logprob = 0
else:
y = [
F.sigmoid(x[..., 0]), # cx
F.sigmoid(x[..., 1]), # cy
F.sigmoid(x[..., 2]) * 2,
F.sigmoid(x[..., 3]) * 2,
T.zeros_like(x[..., 4]),
T.zeros_like(x[..., 5]),
]
diag = T.stack([
y[0] - y[2] / 2,
y[1] - y[3] / 2,
y[0] + y[2] / 2,
y[1] + y[3] / 2,
], -1)
diagN = T.distributions.Normal(
diag, T.ones_like(diag) * 0.1)
diag = diagN.sample()
diag_logprob = diagN.log_prob(diag)
s = F.sigmoid(T.stack([y[4], y[5]], -1))
#sSN = SigmoidNormal(s, T.ones_like(s) * 0.05)
#s = sSN.sample()
#s_logprob = sSN.log_prob(s)
s_logprob = T.zeros_like(s)
y = [
(diag[..., 0] + diag[..., 2]) / 2,
(diag[..., 1] + diag[..., 3]) / 2,
diag[..., 2] - diag[..., 0],
diag[..., 3] - diag[..., 1],
s[..., 0],
s[..., 1],
]
logprob = T.cat([diag_logprob, s_logprob], -1)
return T.stack(y, -1), logprob
@classmethod
def absolute_to_relative(cls, att, absolute):
C_x, C_y, D_x, D_y, S_x, S_y = T.unbind(absolute, -1)
c_x, c_y, d_x, d_y, s_x, s_y = T.unbind(att, -1)
return T.stack([
(c_x - C_x) / D_x + 0.5,
(c_y - C_y) / D_y + 0.5,
d_x / D_x,
d_y / D_y,
s_x / D_x,
s_y / D_y,
], -1)
@classmethod
def relative_to_absolute(cls, att, relative):
C_x, C_y, D_x, D_y, S_x, S_y = T.unbind(relative, -1)
c_x, c_y, d_x, d_y, s_x, s_y = T.unbind(att, -1)
return T.stack([
(c_x - 0.5) * D_x + C_x,
(c_y - 0.5) * D_y + C_y,
d_x * D_x,
d_y * D_y,
s_x * D_x,
s_y * D_y
], -1)
def forward(self, x, spatial_att):
'''
x: 4D Tensor (batch_size, nchannels, n_image_rows, n_image_cols)
spatial_att: 3D Tensor (batch_size, n_glims, att_params) relative scales
'''
# (batch_size, n_glims, att_params)
absolute_att = self._to_absolute_attention(spatial_att, x.size()[-2:])
glims = extract_gaussian_glims(x, absolute_att, self.glim_size)
return glims
def att_to_bbox(self, spatial_att, x_size):
'''
spatial_att: (..., 6) [cx, cy, dx, dy, sx, sy] relative scales ]0, 1[
return: (..., 4) [cx, cy, w, h] absolute scales
'''
cx = spatial_att[..., 0] * x_size[1]
cy = spatial_att[..., 1] * x_size[0]
w = T.abs(spatial_att[..., 2]) * (x_size[1] - 1)
h = T.abs(spatial_att[..., 3]) * (x_size[0] - 1)
bbox = T.stack([cx, cy, w, h], -1)
return bbox
def bbox_to_att(self, bbox, x_size):
'''
bbox: (..., 4) [cx, cy, w, h] absolute scales
return: (..., 6) [cx, cy, dx, dy, sx, sy] relative scales ]0, 1[
'''
cx = bbox[..., 0] / x_size[1]
cy = bbox[..., 1] / x_size[0]
dx = bbox[..., 2] / (x_size[1] - 1)
dy = bbox[..., 3] / (x_size[0] - 1)
sx = bbox[..., 2] * 0.5 / x_size[1]
sy = bbox[..., 3] * 0.5 / x_size[0]
spatial_att = T.stack([cx, cy, dx, dy, sx, sy], -1)
return spatial_att
def _to_axis_attention(self, image_len, glim_len, c, d, s):
c = c * image_len
d = d * (image_len - 1) / (glim_len - 1)
s = (s + 1e-5) * image_len / glim_len
return c, d, s
def _to_absolute_attention(self, params, x_size):
'''
params: 3D Tensor (batch_size, n_glims, att_params)
'''
n_image_rows, n_image_cols = x_size
n_glim_rows, n_glim_cols = self.glim_size
cx, dx, sx = T.unbind(params[..., ::2], -1)
cy, dy, sy = T.unbind(params[..., 1::2], -1)
cx, dx, sx = self._to_axis_attention(
n_image_cols, n_glim_cols, cx, dx, sx)
cy, dy, sy = self._to_axis_attention(
n_image_rows, n_glim_rows, cy, dy, sy)
# ap is now the absolute coordinate/scale on image
# (batch_size, n_glims, att_params)
ap = T.stack([cx, cy, dx, dy, sx, sy], -1)
return ap
class BilinearGlimpse(NN.Module):
att_params = 4
def __init__(self, glim_size):
NN.Module.__init__(self)
self.glim_size = glim_size
@classmethod
def full(cls):
return tovar([0.5, 0.5, 1, 1])
@classmethod
def rescale(cls, x, glimpse_sample):
y = [
F.sigmoid(x[..., 0]), # cx
F.sigmoid(x[..., 1]), # cy
#F.softplus(x[..., 2]) / softplus_zero, #dx
#F.softplus(x[..., 3]) / softplus_zero, #dy
F.sigmoid(x[..., 2]) * 2,
F.sigmoid(x[..., 3]) * 2,
#x[..., 2].exp(),
#x[..., 3].exp(),
]
if glimpse_sample:
diag = T.stack([
y[0] - y[2] / 2,
y[1] - y[3] / 2,
y[0] + y[2] / 2,
y[1] + y[3] / 2,
], -1)
diagN = T.distributions.Normal(
diag, T.ones_like(diag) * 0.1)
diag = diagN.sample()
diag_logprob = diagN.log_prob(diag)
y = [
(diag[..., 0] + diag[..., 2]) / 2,
(diag[..., 1] + diag[..., 3]) / 2,
diag[..., 2] - diag[..., 0],
diag[..., 3] - diag[..., 1],
]
else:
diag_logprob = 0
return T.stack(y, -1), diag_logprob
def forward(self, x, spatial_att):
'''
x: 4D Tensor (batch_size, nchannels, n_image_rows, n_image_cols)
spatial_att: 3D Tensor (batch_size, n_glims, att_params) relative scales
'''
nsamples, nchan, xrow, xcol = x.size()
nglims = spatial_att.size()[1]
x = x[:, None].contiguous()
crow, ccol = self.glim_size
cx, cy, w, h = T.unbind(spatial_att, -1)
cx = cx * xcol
cy = cy * xrow
w = w * xcol
h = h * xrow
dx = w / (ccol - 1)
dy = h / (crow - 1)
cx = cx[:, :, None]
cy = cy[:, :, None]
dx = dx[:, :, None]
dy = dy[:, :, None]
mx = cx + dx * (tovar(T.arange(ccol))[None, None, :] - (ccol - 1) / 2)
my = cy + dy * (tovar(T.arange(crow))[None, None, :] - (crow - 1) / 2)
a = tovar(T.arange(xcol))
b = tovar(T.arange(xrow))
ax = (1 - T.abs(a.view(1, 1, -1, 1) - mx[:, :, None, :])).clamp(min=0)
ax = ax[:, :, None, :, :]
ax = ax.expand(nsamples, nglims, nchan, xcol, ccol).contiguous().view(-1, xcol, ccol)
by = (1 - T.abs(b.view(1, 1, -1, 1) - my[:, :, None, :])).clamp(min=0)
by = by[:, :, None, :, :]
by = by.expand(nsamples, nglims, nchan, xrow, crow).contiguous().view(-1, xrow, crow)
bilin = by.permute(0, 2, 1) @ x.view(-1, xrow, xcol) @ ax
return bilin.view(nsamples, nglims, nchan, crow, ccol)
@classmethod
def absolute_to_relative(cls, att, absolute):
C_x, C_y, D_x, D_y = T.unbind(absolute, -1)
c_x, c_y, d_x, d_y = T.unbind(att, -1)
return T.stack([
(c_x - C_x) / D_x + 0.5,
(c_y - C_y) / D_y + 0.5,
d_x / D_x,
d_y / D_y,
], -1)
@classmethod
def relative_to_absolute(cls, att, relative):
C_x, C_y, D_x, D_y = T.unbind(relative, -1)
c_x, c_y, d_x, d_y = T.unbind(att, -1)
return T.stack([
(c_x - 0.5) * D_x + C_x,
(c_y - 0.5) * D_y + C_y,
d_x * D_x,
d_y * D_y,
], -1)
glimpse_table = {
'gaussian': GaussianGlimpse,
'bilinear': BilinearGlimpse,
}
def create_glimpse(name, size):
return glimpse_table[name](size)
import networkx as nx
from glimpse import create_glimpse
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as MODELS
import torch.nn.init as INIT
from util import USE_CUDA, cuda
import numpy as np
import skorch
from viz import VisdomWindowManager
import matplotlib.pyplot as plt
from dgl.graph import DGLGraph
batch_size = 32
wm = VisdomWindowManager(port=10248)
def dfs_walk(tree, curr, l):
if len(tree.succ[curr]) == 0:
return
else:
for n in tree.succ[curr]:
l.append((curr, n))
dfs_walk(tree, n, l)
l.append((n, curr))
def build_cnn(**config):
cnn_list = []
filters = config['filters']
kernel_size = config['kernel_size']
in_channels = config.get('in_channels', 3)
final_pool_size = config['final_pool_size']
for i in range(len(filters)):
module = nn.Conv2d(
in_channels if i == 0 else filters[i-1],
filters[i],
kernel_size,
padding=tuple((_ - 1) // 2 for _ in kernel_size),
)
INIT.xavier_uniform_(module.weight)
INIT.constant_(module.bias, 0)
cnn_list.append(module)
if i < len(filters) - 1:
cnn_list.append(nn.LeakyReLU())
cnn_list.append(nn.AdaptiveMaxPool2d(final_pool_size))
return nn.Sequential(*cnn_list)
def build_resnet_cnn(**config):
n_layers = config['n_layers']
final_pool_size = config['final_pool_size']
resnet = MODELS.resnet18(pretrained=False)
cnn_list = list(resnet.children())[0:n_layers]
cnn_list.append(nn.AdaptiveMaxPool2d(final_pool_size))
return nn.Sequential(*cnn_list)
def init_canvas(n_nodes):
fig, ax = plt.subplots(2, 4)
fig.set_size_inches(16, 8)
return fig, ax
def display_image(fig, ax, i, im, title):
im = im.detach().cpu().numpy().transpose(1, 2, 0)
ax[i // 4, i % 4].imshow(im, cmap='gray', vmin=0, vmax=1)
ax[i // 4, i % 4].set_title(title)
class MessageModule(nn.Module):
# NOTE(minjie): message module signature change.
def forward(self, src, dst, edge):
h, b_next = [src[k] for k in ['h', 'b_next']]
return h, b_next
class UpdateModule(nn.Module):
"""
UpdateModule:
Returns:
h: new state
b: new bounding box
a: attention (for readout)
y: prediction
"""
def __init__(self, **config):
#h_dims=128,
#n_classes=10,
#steps=5,
#filters=[16, 32, 64, 128, 256],
#kernel_size=(3, 3),
#final_pool_size=(2, 2),
#glimpse_type='gaussian',
#glimpse_size=(15, 15),
#cnn='resnet'
#):
super(UpdateModule, self).__init__()
glimpse_type = config['glimpse_type']
glimpse_size = config['glimpse_size']
self.glimpse = create_glimpse(glimpse_type, glimpse_size)
h_dims = config['h_dims']
n_classes = config['n_classes']
self.net_b = nn.Sequential(
nn.Linear(h_dims, h_dims),
nn.ReLU(),
nn.Linear(h_dims, self.glimpse.att_params),
)
self.net_y = nn.Sequential(
nn.Linear(h_dims, h_dims),
nn.ReLU(),
nn.Linear(h_dims, n_classes),
)
self.net_a = nn.Sequential(
nn.Linear(h_dims, h_dims),
nn.ReLU(),
nn.Linear(h_dims, 1),
)
self.h_to_h = nn.GRUCell(h_dims * 2, h_dims)
INIT.orthogonal_(self.h_to_h.weight_hh)
cnn = config['cnn']
final_pool_size = config['final_pool_size']
if cnn == 'resnet':
n_layers = config['n_layers']
self.cnn = build_resnet_cnn(
n_layers=n_layers,
final_pool_size=final_pool_size,
)
self.net_h = nn.Linear(128 * np.prod(final_pool_size), h_dims)
else:
filters = config['filters']
kernel_size = config['kernel_size']
self.cnn = build_cnn(
filters=filters,
kernel_size=kernel_size,
final_pool_size=final_pool_size,
)
self.net_h = nn.Linear(filters[-1] * np.prod(final_pool_size), h_dims)
self.max_recur = config.get('max_recur', 1)
self.h_dims = h_dims
def set_image(self, x):
self.x = x
def forward(self, node_state, message):
h, b, y, b_fix = [node_state[k] for k in ['h', 'b', 'y', 'b_fix']]
batch_size = h.shape[0]
if len(message) == 0:
h_m_avg = h.new(batch_size, self.h_dims).zero_()
else:
h_m, b_next = zip(*message)
h_m_avg = T.stack(h_m).mean(0)
b = T.stack(b_next).mean(0) if b_fix is None else b_fix
b_new = b_fix = b
h_new = h
for i in range(self.max_recur):
b_rescaled, _ = self.glimpse.rescale(b_new[:, None], False)
g = self.glimpse(self.x, b_rescaled)[:, 0]
h_in = T.cat([self.net_h(self.cnn(g).view(batch_size, -1)), h_m_avg], -1)
h_new = self.h_to_h(h_in, h_new)
db = self.net_b(h_new)
dy = self.net_y(h_new)
b_new = b + db
y_new = y + dy
a_new = self.net_a(h_new)
return {'h': h_new, 'b': b, 'b_next': b_new, 'a': a_new, 'y': y_new, 'g': g, 'b_fix': b_fix, 'db': db}
def update_local():
pass
class ReadoutModule(nn.Module):
'''
Returns the logits of classes
'''
def __init__(self, *args, **kwarg):
super(ReadoutModule, self).__init__()
self.y = nn.Linear(kwarg['h_dims'], kwarg['n_classes'])
# NOTE(minjie): readout module signature change.
def forward(self, nodes_state, edge_states, pretrain=False):
if pretrain:
assert len(nodes_state) == 1 # root only
h = nodes_state[0]['h']
y = self.y(h)
else:
#h = T.stack([s['h'] for s in nodes_state], 1)
#a = F.softmax(T.stack([s['a'] for s in nodes_state], 1), 1)
#b_of_h = T.sum(a * h, 1)
#b_of_h = h[:, -1]
#y = self.y(b_of_h)
#y = nodes_state[-1]['y']
y = T.stack([s['y'] for s in nodes_state], 1)
return y
class DFSGlimpseSingleObjectClassifier(nn.Module):
def __init__(self,
h_dims=128,
n_classes=10,
filters=[16, 32, 64, 128, 256],
kernel_size=(3, 3),
final_pool_size=(2, 2),
glimpse_type='gaussian',
glimpse_size=(15, 15),
cnn='cnn'
):
nn.Module.__init__(self)
#self.T_MAX_RECUR = kwarg['steps']
t = nx.balanced_tree(1, 2)
t_uni = nx.bfs_tree(t, 0)
self.G = DGLGraph(t)
self.root = 0
self.h_dims = h_dims
self.n_classes = n_classes
self.message_module = MessageModule()
self.G.register_message_func(self.message_module) # default: just copy
#self.update_module = UpdateModule(h_dims, n_classes, glimpse_size)
self.update_module = UpdateModule(
glimpse_type=glimpse_type,
glimpse_size=glimpse_size,
n_layers=6,
h_dims=h_dims,
n_classes=n_classes,
final_pool_size=final_pool_size,
filters=filters,
kernel_size=kernel_size,
cnn=cnn,
max_recur=1, # T_MAX_RECUR
)
self.G.register_update_func(self.update_module)
self.readout_module = ReadoutModule(h_dims=h_dims, n_classes=n_classes)
self.G.register_readout_func(self.readout_module)
self.walk_list = [(0, 1), (1, 2)]
#dfs_walk(t_uni, self.root, self.walk_list)
def forward(self, x, pretrain=False):
batch_size = x.shape[0]
self.update_module.set_image(x)
init_states = {
'h': x.new(batch_size, self.h_dims).zero_(),
'b': x.new(batch_size, self.update_module.glimpse.att_params).zero_(),
'b_next': x.new(batch_size, self.update_module.glimpse.att_params).zero_(),
'a': x.new(batch_size, 1).zero_(),
'y': x.new(batch_size, self.n_classes).zero_(),
'g': None,
'b_fix': None,
'db': None,
}
for n in self.G.nodes():
self.G.node[n].update(init_states)
#TODO: the following two lines is needed for single object
#TODO: but not useful or wrong for multi-obj
self.G.recvfrom(self.root, [])
if pretrain:
return self.G.readout([self.root], pretrain=True)
else:
# XXX(minjie): could replace the following loop with propagate call.
#for u, v in self.walk_list:
#self.G.update_by_edge(u, v)
# update local should be inside the update module
#for i in self.T_MAX_RECUR:
# self.G.update_local(u)
self.G.propagate(self.walk_list)
return self.G.readout(pretrain=False)
class Net(skorch.NeuralNet):
def __init__(self, **kwargs):
self.reg_coef_ = kwargs.get('reg_coef', 1e-4)
del kwargs['reg_coef']
skorch.NeuralNet.__init__(self, **kwargs)
def initialize_criterion(self):
# Overriding this method to skip initializing criterion as we don't use it.
pass
def get_split_datasets(self, X, y=None, **fit_params):
# Overriding this method to use our own dataloader to change the X
# in signature to (train_dataset, valid_dataset)
X_train, X_valid = X
train = self.get_dataset(X_train, None)
valid = self.get_dataset(X_valid, None)
return train, valid
def train_step(self, Xi, yi, **fit_params):
step = skorch.NeuralNet.train_step(self, Xi, yi, **fit_params)
dbs = [self.module_.G.nodes[v]['db'] for v in self.module_.G.nodes]
reg = self.reg_coef_ * sum(db.norm(2, 1).mean() for db in dbs if db is not None)
loss = step['loss'] + reg
y_pred = step['y_pred']
acc = self.get_loss(y_pred, yi, training=False)
self.history.record_batch('max_param', max(p.abs().max().item() for p in self.module_.parameters()))
self.history.record_batch('acc', acc.item())
self.history.record_batch('reg', reg.item())
return {
'loss': loss,
'y_pred': y_pred,
}
def get_loss(self, y_pred, y_true, X=None, training=False):
batch_size, n_steps, _ = y_pred.shape
if training:
#return F.cross_entropy(y_pred, y_true)
y_true = y_true[:, None].expand(batch_size, n_steps)
return F.cross_entropy(
y_pred.reshape(batch_size * n_steps, -1),
y_true.reshape(-1)
)
else:
y_prob, y_cls = y_pred.max(-1)
_, y_prob_maxind = y_prob.max(-1)
y_cls_final = y_cls.gather(1, y_prob_maxind[:, None])[:, 0]
return (y_cls_final == y_true).sum()
class Dump(skorch.callbacks.Callback):
def initialize(self):
self.epoch = 0
self.batch = 0
self.correct = 0
self.total = 0
self.best_acc = 0
self.nviz = 0
return self
def on_epoch_begin(self, net, **kwargs):
self.epoch += 1
self.batch = 0
self.correct = 0
self.total = 0
self.nviz = 0
def on_batch_end(self, net, **kwargs):
self.batch += 1
if kwargs['training']:
#print('#', self.epoch, self.batch, kwargs['loss'], kwargs['valid_loss'])
pass
else:
self.correct += kwargs['loss'].item()
self.total += kwargs['X'].shape[0]
if self.nviz < 10:
n_nodes = len(net.module_.G.nodes)
fig, ax = init_canvas(n_nodes)
#a = T.stack([net.module_.G.nodes[v]['a'] for v in net.module_.G.nodes], 1)
#a = F.softmax(a, 1).detach().cpu().numpy()
y = T.stack([net.module_.G.nodes[v]['y'] for v in net.module_.G.nodes], 1)
y_val, y = y.max(-1)
for i, n in enumerate(net.module_.G.nodes):
repr_ = net.module_.G.nodes[n]
g = repr_['g']
if g is None:
continue
b, _ = net.module_.update_module.glimpse.rescale(repr_['b'], False)
display_image(
fig,
ax,
i,
g[0],
np.array_str(
b[0].detach().cpu().numpy(),
precision=2, suppress_small=True) +
#'a=%.2f' % a[0, i, 0]
'y=%d (%.2f)' % (y[0, i], y_val[0, i])
)
wm.display_mpl_figure(fig, win='viz{}'.format(self.nviz))
self.nviz += 1
def on_epoch_end(self, net, **kwargs):
print('@', self.epoch, self.correct, '/', self.total)
acc = self.correct / self.total
if self.best_acc < acc:
self.best_acc = acc
net.history.record('acc_best', acc)
else:
net.history.record('acc_best', None)
def data_generator(dataset, batch_size, shuffle):
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=True, num_workers=0)
for _x, _y, _B in dataloader:
x = _x[:, None].expand(_x.shape[0], 3, _x.shape[1], _x.shape[2]).float() / 255.
y = _y.squeeze(1)
yield cuda(x), cuda(y)
if __name__ == "__main__":
from datasets import MNISTMulti
from torch.utils.data import DataLoader
from sklearn.model_selection import GridSearchCV
mnist_train = MNISTMulti('.', n_digits=1, backrand=0, image_rows=200, image_cols=200, download=True)
mnist_valid = MNISTMulti('.', n_digits=1, backrand=0, image_rows=200, image_cols=200, download=False, mode='valid')
for reg_coef in [0, 100, 1e-2, 0.1, 1, 1e-3]:
print('Trying reg coef', reg_coef)
net = Net(
module=DFSGlimpseSingleObjectClassifier,
criterion=None,
max_epochs=50,
reg_coef=reg_coef,
optimizer=T.optim.RMSprop,
#optimizer__weight_decay=1e-4,
lr=1e-5,
batch_size=batch_size,
device='cuda' if USE_CUDA else 'cpu',
callbacks=[
Dump(),
skorch.callbacks.Checkpoint(monitor='acc_best'),
skorch.callbacks.ProgressBar(postfix_keys=['train_loss', 'valid_loss', 'acc', 'reg']),
skorch.callbacks.GradientNormClipping(0.01),
#skorch.callbacks.LRScheduler('ReduceLROnPlateau'),
],
iterator_train=data_generator,
iterator_train__shuffle=True,
iterator_valid=data_generator,
iterator_valid__shuffle=False,
)
#net.fit((mnist_train, mnist_valid), pretrain=True, epochs=50)
net.partial_fit((mnist_train, mnist_valid), pretrain=False, epochs=500)
import torch as T
import torch.nn.functional as F
import numpy as NP
import os
USE_CUDA = os.getenv('USE_CUDA', None)
def cuda(x, device=None, async_=False):
return x.cuda() if USE_CUDA else x
def tovar(x, *args, dtype='float32', **kwargs):
if not T.is_tensor(x):
x = T.from_numpy(NP.array(x, dtype=dtype))
return T.autograd.Variable(cuda(x), *args, **kwargs)
def tonumpy(x):
if isinstance(x, T.autograd.Variable):
x = x.data
if T.is_tensor(x):
x = x.cpu().numpy()
return x
def create_onehot(idx, size):
onehot = tovar(T.zeros(*size))
onehot = onehot.scatter(1, idx.unsqueeze(1), 1)
return onehot
def reverse(x, dim):
idx = T.arange(x.size()[dim] - 1, -1, -1).long().to(x.device)
return x.index_select(dim, idx)
def addbox(ax, b, ec, lw=1):
import matplotlib.patches as PA
ax.add_patch(PA.Rectangle((b[0] - b[2] / 2, b[1] - b[3] / 2), b[2], b[3],
ec=ec, fill=False, lw=lw))
def overlay(fore, fore_bbox, back):
batch_size = fore.size()[0]
crows, ccols = fore.size()[-2:]
cx, cy, w, h = T.unbind(fore_bbox, -1)
x1 = -2 * cx / w
x2 = 2 * (1 - cx) / w
y1 = -2 * cy / h
y2 = 2 * (1 - cy) / h
x1 = x1[:, None]
x2 = x2[:, None]
y1 = y1[:, None]
y2 = y2[:, None]
nrows, ncols = back.size()[-2:]
grid_x = x1 + (x2 - x1) * tovar(T.arange(ncols))[None, :] / (ncols - 1)
grid_y = y1 + (y2 - y1) * tovar(T.arange(nrows))[None, :] / (nrows - 1)
grid = T.stack([
grid_x[:, None, :].expand(batch_size, nrows, ncols),
grid_y[:, :, None].expand(batch_size, nrows, ncols),
], -1)
fore = T.cat([fore, tovar(T.ones(batch_size, 1, crows, ccols))], 1)
fore = F.grid_sample(fore, grid)
fore_rgb = fore[:, :3]
fore_alpha = fore[:, 3:4]
result = fore_rgb * fore_alpha + back * (1 - fore_alpha)
return result
def intersection(a, b):
x1 = T.max(a[..., 0] - a[..., 2] / 2, b[..., 0] - b[..., 2] / 2)
y1 = T.max(a[..., 1] - a[..., 3] / 2, b[..., 1] - b[..., 3] / 2)
x2 = T.min(a[..., 0] + a[..., 2] / 2, b[..., 0] + b[..., 2] / 2)
y2 = T.min(a[..., 1] + a[..., 3] / 2, b[..., 1] + b[..., 3] / 2)
w = (x2 - x1).clamp(min=0)
h = (y2 - y1).clamp(min=0)
return w * h
def iou(a, b):
i_area = intersection(a, b)
a_area = a[..., 2] * a[..., 3]
b_area = b[..., 2] * b[..., 3]
return i_area / (a_area + b_area - i_area)
import visdom
import matplotlib.pyplot as PL
from util import *
import numpy as np
import cv2
def _fig_to_ndarray(fig):
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
#data = cv2.cvtColor(data, cv2.COLOR_RGB2BGR)
data = data.transpose(2, 0, 1)
PL.close(fig)
return data
class VisdomWindowManager(visdom.Visdom):
def __init__(self, **kwargs):
visdom.Visdom.__init__(self, **kwargs)
self.scalar_plot_length = {}
self.scalar_plot_prev_point = {}
self.mpl_figure_sequence = {}
def append_scalar(self, name, value, t=None, opts=None):
if self.scalar_plot_length.get(name, 0) == 0:
# If we are creating a scalar plot, store the starting point but
# don't plot anything yet
self.close(name)
t = 0 if t is None else t
self.scalar_plot_length[name] = 0
else:
# If we have at least two values, then plot a segment
t = self.scalar_plot_length[name] if t is None else t
prev_v, prev_t = self.scalar_plot_prev_point[name]
newopts = {'xlabel': 'time', 'ylabel': name}
if opts is not None:
newopts.update(opts)
self.line(
X=np.array([prev_t, t]),
Y=np.array([prev_v, value]),
win=name,
update=None if not self.win_exists(name) else 'append',
opts=newopts
)
self.scalar_plot_prev_point[name] = (value, t)
self.scalar_plot_length[name] += 1
def display_mpl_figure(self, fig, **kwargs):
'''
Call this function before calling 'PL.show()' or 'PL.savefig()'.
'''
self.image(
_fig_to_ndarray(fig),
**kwargs
)
def reset_mpl_figure_sequence(self, name):
self.mpl_figure_sequence[name] = []
def append_mpl_figure_to_sequence(self, name, fig):
data = _fig_to_ndarray(fig)
data = data.transpose(1, 2, 0)
if name not in self.mpl_figure_sequence:
self.reset_mpl_figure_sequence(name)
self.mpl_figure_sequence[name].append(data)
def display_mpl_figure_sequence(self, name, **kwargs):
data_seq = self.mpl_figure_sequence[name]
video_rows, video_cols = data_seq[0].shape[:2]
data_seq = [cv2.resize(f, (video_cols, video_rows)) for f in data_seq]
data_seq = np.array(data_seq, dtype=np.uint8)
self.video(
data_seq,
**kwargs
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment