Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
SparseConvNet
Commits
879d0b68
Commit
879d0b68
authored
Dec 11, 2018
by
Benjamin Thomas Graham
Browse files
ScanNet example
parent
bd9f2c46
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
31 additions
and
10 deletions
+31
-10
examples/ScanNet/README.md
examples/ScanNet/README.md
+9
-1
sparseconvnet/SCN/Metadata/Metadata.cpp
sparseconvnet/SCN/Metadata/Metadata.cpp
+6
-2
sparseconvnet/SCN/Metadata/Metadata.h
sparseconvnet/SCN/Metadata/Metadata.h
+3
-3
sparseconvnet/networkInNetwork.py
sparseconvnet/networkInNetwork.py
+1
-1
sparseconvnet/tables.py
sparseconvnet/tables.py
+12
-3
No files found.
examples/ScanNet/README.md
View file @
879d0b68
...
@@ -9,4 +9,12 @@ To train a small U-Net with 5cm-cubed sparse voxels:
...
@@ -9,4 +9,12 @@ To train a small U-Net with 5cm-cubed sparse voxels:
4.
Run 'python prepare_data.py'
4.
Run 'python prepare_data.py'
5.
Run 'python unet.py'
5.
Run 'python unet.py'
You can the computational cost (and hopefully accuracy too) by changing m / block_reps / residual_blocks / scale / val_reps in unet.py / data.py.
You can train a bigger/more accurate network by changing
`m`
/
`block_reps`
/
`residual_blocks`
/
`scale`
/
`val_reps`
in unet.py / data.py, e.g.
```
m=32 # Wider network
block_reps=2 # Deeper network
residual_blocks=True # ResNet style basic blocks
scale=50 # 1/50 m = 2cm voxels
val_reps=3 # Multiple views at test time
batch_size=5 # Fit in 16GB of GPU memory
```
sparseconvnet/SCN/Metadata/Metadata.cpp
View file @
879d0b68
...
@@ -254,12 +254,13 @@ void Metadata<dimension>::appendMetadata(Metadata<dimension> &mAdd,
...
@@ -254,12 +254,13 @@ void Metadata<dimension>::appendMetadata(Metadata<dimension> &mAdd,
}
}
template
<
Int
dimension
>
template
<
Int
dimension
>
at
::
Tensor
std
::
vector
<
at
::
Tensor
>
Metadata
<
dimension
>::
sparsifyCompare
(
Metadata
<
dimension
>
&
mReference
,
Metadata
<
dimension
>::
sparsifyCompare
(
Metadata
<
dimension
>
&
mReference
,
Metadata
<
dimension
>
&
mSparsified
,
Metadata
<
dimension
>
&
mSparsified
,
/*long*/
at
::
Tensor
spatialSize
)
{
/*long*/
at
::
Tensor
spatialSize
)
{
auto
p
=
LongTensorToPoint
<
dimension
>
(
spatialSize
);
auto
p
=
LongTensorToPoint
<
dimension
>
(
spatialSize
);
at
::
Tensor
delta
=
torch
::
zeros
({
nActive
[
p
]},
at
::
kFloat
);
at
::
Tensor
delta
=
torch
::
zeros
({
nActive
[
p
]},
at
::
kFloat
);
at
::
Tensor
ref_map
=
torch
::
empty
({
mReference
.
nActive
[
p
]},
at
::
kLong
);
float
*
deltaPtr
=
delta
.
data
<
float
>
();
float
*
deltaPtr
=
delta
.
data
<
float
>
();
auto
&
sgsReference
=
mReference
.
grids
[
p
];
auto
&
sgsReference
=
mReference
.
grids
[
p
];
auto
&
sgsFull
=
grids
[
p
];
auto
&
sgsFull
=
grids
[
p
];
...
@@ -275,13 +276,16 @@ Metadata<dimension>::sparsifyCompare(Metadata<dimension> &mReference,
...
@@ -275,13 +276,16 @@ Metadata<dimension>::sparsifyCompare(Metadata<dimension> &mReference,
for
(
auto
const
&
iter
:
sgFull
.
mp
)
{
for
(
auto
const
&
iter
:
sgFull
.
mp
)
{
bool
gt
=
sgReference
.
mp
.
find
(
iter
.
first
)
!=
sgReference
.
mp
.
end
();
bool
gt
=
sgReference
.
mp
.
find
(
iter
.
first
)
!=
sgReference
.
mp
.
end
();
bool
hot
=
sgSparsified
.
mp
.
find
(
iter
.
first
)
!=
sgSparsified
.
mp
.
end
();
bool
hot
=
sgSparsified
.
mp
.
find
(
iter
.
first
)
!=
sgSparsified
.
mp
.
end
();
if
(
gt
)
ref_map
[
sgReference
.
mp
[
iter
.
first
]
+
sgReference
.
ctr
]
=
iter
.
second
+
sgFull
.
ctr
;
if
(
gt
and
not
hot
)
if
(
gt
and
not
hot
)
deltaPtr
[
iter
.
second
+
sgFull
.
ctr
]
=
-
1
;
deltaPtr
[
iter
.
second
+
sgFull
.
ctr
]
=
-
1
;
if
(
hot
and
not
gt
)
if
(
hot
and
not
gt
)
deltaPtr
[
iter
.
second
+
sgFull
.
ctr
]
=
+
1
;
deltaPtr
[
iter
.
second
+
sgFull
.
ctr
]
=
+
1
;
}
}
}
}
return
delta
;
return
{
delta
,
ref_map
}
;
}
}
// tensor is size[0] x .. x size[dimension-1] x size[dimension]
// tensor is size[0] x .. x size[dimension-1] x size[dimension]
...
...
sparseconvnet/SCN/Metadata/Metadata.h
View file @
879d0b68
...
@@ -104,9 +104,9 @@ public:
...
@@ -104,9 +104,9 @@ public:
void
appendMetadata
(
Metadata
<
dimension
>
&
mAdd
,
void
appendMetadata
(
Metadata
<
dimension
>
&
mAdd
,
/*long*/
at
::
Tensor
spatialSize
);
/*long*/
at
::
Tensor
spatialSize
);
at
::
Tensor
sparsifyCompare
(
Metadata
<
dimension
>
&
mReference
,
std
::
vector
<
at
::
Tensor
>
sparsifyCompare
(
Metadata
<
dimension
>
&
mReference
,
Metadata
<
dimension
>
&
mSparsified
,
Metadata
<
dimension
>
&
mSparsified
,
/*long*/
at
::
Tensor
spatialSize
);
/*long*/
at
::
Tensor
spatialSize
);
// tensor is size[0] x .. x size[dimension-1] x size[dimension]
// tensor is size[0] x .. x size[dimension-1] x size[dimension]
// size[0] x .. x size[dimension-1] == spatial volume
// size[0] x .. x size[dimension-1] == spatial volume
...
...
sparseconvnet/networkInNetwork.py
View file @
879d0b68
...
@@ -57,7 +57,7 @@ class NetworkInNetworkFunction(Function):
...
@@ -57,7 +57,7 @@ class NetworkInNetworkFunction(Function):
class
NetworkInNetwork
(
Module
):
class
NetworkInNetwork
(
Module
):
def
__init__
(
self
,
nIn
,
nOut
,
bias
=
False
):
def
__init__
(
self
,
nIn
,
nOut
,
bias
):
Module
.
__init__
(
self
)
Module
.
__init__
(
self
)
self
.
nIn
=
nIn
self
.
nIn
=
nIn
self
.
nOut
=
nOut
self
.
nOut
=
nOut
...
...
sparseconvnet/tables.py
View file @
879d0b68
...
@@ -10,7 +10,10 @@ from .utils import *
...
@@ -10,7 +10,10 @@ from .utils import *
from
.sparseConvNetTensor
import
SparseConvNetTensor
from
.sparseConvNetTensor
import
SparseConvNetTensor
class
JoinTable
(
Module
):
class
JoinTable
(
torch
.
nn
.
Sequential
):
def
__init__
(
self
,
*
args
):
torch
.
nn
.
Sequential
.
__init__
(
self
,
*
args
)
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
output
=
SparseConvNetTensor
()
output
=
SparseConvNetTensor
()
output
.
metadata
=
input
[
0
].
metadata
output
.
metadata
=
input
[
0
].
metadata
...
@@ -22,7 +25,10 @@ class JoinTable(Module):
...
@@ -22,7 +25,10 @@ class JoinTable(Module):
return
out_size
return
out_size
class
AddTable
(
Module
):
class
AddTable
(
torch
.
nn
.
Sequential
):
def
__init__
(
self
,
*
args
):
torch
.
nn
.
Sequential
.
__init__
(
self
,
*
args
)
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
output
=
SparseConvNetTensor
()
output
=
SparseConvNetTensor
()
output
.
metadata
=
input
[
0
].
metadata
output
.
metadata
=
input
[
0
].
metadata
...
@@ -34,7 +40,10 @@ class AddTable(Module):
...
@@ -34,7 +40,10 @@ class AddTable(Module):
return
out_size
return
out_size
class
ConcatTable
(
Module
):
class
ConcatTable
(
torch
.
nn
.
Sequential
):
def
__init__
(
self
,
*
args
):
torch
.
nn
.
Sequential
.
__init__
(
self
,
*
args
)
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
return
[
module
(
input
)
for
module
in
self
.
_modules
.
values
()]
return
[
module
(
input
)
for
module
in
self
.
_modules
.
values
()]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment