Commit 879d0b68 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

ScanNet example

parent bd9f2c46
......@@ -9,4 +9,12 @@ To train a small U-Net with 5cm-cubed sparse voxels:
4. Run 'python prepare_data.py'
5. Run 'python unet.py'
You can the computational cost (and hopefully accuracy too) by changing m / block_reps / residual_blocks / scale / val_reps in unet.py / data.py.
You can train a bigger/more accurate network by changing `m` / `block_reps` / `residual_blocks` / `scale` / `val_reps` in unet.py / data.py, e.g.
```
m=32 # Wider network
block_reps=2 # Deeper network
residual_blocks=True # ResNet style basic blocks
scale=50 # 1/50 m = 2cm voxels
val_reps=3 # Multiple views at test time
batch_size=5 # Fit in 16GB of GPU memory
```
......@@ -254,12 +254,13 @@ void Metadata<dimension>::appendMetadata(Metadata<dimension> &mAdd,
}
template <Int dimension>
at::Tensor
std::vector<at::Tensor>
Metadata<dimension>::sparsifyCompare(Metadata<dimension> &mReference,
Metadata<dimension> &mSparsified,
/*long*/ at::Tensor spatialSize) {
auto p = LongTensorToPoint<dimension>(spatialSize);
at::Tensor delta = torch::zeros({nActive[p]}, at::kFloat);
at::Tensor ref_map = torch::empty({mReference.nActive[p]}, at::kLong);
float *deltaPtr = delta.data<float>();
auto &sgsReference = mReference.grids[p];
auto &sgsFull = grids[p];
......@@ -275,13 +276,16 @@ Metadata<dimension>::sparsifyCompare(Metadata<dimension> &mReference,
for (auto const &iter : sgFull.mp) {
bool gt = sgReference.mp.find(iter.first) != sgReference.mp.end();
bool hot = sgSparsified.mp.find(iter.first) != sgSparsified.mp.end();
if (gt)
ref_map[sgReference.mp[iter.first] + sgReference.ctr] =
iter.second + sgFull.ctr;
if (gt and not hot)
deltaPtr[iter.second + sgFull.ctr] = -1;
if (hot and not gt)
deltaPtr[iter.second + sgFull.ctr] = +1;
}
}
return delta;
return {delta, ref_map};
}
// tensor is size[0] x .. x size[dimension-1] x size[dimension]
......
......@@ -104,7 +104,7 @@ public:
void appendMetadata(Metadata<dimension> &mAdd,
/*long*/ at::Tensor spatialSize);
at::Tensor sparsifyCompare(Metadata<dimension> &mReference,
std::vector<at::Tensor> sparsifyCompare(Metadata<dimension> &mReference,
Metadata<dimension> &mSparsified,
/*long*/ at::Tensor spatialSize);
......
......@@ -57,7 +57,7 @@ class NetworkInNetworkFunction(Function):
class NetworkInNetwork(Module):
def __init__(self, nIn, nOut, bias=False):
def __init__(self, nIn, nOut, bias):
Module.__init__(self)
self.nIn = nIn
self.nOut = nOut
......
......@@ -10,7 +10,10 @@ from .utils import *
from .sparseConvNetTensor import SparseConvNetTensor
class JoinTable(Module):
class JoinTable(torch.nn.Sequential):
def __init__(self, *args):
torch.nn.Sequential.__init__(self, *args)
def forward(self, input):
output = SparseConvNetTensor()
output.metadata = input[0].metadata
......@@ -22,7 +25,10 @@ class JoinTable(Module):
return out_size
class AddTable(Module):
class AddTable(torch.nn.Sequential):
def __init__(self, *args):
torch.nn.Sequential.__init__(self, *args)
def forward(self, input):
output = SparseConvNetTensor()
output.metadata = input[0].metadata
......@@ -34,7 +40,10 @@ class AddTable(Module):
return out_size
class ConcatTable(Module):
class ConcatTable(torch.nn.Sequential):
def __init__(self, *args):
torch.nn.Sequential.__init__(self, *args)
def forward(self, input):
return [module(input) for module in self._modules.values()]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment