Commit de3743f6 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

Factor out CUDA code

parent f0407b36
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import sparseconvnet import sparseconvnet, sparseconvnet_SCN
from torch.autograd import Function from torch.autograd import Function
from torch.nn import Module, Parameter from torch.nn import Module, Parameter
from .utils import * from .utils import *
...@@ -45,6 +45,26 @@ class Deconvolution(Module): ...@@ -45,6 +45,26 @@ class Deconvolution(Module):
self.filter_stride) self.filter_stride)
return output return output
def fullForward(self, input):
assert input.features.nelement()==0 or input.features.size(1) == self.nIn
output = SparseConvNetTensor()
output.metadata = Metadata(self.dimension)
output.spatial_size =\
(input.spatial_size - 1) * self.filter_stride + self.filter_size
output.features=FullConvolutionFunction().apply(
input.features,
self.weight,
optionalTensor(self, 'bias'),
input.metadata,
output.metadata,
input.spatial_size,
output.spatial_size,
self.dimension,
self.filter_size,
self.filter_stride,
)
return output
def __repr__(self): def __repr__(self):
s = 'Deconvolution ' + str(self.nIn) + '->' + str(self.nOut) + ' C' s = 'Deconvolution ' + str(self.nIn) + '->' + str(self.nOut) + ' C'
if self.filter_size.max().item() == self.filter_size.min().item() and\ if self.filter_size.max().item() == self.filter_size.min().item() and\
...@@ -85,8 +105,7 @@ class DeconvolutionFunction(Function): ...@@ -85,8 +105,7 @@ class DeconvolutionFunction(Function):
ctx.dimension = dimension ctx.dimension = dimension
sparseconvnet.forward_pass_multiplyAdd_count +=\ sparseconvnet.forward_pass_multiplyAdd_count +=\
dim_typed_fn( sparseconvnet_SCN.Deconvolution_updateOutput(
dimension, input_features, 'Deconvolution_updateOutput')(
input_spatial_size, input_spatial_size,
output_spatial_size, output_spatial_size,
filter_size, filter_size,
...@@ -120,8 +139,7 @@ class DeconvolutionFunction(Function): ...@@ -120,8 +139,7 @@ class DeconvolutionFunction(Function):
grad_input = grad_output.new() grad_input = grad_output.new()
grad_weight = torch.zeros_like(weight) grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias) grad_bias = torch.zeros_like(bias)
dim_typed_fn( sparseconvnet_SCN.Deconvolution_backward(
ctx.dimension, input_features, 'Deconvolution_backward')(
input_spatial_size, input_spatial_size,
output_spatial_size, output_spatial_size,
filter_size, filter_size,
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import sparseconvnet import sparseconvnet, sparseconvnet_SCN
from torch.autograd import Function, Variable from torch.autograd import Function, Variable
from torch.nn import Module, Parameter from torch.nn import Module, Parameter
from .utils import * from .utils import *
...@@ -48,6 +48,24 @@ class FullConvolution(Module): ...@@ -48,6 +48,24 @@ class FullConvolution(Module):
) )
return output return output
def deconvolutionForward(self, input):
assert input.features.nelement() == 0 or input.features.size(1) == self.nIn
output = SparseConvNetTensor()
output.metadata = input.metadata
output.spatial_size =\
(input.spatial_size - 1) * self.filter_stride + self.filter_size
output.features = DeconvolutionFunction.apply(
input.features,
self.weight,
optionalTensor(self, 'bias'),
input.metadata,
input.spatial_size,
output.spatial_size,
self.dimension,
self.filter_size,
self.filter_stride)
return output
def __repr__(self): def __repr__(self):
s = 'FullConvolution ' + str(self.nIn) + '->' + str(self.nOut) + ' C' s = 'FullConvolution ' + str(self.nIn) + '->' + str(self.nOut) + ' C'
if self.filter_size.max() == self.filter_size.min() and\ if self.filter_size.max() == self.filter_size.min() and\
...@@ -64,7 +82,11 @@ class FullConvolution(Module): ...@@ -64,7 +82,11 @@ class FullConvolution(Module):
return s return s
def input_spatial_size(self, out_size): def input_spatial_size(self, out_size):
return (out_size - 1) * self.filter_stride + self.filter_size in_size = (out_size - self.filter_size) / self.filter_stride + 1
assert ((in_size - 1) * self.filter_stride +
self.filter_size == out_size).all()
return in_size
class FullConvolutionFunction(Function): class FullConvolutionFunction(Function):
@staticmethod @staticmethod
...@@ -93,8 +115,7 @@ class FullConvolutionFunction(Function): ...@@ -93,8 +115,7 @@ class FullConvolutionFunction(Function):
filter_size, filter_size,
filter_stride) filter_stride)
sparseconvnet.forward_pass_multiplyAdd_count +=\ sparseconvnet.forward_pass_multiplyAdd_count +=\
dim_typed_fn( sparseconvnet_SCN.FullConvolution_updateOutput(
dimension, input_features, 'FullConvolution_updateOutput')(
input_spatial_size, input_spatial_size,
output_spatial_size, output_spatial_size,
filter_size, filter_size,
...@@ -113,8 +134,7 @@ class FullConvolutionFunction(Function): ...@@ -113,8 +134,7 @@ class FullConvolutionFunction(Function):
grad_input = grad_output.new() grad_input = grad_output.new()
grad_weight = torch.zeros_like(weight) grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias) grad_bias = torch.zeros_like(bias)
dim_typed_fn( sparseconvnet_SCN.FullConvolution_backward(
ctx.dimension, input_features, 'FullConvolution_backward')(
input_spatial_size, input_spatial_size,
output_spatial_size, output_spatial_size,
filter_size, filter_size,
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import torch import torch
from .metadata import Metadata from .metadata import Metadata
from .utils import toLongTensor, dim_fn from .utils import toLongTensor
from .sparseConvNetTensor import SparseConvNetTensor from .sparseConvNetTensor import SparseConvNetTensor
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import sparseconvnet_SCN
from torch.autograd import Function from torch.autograd import Function
from torch.nn import Module, Parameter from torch.nn import Module, Parameter
from .utils import * from .utils import *
...@@ -164,7 +165,7 @@ class InputLayerFunction(Function): ...@@ -164,7 +165,7 @@ class InputLayerFunction(Function):
output_features = input_features.new() output_features = input_features.new()
ctx.dimension = dimension ctx.dimension = dimension
ctx.metadata_ = metadata ctx.metadata_ = metadata
dim_typed_fn(dimension, input_features, 'InputLayer_updateOutput')( sparseconvnet_SCN.InputLayer_updateOutput(
metadata, metadata,
spatial_size, spatial_size,
coords, coords,
...@@ -178,10 +179,7 @@ class InputLayerFunction(Function): ...@@ -178,10 +179,7 @@ class InputLayerFunction(Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
grad_input = grad_output.new() grad_input = grad_output.new()
dim_typed_fn( sparseconvnet_SCN.InputLayer_updateGradInput(
ctx.dimension,
grad_output,
'InputLayer_updateGradInput')(
ctx.metadata_, ctx.metadata_,
grad_input, grad_input,
grad_output.contiguous()) grad_output.contiguous())
...@@ -198,7 +196,7 @@ class OutputLayerFunction(Function): ...@@ -198,7 +196,7 @@ class OutputLayerFunction(Function):
output_features = input_features.new() output_features = input_features.new()
ctx.metadata_ = metadata ctx.metadata_ = metadata
ctx.dimension = dimension ctx.dimension = dimension
dim_typed_fn(dimension, input_features, 'OutputLayer_updateOutput')( sparseconvnet_SCN.OutputLayer_updateOutput(
metadata, metadata,
input_features.contiguous(), input_features.contiguous(),
output_features output_features
...@@ -209,10 +207,7 @@ class OutputLayerFunction(Function): ...@@ -209,10 +207,7 @@ class OutputLayerFunction(Function):
def backward(ctx, grad_output): def backward(ctx, grad_output):
grad_input = grad_output.new() grad_input = grad_output.new()
grad_output=grad_output.contiguous() grad_output=grad_output.contiguous()
dim_typed_fn( sparseconvnet_SCN.OutputLayer_updateGradInput(
ctx.dimension,
grad_output,
'OutputLayer_updateGradInput')(
ctx.metadata_, ctx.metadata_,
grad_input, grad_input,
grad_output.contiguous()) grad_output.contiguous())
...@@ -232,7 +227,7 @@ class BLInputLayerFunction(Function): ...@@ -232,7 +227,7 @@ class BLInputLayerFunction(Function):
output_features = input_features.new() output_features = input_features.new()
ctx.metadata_ = metadata ctx.metadata_ = metadata
ctx.dimension = dimension ctx.dimension = dimension
dim_typed_fn(dimension, input_features, 'BLInputLayer_updateOutput')( sparseconvnet_SCN.BLInputLayer_updateOutput(
metadata, metadata,
spatial_size, spatial_size,
coords, coords,
...@@ -245,10 +240,7 @@ class BLInputLayerFunction(Function): ...@@ -245,10 +240,7 @@ class BLInputLayerFunction(Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
grad_input = grad_output.new() grad_input = grad_output.new()
dim_typed_fn( sparseconvnet_SCN.BLInputLayer_updateGradInput(
ctx.dimension,
grad_output,
'BLInputLayer_updateGradInput')(
ctx.metadata_, ctx.metadata_,
grad_input, grad_input,
grad_output.contiguous()) grad_output.contiguous())
...@@ -265,7 +257,7 @@ class BLOutputLayerFunction(Function): ...@@ -265,7 +257,7 @@ class BLOutputLayerFunction(Function):
output_features = input_features.new() output_features = input_features.new()
ctx.metadata_ = metadata ctx.metadata_ = metadata
ctx.dimension = dimension ctx.dimension = dimension
dim_typed_fn(dimension, input_features, 'BLOutputLayer_updateOutput')( sparseconvnet_SCN.BLOutputLayer_updateOutput(
metadata, metadata,
input_features.contiguous(), input_features.contiguous(),
output_features output_features
...@@ -275,10 +267,7 @@ class BLOutputLayerFunction(Function): ...@@ -275,10 +267,7 @@ class BLOutputLayerFunction(Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
grad_input = grad_output.new() grad_input = grad_output.new()
dim_typed_fn( sparseconvnet_SCN.BLOutputLayer_updateGradInput(
ctx.dimension,
grad_output,
'BLOutputLayer_updateGradInput')(
ctx.metadata_, ctx.metadata_,
grad_input, grad_input,
grad_output.contiguous()) grad_output.contiguous())
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import sparseconvnet_SCN
from torch.autograd import Function from torch.autograd import Function
from torch.nn import Module from torch.nn import Module
from .utils import * from .utils import *
...@@ -26,7 +27,7 @@ class MaxPoolingFunction(Function): ...@@ -26,7 +27,7 @@ class MaxPoolingFunction(Function):
ctx.dimension = dimension ctx.dimension = dimension
ctx.nFeaturesToDrop = nFeaturesToDrop ctx.nFeaturesToDrop = nFeaturesToDrop
output_features = input_features.new() output_features = input_features.new()
dim_typed_fn(dimension, input_features, 'MaxPooling_updateOutput')( sparseconvnet_SCN.MaxPooling_updateOutput(
input_spatial_size, input_spatial_size,
output_spatial_size, output_spatial_size,
pool_size, pool_size,
...@@ -53,8 +54,7 @@ class MaxPoolingFunction(Function): ...@@ -53,8 +54,7 @@ class MaxPoolingFunction(Function):
pool_size,\ pool_size,\
pool_stride = ctx.saved_tensors pool_stride = ctx.saved_tensors
grad_input = grad_output.new() grad_input = grad_output.new()
dim_typed_fn( sparseconvnet_SCN.MaxPooling_updateGradInput(
ctx.dimension, input_features, 'MaxPooling_updateGradInput')(
input_spatial_size, input_spatial_size,
output_spatial_size, output_spatial_size,
pool_size, pool_size,
......
...@@ -11,7 +11,7 @@ all coexist within the same MetaData object as long as each spatial size ...@@ -11,7 +11,7 @@ all coexist within the same MetaData object as long as each spatial size
only occurs once. only occurs once.
""" """
from .utils import dim_fn import sparseconvnet_SCN
def Metadata(dim): def Metadata(dim):
return dim_fn(dim,'Metadata')() return getattr(sparseconvnet_SCN, 'Metadata_%d'%dim)()
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import sparseconvnet import sparseconvnet, sparseconvnet_SCN
from torch.autograd import Function from torch.autograd import Function
from torch.nn import Module, Parameter from torch.nn import Module, Parameter
from .utils import * from .utils import *
...@@ -24,7 +24,7 @@ class NetworkInNetworkFunction(Function): ...@@ -24,7 +24,7 @@ class NetworkInNetworkFunction(Function):
weight, weight,
bias) bias)
sparseconvnet.forward_pass_multiplyAdd_count +=\ sparseconvnet.forward_pass_multiplyAdd_count +=\
typed_fn(input_features, 'NetworkInNetwork_updateOutput')( sparseconvnet_SCN.NetworkInNetwork_updateOutput(
input_features, input_features,
output_features, output_features,
weight, weight,
...@@ -45,11 +45,11 @@ class NetworkInNetworkFunction(Function): ...@@ -45,11 +45,11 @@ class NetworkInNetworkFunction(Function):
grad_bias = None grad_bias = None
else: else:
grad_bias = grad_output.new().resize_as_(bias) grad_bias = grad_output.new().resize_as_(bias)
typed_fn(input_features, 'NetworkInNetwork_updateGradInput')( sparseconvnet_SCN.NetworkInNetwork_updateGradInput(
grad_input, grad_input,
grad_output, grad_output,
weight) weight)
typed_fn(input_features, 'NetworkInNetwork_accGradParameters')( sparseconvnet_SCN.NetworkInNetwork_accGradParameters(
input_features, input_features,
grad_output, grad_output,
grad_weight, grad_weight,
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import sparseconvnet import sparseconvnet, sparseconvnet_SCN
from torch.autograd import Function from torch.autograd import Function
from torch.nn import Module, Parameter from torch.nn import Module, Parameter
from .utils import * from .utils import *
...@@ -104,8 +104,7 @@ class RandomizedStrideConvolutionFunction(Function): ...@@ -104,8 +104,7 @@ class RandomizedStrideConvolutionFunction(Function):
filter_size, filter_size,
filter_stride) filter_stride)
sparseconvnet.forward_pass_multiplyAdd_count +=\ sparseconvnet.forward_pass_multiplyAdd_count +=\
dim_typed_fn( sparseconvnet_SCN.RandomizedStrideConvolution_updateOutput(
dimension, input_features, 'RandomizedStrideConvolution_updateOutput')(
input_spatial_size, input_spatial_size,
output_spatial_size, output_spatial_size,
filter_size, filter_size,
...@@ -125,8 +124,7 @@ class RandomizedStrideConvolutionFunction(Function): ...@@ -125,8 +124,7 @@ class RandomizedStrideConvolutionFunction(Function):
grad_input = grad_output.new() grad_input = grad_output.new()
grad_weight = torch.zeros_like(weight) grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias) grad_bias = torch.zeros_like(bias)
dim_typed_fn( sparseconvnet_SCN.RandomizedStrideConvolution_backward(
ctx.dimension, input_features, 'RandomizedStrideConvolution_backward')(
input_spatial_size, input_spatial_size,
output_spatial_size, output_spatial_size,
filter_size, filter_size,
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import sparseconvnet_SCN
from torch.autograd import Function from torch.autograd import Function
from torch.nn import Module from torch.nn import Module
from .utils import * from .utils import *
...@@ -26,7 +27,7 @@ class RandomizedStrideMaxPoolingFunction(Function): ...@@ -26,7 +27,7 @@ class RandomizedStrideMaxPoolingFunction(Function):
ctx.dimension = dimension ctx.dimension = dimension
ctx.nFeaturesToDrop = nFeaturesToDrop ctx.nFeaturesToDrop = nFeaturesToDrop
output_features = input_features.new() output_features = input_features.new()
dim_typed_fn(dimension, input_features, 'RandomizedStrideMaxPooling_updateOutput')( sparseconvnet_SCN.RandomizedStrideMaxPooling_updateOutput(
input_spatial_size, input_spatial_size,
output_spatial_size, output_spatial_size,
pool_size, pool_size,
...@@ -53,8 +54,7 @@ class RandomizedStrideMaxPoolingFunction(Function): ...@@ -53,8 +54,7 @@ class RandomizedStrideMaxPoolingFunction(Function):
pool_size,\ pool_size,\
pool_stride = ctx.saved_tensors pool_stride = ctx.saved_tensors
grad_input = grad_output.new() grad_input = grad_output.new()
dim_typed_fn( sparseconvnet_SCN.RandomizedStrideMaxPooling_updateGradInput(
ctx.dimension, input_features, 'RandomizedStrideMaxPooling_updateGradInput')(
input_spatial_size, input_spatial_size,
output_spatial_size, output_spatial_size,
pool_size, pool_size,
......
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
import torch import torch
from .utils import dim_fn
from torch.autograd import Variable from torch.autograd import Variable
......
...@@ -15,6 +15,7 @@ Parameters: ...@@ -15,6 +15,7 @@ Parameters:
dimension : of the input field, dimension : of the input field,
""" """
import sparseconvnet_SCN
from torch.autograd import Function from torch.autograd import Function
from torch.nn import Module from torch.nn import Module
from .utils import * from .utils import *
...@@ -34,10 +35,7 @@ class SparseToDenseFunction(Function): ...@@ -34,10 +35,7 @@ class SparseToDenseFunction(Function):
ctx.dimension = dimension ctx.dimension = dimension
ctx.save_for_backward(input_features, spatial_size) ctx.save_for_backward(input_features, spatial_size)
output = input_features.new() output = input_features.new()
dim_typed_fn( sparseconvnet_SCN.SparseToDense_updateOutput(
ctx.dimension,
input_features,
'SparseToDense_updateOutput')(
spatial_size, spatial_size,
input_metadata, input_metadata,
input_features, input_features,
...@@ -49,10 +47,7 @@ class SparseToDenseFunction(Function): ...@@ -49,10 +47,7 @@ class SparseToDenseFunction(Function):
def backward(ctx, grad_output): def backward(ctx, grad_output):
grad_input = grad_output.new() grad_input = grad_output.new()
input_features, spatial_size = ctx.saved_tensors input_features, spatial_size = ctx.saved_tensors
dim_typed_fn( sparseconvnet_SCN.SparseToDense_updateGradInput(
ctx.dimension,
input_features.contiguous(),
'SparseToDense_updateGradInput')(
spatial_size, spatial_size,
ctx.input_metadata, ctx.input_metadata,
input_features, input_features,
......
...@@ -22,15 +22,18 @@ class Sparsify(Module): ...@@ -22,15 +22,18 @@ class Sparsify(Module):
Module.__init__(self) Module.__init__(self)
self.dimension = dimension self.dimension = dimension
def forward(self, input): def forward(self, input):
output = SparseConvNetTensor() if input.features.numel():
output.metadata = Metadata(self.dimension) output = SparseConvNetTensor()
output.spatial_size = input.spatial_size output.metadata = Metadata(self.dimension)
active = input.features[:,0]>0 output.spatial_size = input.spatial_size
output.features=input.features[active] active = input.features[:,0]>0
active=active.type('torch.LongTensor') output.features=input.features[active]
input.metadata.sparsifyMetadata( active=active.type('torch.LongTensor')
output.metadata, input.metadata.sparsifyMetadata(
input.spatial_size, output.metadata,
active.byte(), input.spatial_size,
active.cumsum(0)) active.byte(),
return output active.cumsum(0))
return output
else:
return input
import torch
def spectral_norm(module, n_power_iterations=1, eps=1e-12):
"""
https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/spectral_norm.py
"""
dim=1
torch.nn.utils.SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
return module
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
# 'SubmanifoldConvolution == SubmanifoldConvolution' # 'SubmanifoldConvolution == SubmanifoldConvolution'
import sparseconvnet import sparseconvnet
import sparseconvnet_SCN
from torch.autograd import Function from torch.autograd import Function
from torch.nn import Module, Parameter from torch.nn import Module, Parameter
from .utils import * from .utils import *
...@@ -83,8 +84,7 @@ class SubmanifoldConvolutionFunction(Function): ...@@ -83,8 +84,7 @@ class SubmanifoldConvolutionFunction(Function):
filter_size) filter_size)
sparseconvnet.forward_pass_multiplyAdd_count +=\ sparseconvnet.forward_pass_multiplyAdd_count +=\
dim_typed_fn( sparseconvnet_SCN.SubmanifoldConvolution_updateOutput(
dimension, input_features, 'SubmanifoldConvolution_updateOutput')(
spatial_size, spatial_size,
filter_size, filter_size,
input_metadata, input_metadata,
...@@ -101,8 +101,7 @@ class SubmanifoldConvolutionFunction(Function): ...@@ -101,8 +101,7 @@ class SubmanifoldConvolutionFunction(Function):
grad_input = grad_output.new() grad_input = grad_output.new()
grad_weight = torch.zeros_like(weight) grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias) grad_bias = torch.zeros_like(bias)
dim_typed_fn( sparseconvnet_SCN.SubmanifoldConvolution_backward(
ctx.dimension, input_features, 'SubmanifoldConvolution_backward')(
spatial_size, spatial_size,
filter_size, filter_size,
ctx.input_metadata, ctx.input_metadata,
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import sparseconvnet_SCN
from torch.autograd import Function, Variable from torch.autograd import Function, Variable
from torch.nn import Module from torch.nn import Module
from .utils import * from .utils import *
...@@ -30,7 +31,7 @@ class UnPoolingFunction(Function): ...@@ -30,7 +31,7 @@ class UnPoolingFunction(Function):
ctx.pool_stride = pool_stride ctx.pool_stride = pool_stride
ctx.nFeaturesToDrop = nFeaturesToDrop ctx.nFeaturesToDrop = nFeaturesToDrop
output_features = input_features.new() output_features = input_features.new()
dim_typed_fn(dimension, input_features, 'UnPooling_updateOutput')( sparseconvnet_SCN.UnPooling_updateOutput(
input_spatial_size, input_spatial_size,
output_spatial_size, output_spatial_size,
pool_size, pool_size,
...@@ -44,8 +45,7 @@ class UnPoolingFunction(Function): ...@@ -44,8 +45,7 @@ class UnPoolingFunction(Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
grad_input=Variable(grad_output.data.new()) grad_input=Variable(grad_output.data.new())
dim_typed_fn( sparseconvnet_SCN.UnPooling_updateGradInput(
ctx.dimension, ctx.input_features, 'UnPooling_updateGradInput')(
ctx.input_spatial_size, ctx.input_spatial_size,
ctx.output_spatial_size, ctx.output_spatial_size,
ctx.pool_size, ctx.pool_size,
......
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
import sparseconvnet_SCN as scn from .sparseConvNetTensor import SparseConvNetTensor
from .metadata import Metadata
def toLongTensor(dimension, x): def toLongTensor(dimension, x):
if hasattr(x, 'type') and x.type() == 'torch.LongTensor': if hasattr(x, 'type') and x.type() == 'torch.LongTensor':
...@@ -17,32 +18,14 @@ def toLongTensor(dimension, x): ...@@ -17,32 +18,14 @@ def toLongTensor(dimension, x):
return torch.LongTensor(dimension).fill_(x) return torch.LongTensor(dimension).fill_(x)
typeTable = {
'torch.FloatTensor': 'cpu_float_',
'torch.DoubleTensor': 'cpu_double_',
'torch.cuda.FloatTensor': 'cuda_float_'}
def dim_fn(dimension, name):
f=getattr(scn, name + '_' + str(dimension))
return f
def typed_fn(t, name):
f=getattr(scn, typeTable[t.type()] + name)
return f
def dim_typed_fn(dimension, t, name):
f=getattr(scn, typeTable[t.type()] + name + '_' + str(dimension))
return f
def optionalTensor(a, b): def optionalTensor(a, b):
return getattr(a, b) if hasattr(a, b) else torch.Tensor() return getattr(a, b) if hasattr(a, b) else torch.Tensor()
def optionalTensorReturn(a): def optionalTensorReturn(a):
return a if a.numel() else None return a if a.numel() else None
def threadDatasetIterator(d): def threadDatasetIterator(d):
try: try:
import queue import queue
...@@ -58,9 +41,21 @@ def threadDatasetIterator(d): ...@@ -58,9 +41,21 @@ def threadDatasetIterator(d):
for i in range(8): for i in range(8):
t = threading.Thread(target=worker, args=(i,)) t = threading.Thread(target=worker, args=(i,))
t.start() t.start()
for i in range(len(d)): for _ in range(len(d)):
item = q.get() item = q.get()
yield item yield item
q.task_done() q.task_done()
q.join() q.join()
return iterator return iterator
def appendSparseConvTensors(tensors):
spatial_size=tensors[0].spatial_size
dimension=len(spatial_size)
x=SparseConvNetTensor(
features=torch.cat([t.features for t in features],0),
metadata=Metadata(dimension),
spatial_size=spatial_size)
for t in tensors:
x.metadata.appendMetadata(t.metadata,spatial_size)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment