Commit cd155261 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

unpool

parent 40fcc1e1
......@@ -84,3 +84,38 @@ void cpu_AveragePooling_updateGradInput(
_rules.size());
}
}
template <typename T>
void cpu_CopyFeaturesHelper_updateOutput(at::Tensor rules, at::Tensor context,
at::Tensor Context) {
Int nHot = rules.size(0) / 2;
Int nPlanes = context.size(1);
auto iF = context.data<T>();
auto oF = Context.data<T>();
auto r = rules.data<Int>();
Int outSite;
#pragma omp parallel for private(outSite)
for (outSite = 0; outSite < nHot; outSite++) {
Int i = r[2 * outSite + 1] * nPlanes;
Int o = r[2 * outSite] * nPlanes;
std::memcpy(oF + o, iF + i, nPlanes * sizeof(T));
}
}
template <typename T>
void cpu_CopyFeaturesHelper_updateGradInput(at::Tensor rules,
at::Tensor dcontext,
at::Tensor dContext) {
Int nHot = rules.size(0) / 2;
Int nPlanes = dcontext.size(1);
auto iF = dcontext.data<T>();
auto oF = dContext.data<T>();
auto r = rules.data<Int>();
Int outSite;
#pragma omp parallel for private(outSite)
for (outSite = 0; outSite < nHot; outSite++) {
Int i = r[2 * outSite + 1] * nPlanes;
Int o = r[2 * outSite] * nPlanes;
for (Int plane = 0; plane < nPlanes; plane++)
iF[i + plane] = oF[o + plane];
}
}
......@@ -60,22 +60,19 @@ void cpu_UnPooling_updateGradInput(
/*long*/ at::Tensor inputSize, /*long*/ at::Tensor outputSize,
/*long*/ at::Tensor poolSize,
/*long*/ at::Tensor poolStride, Metadata<Dimension> &m,
/*float*/ at::Tensor input_features,
/*float*/ at::Tensor d_input_features,
/*float*/ at::Tensor d_output_features, long nFeaturesToDrop) {
Int nPlanes = input_features.size(1) - nFeaturesToDrop;
Int nPlanes = d_input_features.size(1) - nFeaturesToDrop;
auto _rules =
m.getRuleBook(outputSize, inputSize, poolSize, poolStride, true);
d_input_features.resize_as_(input_features);
d_input_features.zero_();
auto diF = d_input_features.data<T>() + nFeaturesToDrop;
auto doF = d_output_features.data<T>();
for (auto &r : _rules) {
Int nHot = r.size() / 2;
UnPooling_BackwardPass<T>(diF, doF, nPlanes, input_features.size(1),
UnPooling_BackwardPass<T>(diF, doF, nPlanes, d_input_features.size(1),
d_output_features.size(1), &r[0], nHot);
}
}
......@@ -59,3 +59,33 @@ void cuda_AveragePooling_updateGradInput(
d_output_features.size(1), _rules,
_rules.size());
}
template <typename T>
void cuda_CopyFeaturesHelper_ForwardPass(T *input_features, T *output_features,
Int *rules, Int nPlanes, Int nHot);
template <typename T>
void cuda_CopyFeaturesHelper_BackwardPass(T *d_input_features,
T *d_output_features, Int *rules,
Int nPlanes, Int nHot);
template <typename T>
void cuda_CopyFeaturesHelper_updateOutput(at::Tensor rules, at::Tensor context,
at::Tensor Context) {
Int nPlanes = context.size(1);
Int nHot = rules.size(0) / 2;
cuda_CopyFeaturesHelper_ForwardPass<T>(context.data<T>(), Context.data<T>(),
rules.data<Int>(), nPlanes, nHot);
}
template <typename T>
void cuda_CopyFeaturesHelper_updateGradInput(at::Tensor rules,
at::Tensor dcontext,
at::Tensor dContext) {
Int nPlanes = dcontext.size(1);
Int nHot = rules.size(0) / 2;
cuda_CopyFeaturesHelper_BackwardPass<T>(
dcontext.data<T>(), dContext.data<T>(), rules.data<Int>(), nPlanes, nHot);
}
......@@ -75,3 +75,71 @@ void cuda_AveragePooling_BackwardPass(T *d_input_features, T *d_output_features,
rbB, nHotB, 1.0 / filterVolume));
, )
}
// NTX must be >=2 so r is filled properly
template <typename T, Int NTX, Int NTY>
__global__ void CopyFeaturesHelper_fp(T *input_features, T *output_features, Int * rules,
Int nPlanes, Int nHot) {
__shared__ Int r[NTY * 2];
for (Int n = blockIdx.x * NTY; n < nHot; n += gridDim.x * NTY) {
{
Int i = threadIdx.x + NTX * threadIdx.y;
if (i < NTY * 2 and i < 2 * (nHot - n))
r[i] = rules[2 * n + i];
}
__syncthreads();
if (n + threadIdx.y < nHot) {
Int i = r[2 * threadIdx.y+1] * nPlanes;
Int o = r[2 * threadIdx.y ] * nPlanes;
for (Int plane = threadIdx.x; plane < nPlanes; plane += NTX)
output_features[o + plane]= input_features[i + plane];
}
__syncthreads();
}
}
template <typename T>
void cuda_CopyFeaturesHelper_ForwardPass(T *input_features, T *output_features, Int* rules,
Int nPlanes, Int nHot) {
CopyFeaturesHelper_fp<T, 32, 32><<<32, dim3(32, 32)>>>(
input_features, output_features, rules, nPlanes,
nHot);
}
template <typename T, Int NTX, Int NTY>
__global__ void CopyFeaturesHelper_bp(T *d_input_features, T *d_output_features, Int* rules,
Int nPlanes,Int nHot) {
__shared__ Int r[NTY * 2];
for (Int n = blockIdx.x * NTY; n < nHot; n += gridDim.x * NTY) {
{
Int i = threadIdx.x + NTX * threadIdx.y;
if (i < NTY * 2 and i < 2 * (nHot - n))
r[i] = rules[2 * n + i];
}
__syncthreads();
if (n + threadIdx.y < nHot) {
Int i = r[2 * threadIdx.y+1] * nPlanes;
Int o = r[2 * threadIdx.y] * nPlanes;
for (Int plane = threadIdx.x; plane < nPlanes; plane += NTX)
d_input_features[i + plane] = d_output_features[o + plane];
}
__syncthreads();
}
}
template <typename T>
void cuda_CopyFeaturesHelper_BackwardPass(T *d_input_features, T *d_output_features,
Int* rules, Int nPlanes, Int nHot) {
CopyFeaturesHelper_bp<T, 32, 32><<<32, dim3(32, 32)>>>(
d_input_features, d_output_features, rules, nPlanes, nHot);
}
......@@ -40,19 +40,16 @@ void cuda_UnPooling_updateGradInput(
/*long*/ at::Tensor inputSize, /*long*/ at::Tensor outputSize,
/*long*/ at::Tensor poolSize,
/*long*/ at::Tensor poolStride, Metadata<Dimension> &m,
/*cuda float*/ at::Tensor input_features,
/*cuda float*/ at::Tensor d_input_features,
/*cuda float*/ at::Tensor d_output_features, long nFeaturesToDrop) {
Int nPlanes = input_features.size(1) - nFeaturesToDrop;
Int nPlanes = d_input_features.size(1) - nFeaturesToDrop;
auto _rules =
m.getRuleBook(outputSize, inputSize, poolSize, poolStride, true);
d_input_features.resize_as_(input_features);
d_input_features.zero_();
auto diF = d_input_features.data<T>() + nFeaturesToDrop;
auto doF = d_output_features.data<T>();
cuda_UnPooling_BackwardPass<T>(diF, doF, nPlanes, input_features.size(1),
cuda_UnPooling_BackwardPass<T>(diF, doF, nPlanes, d_input_features.size(1),
d_output_features.size(1), _rules);
}
......@@ -620,6 +620,29 @@ at::Tensor vvl2t_(std::vector<std::vector<Int>> v) {
return t;
}
template <Int dimension>
at::Tensor
Metadata<dimension>::copyFeaturesHelper(Metadata<dimension> &mR,
/* long */ at::Tensor spatialSize) {
auto p = LongTensorToPoint<dimension>(spatialSize);
auto &sgsL = grids[p];
auto &sgsR = mR.grids[p];
Int bs = sgsL.size(), sample;
std::vector<std::vector<Int>> r(bs);
#pragma omp parallel for private(sample)
for (sample = 0; sample < bs; ++sample) {
auto &sgL = sgsL[sample];
auto &sgR = sgsR[sample];
auto &rs = r[sample];
for (auto const &iter : sgL.mp) {
if (sgR.mp.find(iter.first) != sgR.mp.end()) {
rs.push_back(sgL.mp[iter.first] + sgL.ctr);
rs.push_back(sgR.mp[iter.first] + sgR.ctr);
}
}
}
return vvl2t_(r);
}
template <Int dimension> Int volume(long *point) {
Int v = 1;
for (Int i = 0; i < dimension; i++)
......
......@@ -158,6 +158,8 @@ public:
std::vector<at::Tensor>
compareSparseHelper(Metadata<dimension> &mR,
/* long */ at::Tensor spatialSize);
at::Tensor copyFeaturesHelper(Metadata<dimension> &mR,
/* long */ at::Tensor spatialSize);
};
template <typename T> T *OptionalTensorData(at::Tensor tensor);
......
......@@ -134,3 +134,8 @@ template void bmd_f<float>(float *input_features, float *output_features,
template void bmd_b<float>(float *input_features, float *d_input_features,
float *d_output_features, float *noise, Int nActive,
Int nPlanes, float alpha);
template void cuda_CopyFeaturesHelper_ForwardPass<float>(
float* context, float* Context,Int* rules, Int nPlanes, Int nHot);
template void cuda_CopyFeaturesHelper_BackwardPass<float>(
float* dcontext, float* dContext,Int* rules, Int nPlanes, Int nHot);
\ No newline at end of file
......@@ -28,7 +28,8 @@ template <Int Dimension> void dimension(py::module &m, const char *name) {
&Metadata<Dimension>::addSampleFromThresholdedTensor)
.def("generateRuleBooks3s2", &Metadata<Dimension>::generateRuleBooks3s2)
.def("generateRuleBooks2s2", &Metadata<Dimension>::generateRuleBooks2s2)
.def("compareSparseHelper", &Metadata<Dimension>::compareSparseHelper);
.def("compareSparseHelper", &Metadata<Dimension>::compareSparseHelper)
.def("copyFeaturesHelper", &Metadata<Dimension>::copyFeaturesHelper);
m.def("ActivePooling_updateOutput",
(void (*)(at::Tensor, Metadata<Dimension> &, at::Tensor, at::Tensor,
bool)) &
......@@ -191,8 +192,7 @@ template <Int Dimension> void dimension(py::module &m, const char *name) {
"");
m.def("UnPooling_updateGradInput",
(void (*)(at::Tensor, at::Tensor, at::Tensor, at::Tensor,
Metadata<Dimension> &, at::Tensor, at::Tensor, at::Tensor,
long)) &
Metadata<Dimension> &, at::Tensor, at::Tensor, long)) &
UnPooling_updateGradInput,
"");
}
......@@ -226,6 +226,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"");
m.def("NetworkInNetwork_accGradParameters",
&NetworkInNetwork_accGradParameters, "");
m.def("CopyFeaturesHelper_updateOutput", &CopyFeaturesHelper_updateOutput,
"");
m.def("CopyFeaturesHelper_updateGradInput",
&CopyFeaturesHelper_updateGradInput, "");
m.def("n_rulebook_bits", []() { return 8 * sizeof(Int); }, "");
}
......@@ -229,7 +229,11 @@ template <Int Dimension>
void UnPooling_updateGradInput(at::Tensor inputSize, at::Tensor outputSize,
at::Tensor poolSize, at::Tensor poolStride,
Metadata<Dimension> &m,
at::Tensor input_features,
at::Tensor d_input_features,
at::Tensor d_output_features,
long nFeaturesToDrop);
void CopyFeaturesHelper_updateOutput(at::Tensor rules, at::Tensor context,
at::Tensor Context);
void CopyFeaturesHelper_updateGradInput(at::Tensor rules, at::Tensor dcontext,
at::Tensor dContext);
......@@ -408,13 +408,12 @@ template <Int Dimension>
void UnPooling_updateGradInput(at::Tensor inputSize, at::Tensor outputSize,
at::Tensor poolSize, at::Tensor poolStride,
Metadata<Dimension> &m,
at::Tensor input_features,
at::Tensor d_input_features,
at::Tensor d_output_features,
long nFeaturesToDrop) {
cpu_UnPooling_updateGradInput<float, Dimension>(
inputSize, outputSize, poolSize, poolStride, m, input_features,
d_input_features, d_output_features, nFeaturesToDrop);
inputSize, outputSize, poolSize, poolStride, m, d_input_features,
d_output_features, nFeaturesToDrop);
}
#define FOO \
......@@ -560,8 +559,8 @@ void UnPooling_updateGradInput(at::Tensor inputSize, at::Tensor outputSize,
template void UnPooling_updateGradInput<DIMENSION>( \
at::Tensor inputSize, at::Tensor outputSize, at::Tensor poolSize, \
at::Tensor poolStride, Metadata<DIMENSION> & m, \
at::Tensor input_features, at::Tensor d_input_features, \
at::Tensor d_output_features, long nFeaturesToDrop);
at::Tensor d_input_features, at::Tensor d_output_features, \
long nFeaturesToDrop);
#define DIMENSION 1
FOO;
......@@ -581,3 +580,14 @@ FOO;
#define DIMENSION 6
FOO;
#undef DIMENSION
at::Tensor CopyFeaturesHelper_updateOutput(at::Tensor rules, at::Tensor context,
at::Tensor Context) {
return cpu_CopyFeaturesHelper_updateOutput<float>(rules, context, Context);
}
at::Tensor CopyFeaturesHelper_updateGradInput(at::Tensor rules,
at::Tensor dcontext,
at::Tensor dContext) {
return cpu_CopyFeaturesHelper_updateGradInput<float>(rules, dcontext,
dContext);
}
......@@ -619,18 +619,17 @@ template <Int Dimension>
void UnPooling_updateGradInput(at::Tensor inputSize, at::Tensor outputSize,
at::Tensor poolSize, at::Tensor poolStride,
Metadata<Dimension> &m,
at::Tensor input_features,
at::Tensor d_input_features,
at::Tensor d_output_features,
long nFeaturesToDrop) {
if (d_output_features.type().is_cuda())
cuda_UnPooling_updateGradInput<float, Dimension>(
inputSize, outputSize, poolSize, poolStride, m, input_features,
d_input_features, d_output_features, nFeaturesToDrop);
inputSize, outputSize, poolSize, poolStride, m, d_input_features,
d_output_features, nFeaturesToDrop);
else
cpu_UnPooling_updateGradInput<float, Dimension>(
inputSize, outputSize, poolSize, poolStride, m, input_features,
d_input_features, d_output_features, nFeaturesToDrop);
inputSize, outputSize, poolSize, poolStride, m, d_input_features,
d_output_features, nFeaturesToDrop);
}
#define FOO \
......@@ -776,8 +775,8 @@ void UnPooling_updateGradInput(at::Tensor inputSize, at::Tensor outputSize,
template void UnPooling_updateGradInput<DIMENSION>( \
at::Tensor inputSize, at::Tensor outputSize, at::Tensor poolSize, \
at::Tensor poolStride, Metadata<DIMENSION> & m, \
at::Tensor input_features, at::Tensor d_input_features, \
at::Tensor d_output_features, long nFeaturesToDrop);
at::Tensor d_input_features, at::Tensor d_output_features, \
long nFeaturesToDrop);
#define DIMENSION 1
FOO;
......@@ -797,3 +796,18 @@ FOO;
#define DIMENSION 6
FOO;
#undef DIMENSION
void CopyFeaturesHelper_updateOutput(at::Tensor rules, at::Tensor context,
at::Tensor Context) {
if (context.is_cuda())
cuda_CopyFeaturesHelper_updateOutput<float>(rules, context, Context);
else
cpu_CopyFeaturesHelper_updateOutput<float>(rules, context, Context);
}
void CopyFeaturesHelper_updateGradInput(at::Tensor rules, at::Tensor dcontext,
at::Tensor dContext) {
if (dContext.is_cuda())
cuda_CopyFeaturesHelper_updateGradInput<float>(rules, dcontext, dContext);
else
cpu_CopyFeaturesHelper_updateGradInput<float>(rules, dcontext, dContext);
}
......@@ -22,14 +22,15 @@ class UnPoolingFunction(Function):
pool_size,
pool_stride,
nFeaturesToDrop):
ctx.input_features=input_features
ctx.save_for_backward(
input_spatial_size,
output_spatial_size)
ctx.input_metadata=input_metadata
ctx.input_spatial_size = input_spatial_size
ctx.output_spatial_size = output_spatial_size
ctx.dimension = dimension
ctx.pool_size = pool_size
ctx.pool_stride = pool_stride
ctx.nFeaturesToDrop = nFeaturesToDrop
ctx.input_features_shape=input_features.shape
output_features = input_features.new()
sparseconvnet.SCN.UnPooling_updateOutput(
input_spatial_size,
......@@ -44,16 +45,16 @@ class UnPoolingFunction(Function):
@staticmethod
def backward(ctx, grad_output):
grad_input=Variable(grad_output.data.new())
input_spatial_size,output_spatial_size=ctx.saved_tensors
grad_input=torch.zeros(ctx.input_features_shape,device=grad_output.device,dtype=grad_output.dtype)
sparseconvnet.SCN.UnPooling_updateGradInput(
ctx.input_spatial_size,
ctx.output_spatial_size,
input_spatial_size,
output_spatial_size,
ctx.pool_size,
ctx.pool_stride,
ctx.input_metadata,
ctx.input_features,
grad_input.data,
grad_output.data.contiguous(),
grad_input,
grad_output.contiguous(),
ctx.nFeaturesToDrop)
return grad_input, None, None, None, None, None, None, None
......
......@@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch, glob, os
import torch, glob, os, numpy as np
from .sparseConvNetTensor import SparseConvNetTensor
from .metadata import Metadata
......@@ -124,6 +124,16 @@ def batch_location_tensors(location_tensors):
a.append(pad_with_batch_idx(lt,batch_idx))
return torch.cat(a,0)
def prepare_BLInput(l,f):
with torch.no_grad():
n=max([x.size(0) for x in l])
L=torch.empty(len(l),n,l[0].size(1)).fill_(-1)
F=torch.zeros(len(l),n,f[0].size(1))
for i, (ll, ff) in enumerate(zip(l,f)):
L[i,:ll.size(0),:].copy_(ll)
F[i,:ff.size(0),:].copy_(ff)
return (L,F)
def checkpoint_restore(model,exp_name,name2,use_cuda=True,epoch=0):
if use_cuda:
model.cpu()
......@@ -162,3 +172,22 @@ def checkpoint_save(model,exp_name,name2,epoch, use_cuda=True):
def random_rotation(dimension=3):
return torch.qr(torch.randn(dimension,dimension))[0]
class LayerNormLeakyReLU(torch.nn.Module):
def __init__(self,num_features,leakiness):
torch.nn.Module.__init__(self)
self.leakiness=leakiness
self.in1d=torch.nn.LayerNorm(num_features)
def forward(self,x):
if x.features.numel():
x.features=self.in1d(x.features)
x.features=torch.nn.functional.leaky_relu(x.features,self.leakiness,inplace=True)
return x
def voxelize_pointcloud(xyz,rgb):
xyz,inv,counts=np.unique(xyz.long().numpy(),axis=0,return_inverse=True,return_counts=True)
xyz=torch.from_numpy(xyz)
inv=torch.from_numpy(inv)
rgb_out=torch.zeros(xyz.size(0),rgb.size(1),dtype=torch.float32)
rgb_out.index_add_(0,inv,rgb)
return xyz, rgb_out/torch.from_numpy(counts[:,None]).float()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment