Commit 1df7b845 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

3d segmantation

parent f2e3800b
......@@ -30,7 +30,7 @@ class SubmanifoldConvolution(Module):
self.bias = None
def forward(self, input):
assert input.features.ndimension() == 0 or input.features.size(1) == self.nIn
assert input.features.nelement() == 0 or input.features.size(1) == self.nIn
output = SparseConvNetTensor()
output.metadata = input.metadata
output.spatial_size = input.spatial_size
......@@ -89,12 +89,12 @@ class SubmanifoldConvolutionFunction(Function):
spatial_size,
filter_size,
input_metadata.ffi,
input_features.data,
output_features.data,
weight.data,
bias.data if bias is not None else nullptr,
input_features,
output_features,
weight,
bias if bias is not None else nullptr,
0, # remove this parameter!!
torch.cuda.IntTensor() if input_features.is_cuda else nullptr)
)
sparseconvnet.forward_pass_hidden_states += output_features.nelement()
return output_features
......@@ -117,7 +117,7 @@ class SubmanifoldConvolutionFunction(Function):
grad_output.contiguous(),
weight,
grad_weight,
grad_bias.data if grad_bias is not None else nullptr,
grad_bias if grad_bias is not None else nullptr,
0, # remove this parameter
torch.cuda.IntTensor() if input_features.is_cuda else nullptr)
)
return grad_input, grad_weight, grad_bias, None, None, None, None
# Copyright 2016-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from torch.autograd import Function, Variable
from torch.nn import Module
from .utils import *
from .sparseConvNetTensor import SparseConvNetTensor
class UnPoolingFunction(Function):
@staticmethod
def forward(
ctx,
input_features,
input_metadata,
input_spatial_size,
output_spatial_size,
dimension,
pool_size,
pool_stride,
nFeaturesToDrop):
ctx.input_features=input_features
ctx.input_metadata=input_metadata
ctx.input_spatial_size = input_spatial_size
ctx.output_spatial_size = output_spatial_size
ctx.dimension = dimension
ctx.pool_size = pool_size
ctx.pool_stride = pool_stride
ctx.nFeaturesToDrop = nFeaturesToDrop
output_features = input_features.new()
dim_typed_fn(dimension, input_features, 'UnPooling_updateOutput')(
input_spatial_size,
output_spatial_size,
pool_size,
pool_stride,
input_metadata.ffi,
input_features,
output_features,
nFeaturesToDrop)
return output_features
@staticmethod
def backward(ctx, grad_output):
grad_input=Variable(grad_output.data.new())
dim_typed_fn(
ctx.dimension, ctx.input_features, 'UnPooling_updateGradInput')(
ctx.input_spatial_size,
ctx.output_spatial_size,
ctx.pool_size,
ctx.pool_stride,
ctx.input_metadata.ffi,
ctx.input_features,
grad_input.data,
grad_output.data.contiguous(),
ctx.nFeaturesToDrop)
return grad_input, None, None, None, None, None, None, None
class UnPooling(Module):
def __init__(self, dimension, pool_size, pool_stride, nFeaturesToDrop=0):
super(UnPooling, self).__init__()
self.dimension = dimension
self.pool_size = toLongTensor(dimension, pool_size)
self.pool_stride = toLongTensor(dimension, pool_stride)
self.nFeaturesToDrop = nFeaturesToDrop
def forward(self, input):
output = SparseConvNetTensor()
output.metadata = input.metadata
output.spatial_size =\
(input.spatial_size - 1) * self.pool_stride + self.pool_size
output.features = UnPoolingFunction().apply(
input.features, input.metadata, input.spatial_size,
output.spatial_size, self.dimension,self.pool_size,self.pool_stride,
self.nFeaturesToDrop)
return output
def input_spatial_size(self, out_size):
return (out_size - 1) * self.pool_stride + self.pool_size
def __repr__(self):
s = 'UnPooling'
if self.pool_size.max() == self.pool_size.min() and\
self.pool_stride.max() == self.pool_stride.min():
s = s + str(self.pool_size[0].item()) + '/' + str(self.pool_stride[0].item())
else:
s = s + '(' + str(self.pool_size[0].item())
for i in self.pool_size[1:]:
s = s + ',' + str(i)
s = s + ')/(' + str(self.pool_stride[0].item())
for i in self.pool_stride[1:]:
s = s + ',' + str(i)
s = s + ')'
if self.nFeaturesToDrop > 0:
s = s + ' nFeaturesToDrop = ' + self.nFeaturesToDrop
return s
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment