Commit 2c4ed608 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

Goodbye THNN. Hello ATen!

parent 6d4475db
......@@ -39,7 +39,7 @@ class SparseToDenseFunction(Function):
input_features,
'SparseToDense_updateOutput')(
spatial_size,
input_metadata.ffi,
input_metadata,
input_features,
output,
nPlanes)
......@@ -54,7 +54,7 @@ class SparseToDenseFunction(Function):
input_features.contiguous(),
'SparseToDense_updateGradInput')(
spatial_size,
ctx.input_metadata.ffi,
ctx.input_metadata,
input_features,
grad_input,
grad_output.contiguous())
......
......@@ -29,8 +29,8 @@ class Sparsify(Module):
output.features=input.features[active]
active=active.type('torch.LongTensor')
dim_fn(self.dimension, 'sparsifyMetadata')(
input.metadata.ffi,
output.metadata.ffi,
input.metadata,
output.metadata,
input.spatial_size,
active.byte(),
active.cumsum(0))
......
......@@ -22,12 +22,10 @@ class SubmanifoldConvolution(Module):
self.filter_volume = self.filter_size.prod().item()
std = (2.0 / nIn / self.filter_volume)**0.5
self.weight = Parameter(torch.Tensor(
nIn * self.filter_volume, nOut
self.filter_volume, nIn, nOut
).normal_(0, std))
if bias:
self.bias = Parameter(torch.Tensor(nOut).zero_())
else:
self.bias = None
def forward(self, input):
assert input.features.nelement() == 0 or input.features.size(1) == self.nIn
......@@ -37,7 +35,7 @@ class SubmanifoldConvolution(Module):
output.features = SubmanifoldConvolutionFunction.apply(
input.features,
self.weight,
self.bias,
optionalTensor(self, 'bias'),
input.metadata,
input.spatial_size,
self.dimension,
......@@ -83,18 +81,17 @@ class SubmanifoldConvolutionFunction(Function):
weight,
bias,
filter_size)
sparseconvnet.forward_pass_multiplyAdd_count +=\
dim_typed_fn(
dimension, input_features, 'SubmanifoldConvolution_updateOutput')(
spatial_size,
filter_size,
input_metadata.ffi,
input_metadata,
input_features,
output_features,
weight,
bias if bias is not None else nullptr,
0, # remove this parameter!!
)
bias)
sparseconvnet.forward_pass_hidden_states += output_features.nelement()
return output_features
......@@ -102,22 +99,17 @@ class SubmanifoldConvolutionFunction(Function):
def backward(ctx, grad_output):
input_features, spatial_size, weight, bias, filter_size = ctx.saved_tensors
grad_input = grad_output.new()
grad_weight = grad_output.new().resize_as_(weight).zero_()
if bias is None:
grad_bias = None
else:
grad_bias = grad_output.new().resize_as_(bias).zero_()
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
dim_typed_fn(
ctx.dimension, input_features, 'SubmanifoldConvolution_backward')(
spatial_size,
filter_size,
ctx.input_metadata.ffi,
ctx.input_metadata,
input_features,
grad_input,
grad_output.contiguous(),
weight,
grad_weight,
grad_bias if grad_bias is not None else nullptr,
0, # remove this parameter
)
return grad_input, grad_weight, grad_bias, None, None, None, None
grad_bias)
return grad_input, grad_weight, optionalTensorReturn(grad_bias), None, None, None, None
......@@ -35,7 +35,7 @@ class UnPoolingFunction(Function):
output_spatial_size,
pool_size,
pool_stride,
input_metadata.ffi,
input_metadata,
input_features,
output_features,
nFeaturesToDrop)
......@@ -50,7 +50,7 @@ class UnPoolingFunction(Function):
ctx.output_spatial_size,
ctx.pool_size,
ctx.pool_stride,
ctx.input_metadata.ffi,
ctx.input_metadata,
ctx.input_features,
grad_input.data,
grad_output.data.contiguous(),
......
......@@ -5,9 +5,7 @@
# LICENSE file in the root directory of this source tree.
import torch
import sparseconvnet.SCN as scn
from cffi import FFI
import sparseconvnet_SCN as scn
def toLongTensor(dimension, x):
if hasattr(x, 'type') and x.type() == 'torch.LongTensor':
......@@ -20,36 +18,30 @@ def toLongTensor(dimension, x):
typeTable = {
'torch.FloatTensor': 'cpu_float',
'torch.DoubleTensor': 'cpu_double',
'torch.cuda.FloatTensor': 'gpu_float'}
'torch.FloatTensor': 'cpu_float_',
'torch.DoubleTensor': 'cpu_double_',
'torch.cuda.FloatTensor': 'cuda_float_'}
def dim_fn(dimension, name):
# print('dim_fn',dimension,name)
return getattr(scn, 'scn_' + str(dimension) + '_' + name)
f=getattr(scn, name + '_' + str(dimension))
return f
def typed_fn(t, name):
# print('typed_fn',t.type(),name)
return getattr(scn, 'scn_' + typeTable[t.type()] + '_' + name)
f=getattr(scn, typeTable[t.type()] + name)
return f
def dim_typed_fn(dimension, t, name):
# print('dim_typed_fn',dimension,t.type(),name)
return getattr(scn, 'scn_' +
typeTable[t.type()] +
str(dimension) +
name)
ffi = FFI()
nullptr = ffi.NULL
f=getattr(scn, typeTable[t.type()] + name + '_' + str(dimension))
return f
def optionalTensor(a, b):
return getattr(a, b) if hasattr(a, b) else nullptr
return getattr(a, b) if hasattr(a, b) else torch.Tensor()
def optionalTensorReturn(a):
return a if a.numel() else None
def threadDatasetIterator(d):
try:
......@@ -72,18 +64,3 @@ def threadDatasetIterator(d):
q.task_done()
q.join()
return iterator
# def threadDatasetIterator(d):
# print('not threads!!!')
# def iterator():
# for x in d:
# yield x
# return iterator
def set(obj):
if hasattr(obj, 'storage_type'):
obj.set_(obj.storage_type()())
else:
obj.set_()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment