Commit 2c4ed608 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

Goodbye THNN. Hello ATen!

parent 6d4475db
...@@ -39,7 +39,7 @@ class SparseToDenseFunction(Function): ...@@ -39,7 +39,7 @@ class SparseToDenseFunction(Function):
input_features, input_features,
'SparseToDense_updateOutput')( 'SparseToDense_updateOutput')(
spatial_size, spatial_size,
input_metadata.ffi, input_metadata,
input_features, input_features,
output, output,
nPlanes) nPlanes)
...@@ -54,7 +54,7 @@ class SparseToDenseFunction(Function): ...@@ -54,7 +54,7 @@ class SparseToDenseFunction(Function):
input_features.contiguous(), input_features.contiguous(),
'SparseToDense_updateGradInput')( 'SparseToDense_updateGradInput')(
spatial_size, spatial_size,
ctx.input_metadata.ffi, ctx.input_metadata,
input_features, input_features,
grad_input, grad_input,
grad_output.contiguous()) grad_output.contiguous())
......
...@@ -29,8 +29,8 @@ class Sparsify(Module): ...@@ -29,8 +29,8 @@ class Sparsify(Module):
output.features=input.features[active] output.features=input.features[active]
active=active.type('torch.LongTensor') active=active.type('torch.LongTensor')
dim_fn(self.dimension, 'sparsifyMetadata')( dim_fn(self.dimension, 'sparsifyMetadata')(
input.metadata.ffi, input.metadata,
output.metadata.ffi, output.metadata,
input.spatial_size, input.spatial_size,
active.byte(), active.byte(),
active.cumsum(0)) active.cumsum(0))
......
...@@ -22,12 +22,10 @@ class SubmanifoldConvolution(Module): ...@@ -22,12 +22,10 @@ class SubmanifoldConvolution(Module):
self.filter_volume = self.filter_size.prod().item() self.filter_volume = self.filter_size.prod().item()
std = (2.0 / nIn / self.filter_volume)**0.5 std = (2.0 / nIn / self.filter_volume)**0.5
self.weight = Parameter(torch.Tensor( self.weight = Parameter(torch.Tensor(
nIn * self.filter_volume, nOut self.filter_volume, nIn, nOut
).normal_(0, std)) ).normal_(0, std))
if bias: if bias:
self.bias = Parameter(torch.Tensor(nOut).zero_()) self.bias = Parameter(torch.Tensor(nOut).zero_())
else:
self.bias = None
def forward(self, input): def forward(self, input):
assert input.features.nelement() == 0 or input.features.size(1) == self.nIn assert input.features.nelement() == 0 or input.features.size(1) == self.nIn
...@@ -37,7 +35,7 @@ class SubmanifoldConvolution(Module): ...@@ -37,7 +35,7 @@ class SubmanifoldConvolution(Module):
output.features = SubmanifoldConvolutionFunction.apply( output.features = SubmanifoldConvolutionFunction.apply(
input.features, input.features,
self.weight, self.weight,
self.bias, optionalTensor(self, 'bias'),
input.metadata, input.metadata,
input.spatial_size, input.spatial_size,
self.dimension, self.dimension,
...@@ -83,18 +81,17 @@ class SubmanifoldConvolutionFunction(Function): ...@@ -83,18 +81,17 @@ class SubmanifoldConvolutionFunction(Function):
weight, weight,
bias, bias,
filter_size) filter_size)
sparseconvnet.forward_pass_multiplyAdd_count +=\ sparseconvnet.forward_pass_multiplyAdd_count +=\
dim_typed_fn( dim_typed_fn(
dimension, input_features, 'SubmanifoldConvolution_updateOutput')( dimension, input_features, 'SubmanifoldConvolution_updateOutput')(
spatial_size, spatial_size,
filter_size, filter_size,
input_metadata.ffi, input_metadata,
input_features, input_features,
output_features, output_features,
weight, weight,
bias if bias is not None else nullptr, bias)
0, # remove this parameter!!
)
sparseconvnet.forward_pass_hidden_states += output_features.nelement() sparseconvnet.forward_pass_hidden_states += output_features.nelement()
return output_features return output_features
...@@ -102,22 +99,17 @@ class SubmanifoldConvolutionFunction(Function): ...@@ -102,22 +99,17 @@ class SubmanifoldConvolutionFunction(Function):
def backward(ctx, grad_output): def backward(ctx, grad_output):
input_features, spatial_size, weight, bias, filter_size = ctx.saved_tensors input_features, spatial_size, weight, bias, filter_size = ctx.saved_tensors
grad_input = grad_output.new() grad_input = grad_output.new()
grad_weight = grad_output.new().resize_as_(weight).zero_() grad_weight = torch.zeros_like(weight)
if bias is None: grad_bias = torch.zeros_like(bias)
grad_bias = None
else:
grad_bias = grad_output.new().resize_as_(bias).zero_()
dim_typed_fn( dim_typed_fn(
ctx.dimension, input_features, 'SubmanifoldConvolution_backward')( ctx.dimension, input_features, 'SubmanifoldConvolution_backward')(
spatial_size, spatial_size,
filter_size, filter_size,
ctx.input_metadata.ffi, ctx.input_metadata,
input_features, input_features,
grad_input, grad_input,
grad_output.contiguous(), grad_output.contiguous(),
weight, weight,
grad_weight, grad_weight,
grad_bias if grad_bias is not None else nullptr, grad_bias)
0, # remove this parameter return grad_input, grad_weight, optionalTensorReturn(grad_bias), None, None, None, None
)
return grad_input, grad_weight, grad_bias, None, None, None, None
...@@ -35,7 +35,7 @@ class UnPoolingFunction(Function): ...@@ -35,7 +35,7 @@ class UnPoolingFunction(Function):
output_spatial_size, output_spatial_size,
pool_size, pool_size,
pool_stride, pool_stride,
input_metadata.ffi, input_metadata,
input_features, input_features,
output_features, output_features,
nFeaturesToDrop) nFeaturesToDrop)
...@@ -50,7 +50,7 @@ class UnPoolingFunction(Function): ...@@ -50,7 +50,7 @@ class UnPoolingFunction(Function):
ctx.output_spatial_size, ctx.output_spatial_size,
ctx.pool_size, ctx.pool_size,
ctx.pool_stride, ctx.pool_stride,
ctx.input_metadata.ffi, ctx.input_metadata,
ctx.input_features, ctx.input_features,
grad_input.data, grad_input.data,
grad_output.data.contiguous(), grad_output.data.contiguous(),
......
...@@ -5,9 +5,7 @@ ...@@ -5,9 +5,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
import sparseconvnet.SCN as scn import sparseconvnet_SCN as scn
from cffi import FFI
def toLongTensor(dimension, x): def toLongTensor(dimension, x):
if hasattr(x, 'type') and x.type() == 'torch.LongTensor': if hasattr(x, 'type') and x.type() == 'torch.LongTensor':
...@@ -20,36 +18,30 @@ def toLongTensor(dimension, x): ...@@ -20,36 +18,30 @@ def toLongTensor(dimension, x):
typeTable = { typeTable = {
'torch.FloatTensor': 'cpu_float', 'torch.FloatTensor': 'cpu_float_',
'torch.DoubleTensor': 'cpu_double', 'torch.DoubleTensor': 'cpu_double_',
'torch.cuda.FloatTensor': 'gpu_float'} 'torch.cuda.FloatTensor': 'cuda_float_'}
def dim_fn(dimension, name): def dim_fn(dimension, name):
# print('dim_fn',dimension,name) f=getattr(scn, name + '_' + str(dimension))
return getattr(scn, 'scn_' + str(dimension) + '_' + name) return f
def typed_fn(t, name): def typed_fn(t, name):
# print('typed_fn',t.type(),name) f=getattr(scn, typeTable[t.type()] + name)
return getattr(scn, 'scn_' + typeTable[t.type()] + '_' + name) return f
def dim_typed_fn(dimension, t, name): def dim_typed_fn(dimension, t, name):
# print('dim_typed_fn',dimension,t.type(),name) f=getattr(scn, typeTable[t.type()] + name + '_' + str(dimension))
return getattr(scn, 'scn_' + return f
typeTable[t.type()] +
str(dimension) +
name)
ffi = FFI()
nullptr = ffi.NULL
def optionalTensor(a, b): def optionalTensor(a, b):
return getattr(a, b) if hasattr(a, b) else nullptr return getattr(a, b) if hasattr(a, b) else torch.Tensor()
def optionalTensorReturn(a):
return a if a.numel() else None
def threadDatasetIterator(d): def threadDatasetIterator(d):
try: try:
...@@ -72,18 +64,3 @@ def threadDatasetIterator(d): ...@@ -72,18 +64,3 @@ def threadDatasetIterator(d):
q.task_done() q.task_done()
q.join() q.join()
return iterator return iterator
# def threadDatasetIterator(d):
# print('not threads!!!')
# def iterator():
# for x in d:
# yield x
# return iterator
def set(obj):
if hasattr(obj, 'storage_type'):
obj.set_(obj.storage_type()())
else:
obj.set_()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment