Commit 879d0b68 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

ScanNet example

parent bd9f2c46
...@@ -9,4 +9,12 @@ To train a small U-Net with 5cm-cubed sparse voxels: ...@@ -9,4 +9,12 @@ To train a small U-Net with 5cm-cubed sparse voxels:
4. Run 'python prepare_data.py' 4. Run 'python prepare_data.py'
5. Run 'python unet.py' 5. Run 'python unet.py'
You can the computational cost (and hopefully accuracy too) by changing m / block_reps / residual_blocks / scale / val_reps in unet.py / data.py. You can train a bigger/more accurate network by changing `m` / `block_reps` / `residual_blocks` / `scale` / `val_reps` in unet.py / data.py, e.g.
```
m=32 # Wider network
block_reps=2 # Deeper network
residual_blocks=True # ResNet style basic blocks
scale=50 # 1/50 m = 2cm voxels
val_reps=3 # Multiple views at test time
batch_size=5 # Fit in 16GB of GPU memory
```
...@@ -254,12 +254,13 @@ void Metadata<dimension>::appendMetadata(Metadata<dimension> &mAdd, ...@@ -254,12 +254,13 @@ void Metadata<dimension>::appendMetadata(Metadata<dimension> &mAdd,
} }
template <Int dimension> template <Int dimension>
at::Tensor std::vector<at::Tensor>
Metadata<dimension>::sparsifyCompare(Metadata<dimension> &mReference, Metadata<dimension>::sparsifyCompare(Metadata<dimension> &mReference,
Metadata<dimension> &mSparsified, Metadata<dimension> &mSparsified,
/*long*/ at::Tensor spatialSize) { /*long*/ at::Tensor spatialSize) {
auto p = LongTensorToPoint<dimension>(spatialSize); auto p = LongTensorToPoint<dimension>(spatialSize);
at::Tensor delta = torch::zeros({nActive[p]}, at::kFloat); at::Tensor delta = torch::zeros({nActive[p]}, at::kFloat);
at::Tensor ref_map = torch::empty({mReference.nActive[p]}, at::kLong);
float *deltaPtr = delta.data<float>(); float *deltaPtr = delta.data<float>();
auto &sgsReference = mReference.grids[p]; auto &sgsReference = mReference.grids[p];
auto &sgsFull = grids[p]; auto &sgsFull = grids[p];
...@@ -275,13 +276,16 @@ Metadata<dimension>::sparsifyCompare(Metadata<dimension> &mReference, ...@@ -275,13 +276,16 @@ Metadata<dimension>::sparsifyCompare(Metadata<dimension> &mReference,
for (auto const &iter : sgFull.mp) { for (auto const &iter : sgFull.mp) {
bool gt = sgReference.mp.find(iter.first) != sgReference.mp.end(); bool gt = sgReference.mp.find(iter.first) != sgReference.mp.end();
bool hot = sgSparsified.mp.find(iter.first) != sgSparsified.mp.end(); bool hot = sgSparsified.mp.find(iter.first) != sgSparsified.mp.end();
if (gt)
ref_map[sgReference.mp[iter.first] + sgReference.ctr] =
iter.second + sgFull.ctr;
if (gt and not hot) if (gt and not hot)
deltaPtr[iter.second + sgFull.ctr] = -1; deltaPtr[iter.second + sgFull.ctr] = -1;
if (hot and not gt) if (hot and not gt)
deltaPtr[iter.second + sgFull.ctr] = +1; deltaPtr[iter.second + sgFull.ctr] = +1;
} }
} }
return delta; return {delta, ref_map};
} }
// tensor is size[0] x .. x size[dimension-1] x size[dimension] // tensor is size[0] x .. x size[dimension-1] x size[dimension]
......
...@@ -104,9 +104,9 @@ public: ...@@ -104,9 +104,9 @@ public:
void appendMetadata(Metadata<dimension> &mAdd, void appendMetadata(Metadata<dimension> &mAdd,
/*long*/ at::Tensor spatialSize); /*long*/ at::Tensor spatialSize);
at::Tensor sparsifyCompare(Metadata<dimension> &mReference, std::vector<at::Tensor> sparsifyCompare(Metadata<dimension> &mReference,
Metadata<dimension> &mSparsified, Metadata<dimension> &mSparsified,
/*long*/ at::Tensor spatialSize); /*long*/ at::Tensor spatialSize);
// tensor is size[0] x .. x size[dimension-1] x size[dimension] // tensor is size[0] x .. x size[dimension-1] x size[dimension]
// size[0] x .. x size[dimension-1] == spatial volume // size[0] x .. x size[dimension-1] == spatial volume
......
...@@ -57,7 +57,7 @@ class NetworkInNetworkFunction(Function): ...@@ -57,7 +57,7 @@ class NetworkInNetworkFunction(Function):
class NetworkInNetwork(Module): class NetworkInNetwork(Module):
def __init__(self, nIn, nOut, bias=False): def __init__(self, nIn, nOut, bias):
Module.__init__(self) Module.__init__(self)
self.nIn = nIn self.nIn = nIn
self.nOut = nOut self.nOut = nOut
......
...@@ -10,7 +10,10 @@ from .utils import * ...@@ -10,7 +10,10 @@ from .utils import *
from .sparseConvNetTensor import SparseConvNetTensor from .sparseConvNetTensor import SparseConvNetTensor
class JoinTable(Module): class JoinTable(torch.nn.Sequential):
def __init__(self, *args):
torch.nn.Sequential.__init__(self, *args)
def forward(self, input): def forward(self, input):
output = SparseConvNetTensor() output = SparseConvNetTensor()
output.metadata = input[0].metadata output.metadata = input[0].metadata
...@@ -22,7 +25,10 @@ class JoinTable(Module): ...@@ -22,7 +25,10 @@ class JoinTable(Module):
return out_size return out_size
class AddTable(Module): class AddTable(torch.nn.Sequential):
def __init__(self, *args):
torch.nn.Sequential.__init__(self, *args)
def forward(self, input): def forward(self, input):
output = SparseConvNetTensor() output = SparseConvNetTensor()
output.metadata = input[0].metadata output.metadata = input[0].metadata
...@@ -34,7 +40,10 @@ class AddTable(Module): ...@@ -34,7 +40,10 @@ class AddTable(Module):
return out_size return out_size
class ConcatTable(Module): class ConcatTable(torch.nn.Sequential):
def __init__(self, *args):
torch.nn.Sequential.__init__(self, *args)
def forward(self, input): def forward(self, input):
return [module(input) for module in self._modules.values()] return [module(input) for module in self._modules.values()]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment