Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
SparseConvNet
Commits
d8b64558
Commit
d8b64558
authored
Oct 24, 2017
by
Benjamin Thomas Graham
Browse files
UNet; fix example data loader permutation
parent
00891eb5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
77 additions
and
6 deletions
+77
-6
PyTorch/sparseconvnet/networkArchitectures.py
PyTorch/sparseconvnet/networkArchitectures.py
+73
-2
examples/Assamese_handwriting/data.py
examples/Assamese_handwriting/data.py
+2
-2
examples/Chinese_handwriting/data.py
examples/Chinese_handwriting/data.py
+2
-2
No files found.
PyTorch/sparseconvnet/networkArchitectures.py
View file @
d8b64558
...
@@ -17,6 +17,7 @@ from .sparseToDense import SparseToDense
...
@@ -17,6 +17,7 @@ from .sparseToDense import SparseToDense
from
.denseToSparse
import
DenseToSparse
from
.denseToSparse
import
DenseToSparse
from
.tables
import
*
from
.tables
import
*
def
SparseVggNet
(
dimension
,
nInputPlanes
,
layers
):
def
SparseVggNet
(
dimension
,
nInputPlanes
,
layers
):
"""
"""
VGG style nets
VGG style nets
...
@@ -36,7 +37,7 @@ def SparseVggNet(dimension, nInputPlanes, layers):
...
@@ -36,7 +37,7 @@ def SparseVggNet(dimension, nInputPlanes, layers):
m
.
add
(
BatchNormReLU
(
nPlanes
))
m
.
add
(
BatchNormReLU
(
nPlanes
))
elif
x
[
0
]
==
'C'
and
len
(
x
)
==
3
:
elif
x
[
0
]
==
'C'
and
len
(
x
)
==
3
:
m
.
add
(
ConcatTable
()
m
.
add
(
ConcatTable
()
.
add
(
.
add
(
SubmanifoldConvolution
(
dimension
,
nPlanes
,
x
[
1
],
3
,
False
)
SubmanifoldConvolution
(
dimension
,
nPlanes
,
x
[
1
],
3
,
False
)
).
add
(
).
add
(
Sequential
()
Sequential
()
...
@@ -105,7 +106,7 @@ def SparseVggNet(dimension, nInputPlanes, layers):
...
@@ -105,7 +106,7 @@ def SparseVggNet(dimension, nInputPlanes, layers):
.
add
(
SubmanifoldConvolution
(
dimension
,
x
[
3
],
x
[
3
],
3
,
False
))
.
add
(
SubmanifoldConvolution
(
dimension
,
x
[
3
],
x
[
3
],
3
,
False
))
.
add
(
BatchNormReLU
(
x
[
3
]))
.
add
(
BatchNormReLU
(
x
[
3
]))
.
add
(
Deconvolution
(
dimension
,
x
[
3
],
x
[
3
],
3
,
2
,
False
))
.
add
(
Deconvolution
(
dimension
,
x
[
3
],
x
[
3
],
3
,
2
,
False
))
)
)
.
add
(
Sequential
()
.
add
(
Sequential
()
.
add
(
Convolution
(
dimension
,
nPlanes
,
x
[
4
],
3
,
2
,
False
))
.
add
(
Convolution
(
dimension
,
nPlanes
,
x
[
4
],
3
,
2
,
False
))
.
add
(
BatchNormReLU
(
x
[
4
]))
.
add
(
BatchNormReLU
(
x
[
4
]))
...
@@ -133,6 +134,7 @@ def SparseVggNet(dimension, nInputPlanes, layers):
...
@@ -133,6 +134,7 @@ def SparseVggNet(dimension, nInputPlanes, layers):
m
.
add
(
BatchNormReLU
(
nPlanes
))
m
.
add
(
BatchNormReLU
(
nPlanes
))
return
m
return
m
def
SparseResNet
(
dimension
,
nInputPlanes
,
layers
):
def
SparseResNet
(
dimension
,
nInputPlanes
,
layers
):
"""
"""
pre-activated ResNet
pre-activated ResNet
...
@@ -202,3 +204,72 @@ def SparseResNet(dimension, nInputPlanes, layers):
...
@@ -202,3 +204,72 @@ def SparseResNet(dimension, nInputPlanes, layers):
m
.
add
(
AddTable
())
m
.
add
(
AddTable
())
m
.
add
(
BatchNormReLU
(
nPlanes
))
m
.
add
(
BatchNormReLU
(
nPlanes
))
return
m
return
m
def
ResNetUNet
(
dimension
,
nPlanes
,
reps
,
depth
=
4
):
"""
U-Net style network with ResNet-style blocks.
For voxel level prediction:
import sparseconvnet as scn
import torch.nn
class Model(nn.Module):
def __init__(self):
nn.Module.__init__(self)
self.sparseModel = scn.Sequential().add(
scn.ValidConvolution(3, nInputFeatures, 64, 3, False)).add(
scn.ResNetUNet(3, 64, 2, 4))
self.linear = nn.Linear(64, nClasses)
def forward(self,x):
x=self.sparseModel(x).features
x=self.linear(x)
return x
"""
def
res
(
m
,
a
,
b
):
m
.
add
(
ConcatTable
()
.
add
(
Identity
()
if
a
==
b
else
NetworkInNetwork
(
a
,
b
,
False
))
.
add
(
Sequential
()
.
add
(
BatchNormReLU
(
a
))
.
add
(
SubmanifoldConvolution
(
dimension
,
a
,
b
,
3
,
False
))
.
add
(
BatchNormReLU
(
b
))
.
add
(
SubmanifoldConvolution
(
dimension
,
b
,
b
,
3
,
False
))))
\
.
add
(
AddTable
())
def
v
(
depth
,
nPlanes
):
m
=
Sequential
()
if
depth
==
1
:
for
_
in
range
(
reps
):
res
(
m
,
nPlanes
,
nPlanes
)
else
:
m
=
Sequential
()
for
_
in
range
(
reps
):
res
(
m
,
nPlanes
,
nPlanes
)
m
.
add
(
ConcatTable
()
.
add
(
Identity
())
.
add
(
Sequential
()
.
add
(
BatchNormReLU
(
nPlanes
))
.
add
(
Convolution
(
dimension
,
nPlanes
,
nPlanes
,
2
,
2
,
False
))
.
add
(
v
(
depth
-
1
,
nPlanes
))
.
add
(
BatchNormReLU
(
nPlanes
))
.
add
(
Deconvolution
(
dimension
,
nPlanes
,
nPlanes
,
2
,
2
,
False
))))
m
.
add
(
JoinTable
())
for
i
in
range
(
reps
):
res
(
m
,
2
*
nPlanes
if
i
==
0
else
nPlanes
,
nPlanes
)
return
m
m
=
v
(
depth
,
nPlanes
)
m
.
add
(
BatchNormReLU
(
nPlanes
))
return
m
examples/Assamese_handwriting/data.py
View file @
d8b64558
...
@@ -85,7 +85,7 @@ def train(spatial_size, Scale, precomputeStride):
...
@@ -85,7 +85,7 @@ def train(spatial_size, Scale, precomputeStride):
tdi
=
scn
.
threadDatasetIterator
(
bd
)
tdi
=
scn
.
threadDatasetIterator
(
bd
)
def
iter
():
def
iter
():
randperm
=
torch
.
randperm
(
len
(
d
))
randperm
.
copy_
(
torch
.
randperm
(
len
(
d
))
)
return
tdi
()
return
tdi
()
return
iter
return
iter
...
@@ -120,7 +120,7 @@ def val(spatial_size, Scale, precomputeStride):
...
@@ -120,7 +120,7 @@ def val(spatial_size, Scale, precomputeStride):
tdi
=
scn
.
threadDatasetIterator
(
bd
)
tdi
=
scn
.
threadDatasetIterator
(
bd
)
def
iter
():
def
iter
():
randperm
=
torch
.
randperm
(
len
(
d
))
randperm
.
copy_
(
torch
.
randperm
(
len
(
d
))
)
return
tdi
()
return
tdi
()
return
iter
return
iter
...
...
examples/Chinese_handwriting/data.py
View file @
d8b64558
...
@@ -74,7 +74,7 @@ def train(spatial_size, Scale, precomputeStride):
...
@@ -74,7 +74,7 @@ def train(spatial_size, Scale, precomputeStride):
tdi
=
scn
.
threadDatasetIterator
(
bd
)
tdi
=
scn
.
threadDatasetIterator
(
bd
)
def
iter
():
def
iter
():
randperm
=
torch
.
randperm
(
len
(
d
))
randperm
.
copy_
(
torch
.
randperm
(
len
(
d
))
)
return
tdi
()
return
tdi
()
return
iter
return
iter
...
@@ -109,7 +109,7 @@ def val(spatial_size, Scale, precomputeStride):
...
@@ -109,7 +109,7 @@ def val(spatial_size, Scale, precomputeStride):
tdi
=
scn
.
threadDatasetIterator
(
bd
)
tdi
=
scn
.
threadDatasetIterator
(
bd
)
def
iter
():
def
iter
():
randperm
=
torch
.
randperm
(
len
(
d
))
randperm
.
copy_
(
torch
.
randperm
(
len
(
d
))
)
return
tdi
()
return
tdi
()
return
iter
return
iter
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment