Commit 4a543082 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

utils

parent b596b107
...@@ -5,8 +5,7 @@ ...@@ -5,8 +5,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import numpy as np import numpy as np
import torch import torch, torch.utils.data
import torchnet
import glob, math, os import glob, math, os
import scipy, scipy.ndimage import scipy, scipy.ndimage
import sparseconvnet as scn import sparseconvnet as scn
......
...@@ -14,9 +14,9 @@ import os, sys ...@@ -14,9 +14,9 @@ import os, sys
import math import math
import numpy as np import numpy as np
data.init(-1,24,24*8+15,16) data.init(-1,24,24*8,16)
dimension = 3 dimension = 3
reps = 2 #Conv block repetition factor reps = 1 #Conv block repetition factor
m = 32 #Unet number of features m = 32 #Unet number of features
nPlanes = [m, 2*m, 3*m, 4*m, 5*m] #UNet number of features per level nPlanes = [m, 2*m, 3*m, 4*m, 5*m] #UNet number of features per level
......
# Copyright 2016-present, Facebook, Inc. # Copyright 2016-present, Facebook, Inc.
# All rights reserved. # All rights reserved.
# #
...@@ -19,7 +20,7 @@ if not os.path.exists('pickle/'): ...@@ -19,7 +20,7 @@ if not os.path.exists('pickle/'):
'wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/OLHWDB1.1trn_pot.zip') 'wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/OLHWDB1.1trn_pot.zip')
os.system( os.system(
'wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/OLHWDB1.1tst_pot.zip') 'wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/OLHWDB1.1tst_pot.zip')
os.system('mkdir -p t7/train/ t7/test/ POT/ pickle/') os.system('mkdir -p POT/ pickle/')
os.system('unzip OLHWDB1.1trn_pot.zip -d POT/') os.system('unzip OLHWDB1.1trn_pot.zip -d POT/')
os.system('unzip OLHWDB1.1tst_pot.zip -d POT/') os.system('unzip OLHWDB1.1tst_pot.zip -d POT/')
os.system('python readPotFiles.py') os.system('python readPotFiles.py')
......
...@@ -260,7 +260,7 @@ Metadata<dimension>::sparsifyCompare(Metadata<dimension> &mReference, ...@@ -260,7 +260,7 @@ Metadata<dimension>::sparsifyCompare(Metadata<dimension> &mReference,
Metadata<dimension> &mSparsified, Metadata<dimension> &mSparsified,
/*long*/ at::Tensor spatialSize) { /*long*/ at::Tensor spatialSize) {
auto p = LongTensorToPoint<dimension>(spatialSize); auto p = LongTensorToPoint<dimension>(spatialSize);
at::Tensor delta = at::zeros({nActive[p]}, torch::CPU(at::kFloat)); at::Tensor delta = at::zeros({nActive[p]}, at::kFloat);
float *deltaPtr = delta.data<float>(); float *deltaPtr = delta.data<float>();
auto &sgsReference = mReference.grids[p]; auto &sgsReference = mReference.grids[p];
auto &sgsFull = grids[p]; auto &sgsFull = grids[p];
......
...@@ -18,8 +18,8 @@ public: ...@@ -18,8 +18,8 @@ public:
RSRTicks(Int input_spatialSize, Int output_spatialSize, Int size, Int stride, RSRTicks(Int input_spatialSize, Int output_spatialSize, Int size, Int stride,
std::default_random_engine re) { std::default_random_engine re) {
std::vector<Int> steps; std::vector<Int> steps;
// steps.resize(output_spatialSize/3,stride-1); steps.resize(output_spatialSize / 3, stride - 1);
// steps.resize(output_spatialSize/3*2,stride+1); steps.resize(output_spatialSize / 3 * 2, stride + 1);
steps.resize(output_spatialSize - 1, stride); steps.resize(output_spatialSize - 1, stride);
std::shuffle(steps.begin(), steps.end(), re); std::shuffle(steps.begin(), steps.end(), re);
inputL.push_back(0); inputL.push_back(0);
......
...@@ -34,3 +34,4 @@ from .submanifoldConvolution import SubmanifoldConvolution, ValidConvolution ...@@ -34,3 +34,4 @@ from .submanifoldConvolution import SubmanifoldConvolution, ValidConvolution
from .tables import * from .tables import *
from .unPooling import UnPooling from .unPooling import UnPooling
from .utils import append_tensors, AddCoords, add_feature_planes, concatenate_feature_planes, compare_sparse from .utils import append_tensors, AddCoords, add_feature_planes, concatenate_feature_planes, compare_sparse
from .shapeContext import ShapeContext, MultiscaleShapeContext
...@@ -41,7 +41,7 @@ class BatchNormalization(Module): ...@@ -41,7 +41,7 @@ class BatchNormalization(Module):
self.bias = Parameter(torch.Tensor(nPlanes).fill_(0)) self.bias = Parameter(torch.Tensor(nPlanes).fill_(0))
def forward(self, input): def forward(self, input):
assert input.features.nelement() == 0 or input.features.size(1) == self.nPlanes assert input.features.nelement() == 0 or input.features.size(1) == self.nPlanes, (self.nPlanes, input.features.shape)
output = SparseConvNetTensor() output = SparseConvNetTensor()
output.metadata = input.metadata output.metadata = input.metadata
output.spatial_size = input.spatial_size output.spatial_size = input.spatial_size
......
...@@ -34,7 +34,7 @@ class Convolution(Module): ...@@ -34,7 +34,7 @@ class Convolution(Module):
output.spatial_size =\ output.spatial_size =\
(input.spatial_size - self.filter_size) / self.filter_stride + 1 (input.spatial_size - self.filter_size) / self.filter_stride + 1
assert ((output.spatial_size - 1) * self.filter_stride + assert ((output.spatial_size - 1) * self.filter_stride +
self.filter_size == input.spatial_size).all() self.filter_size == input.spatial_size).all(), (input.spatial_size,output.spatial_size,self.filter_size,self.filter_stride)
output.features = ConvolutionFunction.apply( output.features = ConvolutionFunction.apply(
input.features, input.features,
self.weight, self.weight,
......
...@@ -70,7 +70,7 @@ class NetworkInNetwork(Module): ...@@ -70,7 +70,7 @@ class NetworkInNetwork(Module):
self.bias = Parameter(torch.Tensor(nOut).zero_()) self.bias = Parameter(torch.Tensor(nOut).zero_())
def forward(self, input): def forward(self, input):
assert input.features.nelement() == 0 or input.features.size(1) == self.nIn assert input.features.nelement() == 0 or input.features.size(1) == self.nIn, (self.nIn, input.features.shape)
output = SparseConvNetTensor() output = SparseConvNetTensor()
output.metadata = input.metadata output.metadata = input.metadata
output.spatial_size = input.spatial_size output.spatial_size = input.spatial_size
......
# Copyright 2016-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Fixed weight submanifold convolution - ineffcieit implementation
# prod(filter_size)* nIn outputs
# weight format locations x nInput x nOutput
import sparseconvnet
import sparseconvnet.SCN
from torch.autograd import Function
from torch.nn import Module, Parameter
from .utils import *
from .sparseConvNetTensor import SparseConvNetTensor
class ShapeContext(Module):
def __init__(self, dimension, nIn, filter_size=3):
Module.__init__(self)
self.dimension = dimension
self.filter_size = toLongTensor(dimension, filter_size)
self.filter_volume = self.filter_size.prod().item()
self.nIn = nIn
self.nOut = nIn * self.filter_volume
self.register_buffer("weight",
torch.eye(self.nOut).view(self.filter_volume, self.nIn, self.nOut))
def forward(self, input):
assert input.features.nelement() == 0 or input.features.size(1) == self.nIn, (self.nIn, self.nOut, input)
output = SparseConvNetTensor()
output.metadata = input.metadata
output.spatial_size = input.spatial_size
output.features = ShapeContextFunction.apply(
input.features,
self.weight,
optionalTensor(self, 'bias'),
input.metadata,
input.spatial_size,
self.dimension,
self.filter_size)
return output
def __repr__(self):
s = 'ShapeContext ' + \
str(self.nIn) + '->' + str(self.nOut) + ' C'
if self.filter_size.max() == self.filter_size.min():
s = s + str(self.filter_size[0].item())
else:
s = s + '(' + str(self.filter_size[0].item())
for i in self.filter_size[1:]:
s = s + ',' + str(i.item())
s = s + ')'
return s
def input_spatial_size(self, out_size):
return out_size
class ShapeContextFunction(Function):
@staticmethod
def forward(
ctx,
input_features,
weight,
bias,
input_metadata,
spatial_size,
dimension,
filter_size):
ctx.input_metadata = input_metadata
ctx.dimension = dimension
output_features = input_features.new()
ctx.save_for_backward(
input_features,
spatial_size,
weight,
bias,
filter_size)
sparseconvnet.SCN.SubmanifoldConvolution_updateOutput(
spatial_size,
filter_size,
input_metadata,
input_features,
output_features,
weight,
bias)
return output_features
@staticmethod
def backward(ctx, grad_output):
assert False, "Don't backprop through ShapeContext!"
input_features, spatial_size, weight, bias, filter_size = ctx.saved_tensors
grad_input = grad_output.new()
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
sparseconvnet.SCN.SubmanifoldConvolution_backward(
spatial_size,
filter_size,
ctx.input_metadata,
input_features,
grad_input,
grad_output.contiguous(),
weight,
grad_weight,
grad_bias)
return grad_input, grad_weight, optionalTensorReturn(grad_bias), None, None, None, None
def MultiscaleShapeContext(dimension,n_features=1,n_layers=3,shape_context_size=3,downsample_size=2,downsample_stride=2,bn=True):
m=sparseconvnet.Sequential()
if n_layers==1:
m.add(sparseconvnet.ShapeContext(dimension,n_features,shape_context_size))
else:
m.add(
sparseconvnet.ConcatTable().add(
sparseconvnet.ShapeContext(dimension, n_features, shape_context_size)).add(
sparseconvnet.Sequential(
sparseconvnet.AveragePooling(dimension,downsample_size,downsample_stride),
MultiscaleShapeContext(dimension,n_features,n_layers-1,shape_context_size,downsample_size,downsample_stride,False),
sparseconvnet.UnPooling(dimension,downsample_size,downsample_stride)))).add(
sparseconvnet.JoinTable())
if bn:
m.add(sparseconvnet.BatchNormalization(shape_context_size**dimension*n_features*n_layers))
return m
...@@ -29,7 +29,7 @@ class SubmanifoldConvolution(Module): ...@@ -29,7 +29,7 @@ class SubmanifoldConvolution(Module):
self.bias = Parameter(torch.Tensor(nOut).zero_()) self.bias = Parameter(torch.Tensor(nOut).zero_())
def forward(self, input): def forward(self, input):
assert input.features.nelement() == 0 or input.features.size(1) == self.nIn assert input.features.nelement() == 0 or input.features.size(1) == self.nIn, (self.nIn, self.nOut, input)
output = SparseConvNetTensor() output = SparseConvNetTensor()
output.metadata = input.metadata output.metadata = input.metadata
output.spatial_size = input.spatial_size output.spatial_size = input.spatial_size
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch, glob, os
from .sparseConvNetTensor import SparseConvNetTensor from .sparseConvNetTensor import SparseConvNetTensor
from .metadata import Metadata from .metadata import Metadata
...@@ -113,3 +113,47 @@ def spectral_norm_svd(module): ...@@ -113,3 +113,47 @@ def spectral_norm_svd(module):
w=w.view(-1,w.size(2)) w=w.view(-1,w.size(2))
_,s,_=torch.svd(w) _,s,_=torch.svd(w)
return s[0] return s[0]
def pad_with_batch_idx(x,idx): #add a batch index to the list of coordinates
return torch.cat([x,torch.LongTensor(x.size(0),1).fill_(idx)],1)
def batch_location_tensors(location_tensors):
a=[]
for batch_idx, lt in enumerate(location_tensors):
if lt.numel():
a.append(pad_with_batch_idx(lt,batch_idx))
return torch.cat(a,0)
def checkpoint_restore(model,exp_name,name2,use_cuda=True,epoch=0):
if use_cuda:
model.cpu()
if epoch>0:
f=exp_name+'-%09d-'%epoch+name2+'.pth'
assert os.path.isfile(f)
print('Restore from ' + f)
model.load_state_dict(torch.load(f))
else:
f=sorted(glob.glob(exp_name+'-*-'+name2+'.pth'))
if len(f)>0:
f=f[-1]
print('Restore from ' + f)
model.load_state_dict(torch.load(f))
epoch=int(f[len(exp_name)+1:-len(name2)-5])
if use_cuda:
model.cuda()
return epoch+1
def is_power2(num):
return num != 0 and ((num & (num - 1)) == 0)
def checkpoint_save(model,exp_name,name2,epoch, use_cuda=True):
f=exp_name+'-%09d-'%epoch+name2+'.pth'
model.cpu()
torch.save(model.state_dict(),f)
if use_cuda:
model.cuda()
#remove previous checkpoints unless they are a power of 2 to save disk space
epoch=epoch-1
f=exp_name+'-%09d-'%epoch+name2+'.pth'
if os.path.isfile(f):
if not is_power2(epoch):
os.remove(f)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment