Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
SparseConvNet
Commits
4a543082
Commit
4a543082
authored
Nov 02, 2018
by
Benjamin Thomas Graham
Browse files
utils
parent
b596b107
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
183 additions
and
13 deletions
+183
-13
examples/3d_segmentation/data.py
examples/3d_segmentation/data.py
+1
-2
examples/3d_segmentation/unet.py
examples/3d_segmentation/unet.py
+2
-2
examples/Chinese_handwriting/data.py
examples/Chinese_handwriting/data.py
+2
-1
sparseconvnet/SCN/Metadata/Metadata.cpp
sparseconvnet/SCN/Metadata/Metadata.cpp
+1
-1
sparseconvnet/SCN/Metadata/RandomizedStrideRules.h
sparseconvnet/SCN/Metadata/RandomizedStrideRules.h
+2
-2
sparseconvnet/__init__.py
sparseconvnet/__init__.py
+1
-0
sparseconvnet/batchNormalization.py
sparseconvnet/batchNormalization.py
+1
-1
sparseconvnet/convolution.py
sparseconvnet/convolution.py
+1
-1
sparseconvnet/networkInNetwork.py
sparseconvnet/networkInNetwork.py
+1
-1
sparseconvnet/shapeContext.py
sparseconvnet/shapeContext.py
+125
-0
sparseconvnet/submanifoldConvolution.py
sparseconvnet/submanifoldConvolution.py
+1
-1
sparseconvnet/utils.py
sparseconvnet/utils.py
+45
-1
No files found.
examples/3d_segmentation/data.py
View file @
4a543082
...
@@ -5,8 +5,7 @@
...
@@ -5,8 +5,7 @@
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
,
torch
.
utils
.
data
import
torchnet
import
glob
,
math
,
os
import
glob
,
math
,
os
import
scipy
,
scipy
.
ndimage
import
scipy
,
scipy
.
ndimage
import
sparseconvnet
as
scn
import
sparseconvnet
as
scn
...
...
examples/3d_segmentation/unet.py
View file @
4a543082
...
@@ -14,9 +14,9 @@ import os, sys
...
@@ -14,9 +14,9 @@ import os, sys
import
math
import
math
import
numpy
as
np
import
numpy
as
np
data
.
init
(
-
1
,
24
,
24
*
8
+
15
,
16
)
data
.
init
(
-
1
,
24
,
24
*
8
,
16
)
dimension
=
3
dimension
=
3
reps
=
2
#Conv block repetition factor
reps
=
1
#Conv block repetition factor
m
=
32
#Unet number of features
m
=
32
#Unet number of features
nPlanes
=
[
m
,
2
*
m
,
3
*
m
,
4
*
m
,
5
*
m
]
#UNet number of features per level
nPlanes
=
[
m
,
2
*
m
,
3
*
m
,
4
*
m
,
5
*
m
]
#UNet number of features per level
...
...
examples/Chinese_handwriting/data.py
View file @
4a543082
# Copyright 2016-present, Facebook, Inc.
# Copyright 2016-present, Facebook, Inc.
# All rights reserved.
# All rights reserved.
#
#
...
@@ -19,7 +20,7 @@ if not os.path.exists('pickle/'):
...
@@ -19,7 +20,7 @@ if not os.path.exists('pickle/'):
'wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/OLHWDB1.1trn_pot.zip'
)
'wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/OLHWDB1.1trn_pot.zip'
)
os
.
system
(
os
.
system
(
'wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/OLHWDB1.1tst_pot.zip'
)
'wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/OLHWDB1.1tst_pot.zip'
)
os
.
system
(
'mkdir -p
t7/train/ t7/test/
POT/ pickle/'
)
os
.
system
(
'mkdir -p POT/ pickle/'
)
os
.
system
(
'unzip OLHWDB1.1trn_pot.zip -d POT/'
)
os
.
system
(
'unzip OLHWDB1.1trn_pot.zip -d POT/'
)
os
.
system
(
'unzip OLHWDB1.1tst_pot.zip -d POT/'
)
os
.
system
(
'unzip OLHWDB1.1tst_pot.zip -d POT/'
)
os
.
system
(
'python readPotFiles.py'
)
os
.
system
(
'python readPotFiles.py'
)
...
...
sparseconvnet/SCN/Metadata/Metadata.cpp
View file @
4a543082
...
@@ -260,7 +260,7 @@ Metadata<dimension>::sparsifyCompare(Metadata<dimension> &mReference,
...
@@ -260,7 +260,7 @@ Metadata<dimension>::sparsifyCompare(Metadata<dimension> &mReference,
Metadata
<
dimension
>
&
mSparsified
,
Metadata
<
dimension
>
&
mSparsified
,
/*long*/
at
::
Tensor
spatialSize
)
{
/*long*/
at
::
Tensor
spatialSize
)
{
auto
p
=
LongTensorToPoint
<
dimension
>
(
spatialSize
);
auto
p
=
LongTensorToPoint
<
dimension
>
(
spatialSize
);
at
::
Tensor
delta
=
at
::
zeros
({
nActive
[
p
]},
torch
::
CPU
(
at
::
kFloat
)
)
;
at
::
Tensor
delta
=
at
::
zeros
({
nActive
[
p
]},
at
::
kFloat
);
float
*
deltaPtr
=
delta
.
data
<
float
>
();
float
*
deltaPtr
=
delta
.
data
<
float
>
();
auto
&
sgsReference
=
mReference
.
grids
[
p
];
auto
&
sgsReference
=
mReference
.
grids
[
p
];
auto
&
sgsFull
=
grids
[
p
];
auto
&
sgsFull
=
grids
[
p
];
...
...
sparseconvnet/SCN/Metadata/RandomizedStrideRules.h
View file @
4a543082
...
@@ -18,8 +18,8 @@ public:
...
@@ -18,8 +18,8 @@ public:
RSRTicks
(
Int
input_spatialSize
,
Int
output_spatialSize
,
Int
size
,
Int
stride
,
RSRTicks
(
Int
input_spatialSize
,
Int
output_spatialSize
,
Int
size
,
Int
stride
,
std
::
default_random_engine
re
)
{
std
::
default_random_engine
re
)
{
std
::
vector
<
Int
>
steps
;
std
::
vector
<
Int
>
steps
;
//
steps.resize(output_spatialSize
/
3,stride
-
1);
steps
.
resize
(
output_spatialSize
/
3
,
stride
-
1
);
//
steps.resize(output_spatialSize
/3*
2,stride
+
1);
steps
.
resize
(
output_spatialSize
/
3
*
2
,
stride
+
1
);
steps
.
resize
(
output_spatialSize
-
1
,
stride
);
steps
.
resize
(
output_spatialSize
-
1
,
stride
);
std
::
shuffle
(
steps
.
begin
(),
steps
.
end
(),
re
);
std
::
shuffle
(
steps
.
begin
(),
steps
.
end
(),
re
);
inputL
.
push_back
(
0
);
inputL
.
push_back
(
0
);
...
...
sparseconvnet/__init__.py
View file @
4a543082
...
@@ -34,3 +34,4 @@ from .submanifoldConvolution import SubmanifoldConvolution, ValidConvolution
...
@@ -34,3 +34,4 @@ from .submanifoldConvolution import SubmanifoldConvolution, ValidConvolution
from
.tables
import
*
from
.tables
import
*
from
.unPooling
import
UnPooling
from
.unPooling
import
UnPooling
from
.utils
import
append_tensors
,
AddCoords
,
add_feature_planes
,
concatenate_feature_planes
,
compare_sparse
from
.utils
import
append_tensors
,
AddCoords
,
add_feature_planes
,
concatenate_feature_planes
,
compare_sparse
from
.shapeContext
import
ShapeContext
,
MultiscaleShapeContext
sparseconvnet/batchNormalization.py
View file @
4a543082
...
@@ -41,7 +41,7 @@ class BatchNormalization(Module):
...
@@ -41,7 +41,7 @@ class BatchNormalization(Module):
self
.
bias
=
Parameter
(
torch
.
Tensor
(
nPlanes
).
fill_
(
0
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
nPlanes
).
fill_
(
0
))
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
assert
input
.
features
.
nelement
()
==
0
or
input
.
features
.
size
(
1
)
==
self
.
nPlanes
assert
input
.
features
.
nelement
()
==
0
or
input
.
features
.
size
(
1
)
==
self
.
nPlanes
,
(
self
.
nPlanes
,
input
.
features
.
shape
)
output
=
SparseConvNetTensor
()
output
=
SparseConvNetTensor
()
output
.
metadata
=
input
.
metadata
output
.
metadata
=
input
.
metadata
output
.
spatial_size
=
input
.
spatial_size
output
.
spatial_size
=
input
.
spatial_size
...
...
sparseconvnet/convolution.py
View file @
4a543082
...
@@ -34,7 +34,7 @@ class Convolution(Module):
...
@@ -34,7 +34,7 @@ class Convolution(Module):
output
.
spatial_size
=
\
output
.
spatial_size
=
\
(
input
.
spatial_size
-
self
.
filter_size
)
/
self
.
filter_stride
+
1
(
input
.
spatial_size
-
self
.
filter_size
)
/
self
.
filter_stride
+
1
assert
((
output
.
spatial_size
-
1
)
*
self
.
filter_stride
+
assert
((
output
.
spatial_size
-
1
)
*
self
.
filter_stride
+
self
.
filter_size
==
input
.
spatial_size
).
all
()
self
.
filter_size
==
input
.
spatial_size
).
all
()
,
(
input
.
spatial_size
,
output
.
spatial_size
,
self
.
filter_size
,
self
.
filter_stride
)
output
.
features
=
ConvolutionFunction
.
apply
(
output
.
features
=
ConvolutionFunction
.
apply
(
input
.
features
,
input
.
features
,
self
.
weight
,
self
.
weight
,
...
...
sparseconvnet/networkInNetwork.py
View file @
4a543082
...
@@ -70,7 +70,7 @@ class NetworkInNetwork(Module):
...
@@ -70,7 +70,7 @@ class NetworkInNetwork(Module):
self
.
bias
=
Parameter
(
torch
.
Tensor
(
nOut
).
zero_
())
self
.
bias
=
Parameter
(
torch
.
Tensor
(
nOut
).
zero_
())
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
assert
input
.
features
.
nelement
()
==
0
or
input
.
features
.
size
(
1
)
==
self
.
nIn
assert
input
.
features
.
nelement
()
==
0
or
input
.
features
.
size
(
1
)
==
self
.
nIn
,
(
self
.
nIn
,
input
.
features
.
shape
)
output
=
SparseConvNetTensor
()
output
=
SparseConvNetTensor
()
output
.
metadata
=
input
.
metadata
output
.
metadata
=
input
.
metadata
output
.
spatial_size
=
input
.
spatial_size
output
.
spatial_size
=
input
.
spatial_size
...
...
sparseconvnet/shapeContext.py
0 → 100644
View file @
4a543082
# Copyright 2016-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Fixed weight submanifold convolution - ineffcieit implementation
# prod(filter_size)* nIn outputs
# weight format locations x nInput x nOutput
import
sparseconvnet
import
sparseconvnet.SCN
from
torch.autograd
import
Function
from
torch.nn
import
Module
,
Parameter
from
.utils
import
*
from
.sparseConvNetTensor
import
SparseConvNetTensor
class
ShapeContext
(
Module
):
def
__init__
(
self
,
dimension
,
nIn
,
filter_size
=
3
):
Module
.
__init__
(
self
)
self
.
dimension
=
dimension
self
.
filter_size
=
toLongTensor
(
dimension
,
filter_size
)
self
.
filter_volume
=
self
.
filter_size
.
prod
().
item
()
self
.
nIn
=
nIn
self
.
nOut
=
nIn
*
self
.
filter_volume
self
.
register_buffer
(
"weight"
,
torch
.
eye
(
self
.
nOut
).
view
(
self
.
filter_volume
,
self
.
nIn
,
self
.
nOut
))
def
forward
(
self
,
input
):
assert
input
.
features
.
nelement
()
==
0
or
input
.
features
.
size
(
1
)
==
self
.
nIn
,
(
self
.
nIn
,
self
.
nOut
,
input
)
output
=
SparseConvNetTensor
()
output
.
metadata
=
input
.
metadata
output
.
spatial_size
=
input
.
spatial_size
output
.
features
=
ShapeContextFunction
.
apply
(
input
.
features
,
self
.
weight
,
optionalTensor
(
self
,
'bias'
),
input
.
metadata
,
input
.
spatial_size
,
self
.
dimension
,
self
.
filter_size
)
return
output
def
__repr__
(
self
):
s
=
'ShapeContext '
+
\
str
(
self
.
nIn
)
+
'->'
+
str
(
self
.
nOut
)
+
' C'
if
self
.
filter_size
.
max
()
==
self
.
filter_size
.
min
():
s
=
s
+
str
(
self
.
filter_size
[
0
].
item
())
else
:
s
=
s
+
'('
+
str
(
self
.
filter_size
[
0
].
item
())
for
i
in
self
.
filter_size
[
1
:]:
s
=
s
+
','
+
str
(
i
.
item
())
s
=
s
+
')'
return
s
def
input_spatial_size
(
self
,
out_size
):
return
out_size
class
ShapeContextFunction
(
Function
):
@
staticmethod
def
forward
(
ctx
,
input_features
,
weight
,
bias
,
input_metadata
,
spatial_size
,
dimension
,
filter_size
):
ctx
.
input_metadata
=
input_metadata
ctx
.
dimension
=
dimension
output_features
=
input_features
.
new
()
ctx
.
save_for_backward
(
input_features
,
spatial_size
,
weight
,
bias
,
filter_size
)
sparseconvnet
.
SCN
.
SubmanifoldConvolution_updateOutput
(
spatial_size
,
filter_size
,
input_metadata
,
input_features
,
output_features
,
weight
,
bias
)
return
output_features
@
staticmethod
def
backward
(
ctx
,
grad_output
):
assert
False
,
"Don't backprop through ShapeContext!"
input_features
,
spatial_size
,
weight
,
bias
,
filter_size
=
ctx
.
saved_tensors
grad_input
=
grad_output
.
new
()
grad_weight
=
torch
.
zeros_like
(
weight
)
grad_bias
=
torch
.
zeros_like
(
bias
)
sparseconvnet
.
SCN
.
SubmanifoldConvolution_backward
(
spatial_size
,
filter_size
,
ctx
.
input_metadata
,
input_features
,
grad_input
,
grad_output
.
contiguous
(),
weight
,
grad_weight
,
grad_bias
)
return
grad_input
,
grad_weight
,
optionalTensorReturn
(
grad_bias
),
None
,
None
,
None
,
None
def
MultiscaleShapeContext
(
dimension
,
n_features
=
1
,
n_layers
=
3
,
shape_context_size
=
3
,
downsample_size
=
2
,
downsample_stride
=
2
,
bn
=
True
):
m
=
sparseconvnet
.
Sequential
()
if
n_layers
==
1
:
m
.
add
(
sparseconvnet
.
ShapeContext
(
dimension
,
n_features
,
shape_context_size
))
else
:
m
.
add
(
sparseconvnet
.
ConcatTable
().
add
(
sparseconvnet
.
ShapeContext
(
dimension
,
n_features
,
shape_context_size
)).
add
(
sparseconvnet
.
Sequential
(
sparseconvnet
.
AveragePooling
(
dimension
,
downsample_size
,
downsample_stride
),
MultiscaleShapeContext
(
dimension
,
n_features
,
n_layers
-
1
,
shape_context_size
,
downsample_size
,
downsample_stride
,
False
),
sparseconvnet
.
UnPooling
(
dimension
,
downsample_size
,
downsample_stride
)))).
add
(
sparseconvnet
.
JoinTable
())
if
bn
:
m
.
add
(
sparseconvnet
.
BatchNormalization
(
shape_context_size
**
dimension
*
n_features
*
n_layers
))
return
m
sparseconvnet/submanifoldConvolution.py
View file @
4a543082
...
@@ -29,7 +29,7 @@ class SubmanifoldConvolution(Module):
...
@@ -29,7 +29,7 @@ class SubmanifoldConvolution(Module):
self
.
bias
=
Parameter
(
torch
.
Tensor
(
nOut
).
zero_
())
self
.
bias
=
Parameter
(
torch
.
Tensor
(
nOut
).
zero_
())
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
assert
input
.
features
.
nelement
()
==
0
or
input
.
features
.
size
(
1
)
==
self
.
nIn
assert
input
.
features
.
nelement
()
==
0
or
input
.
features
.
size
(
1
)
==
self
.
nIn
,
(
self
.
nIn
,
self
.
nOut
,
input
)
output
=
SparseConvNetTensor
()
output
=
SparseConvNetTensor
()
output
.
metadata
=
input
.
metadata
output
.
metadata
=
input
.
metadata
output
.
spatial_size
=
input
.
spatial_size
output
.
spatial_size
=
input
.
spatial_size
...
...
sparseconvnet/utils.py
View file @
4a543082
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
# This source code is licensed under the license found in the
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
import
torch
import
torch
,
glob
,
os
from
.sparseConvNetTensor
import
SparseConvNetTensor
from
.sparseConvNetTensor
import
SparseConvNetTensor
from
.metadata
import
Metadata
from
.metadata
import
Metadata
...
@@ -113,3 +113,47 @@ def spectral_norm_svd(module):
...
@@ -113,3 +113,47 @@ def spectral_norm_svd(module):
w
=
w
.
view
(
-
1
,
w
.
size
(
2
))
w
=
w
.
view
(
-
1
,
w
.
size
(
2
))
_
,
s
,
_
=
torch
.
svd
(
w
)
_
,
s
,
_
=
torch
.
svd
(
w
)
return
s
[
0
]
return
s
[
0
]
def
pad_with_batch_idx
(
x
,
idx
):
#add a batch index to the list of coordinates
return
torch
.
cat
([
x
,
torch
.
LongTensor
(
x
.
size
(
0
),
1
).
fill_
(
idx
)],
1
)
def
batch_location_tensors
(
location_tensors
):
a
=
[]
for
batch_idx
,
lt
in
enumerate
(
location_tensors
):
if
lt
.
numel
():
a
.
append
(
pad_with_batch_idx
(
lt
,
batch_idx
))
return
torch
.
cat
(
a
,
0
)
def
checkpoint_restore
(
model
,
exp_name
,
name2
,
use_cuda
=
True
,
epoch
=
0
):
if
use_cuda
:
model
.
cpu
()
if
epoch
>
0
:
f
=
exp_name
+
'-%09d-'
%
epoch
+
name2
+
'.pth'
assert
os
.
path
.
isfile
(
f
)
print
(
'Restore from '
+
f
)
model
.
load_state_dict
(
torch
.
load
(
f
))
else
:
f
=
sorted
(
glob
.
glob
(
exp_name
+
'-*-'
+
name2
+
'.pth'
))
if
len
(
f
)
>
0
:
f
=
f
[
-
1
]
print
(
'Restore from '
+
f
)
model
.
load_state_dict
(
torch
.
load
(
f
))
epoch
=
int
(
f
[
len
(
exp_name
)
+
1
:
-
len
(
name2
)
-
5
])
if
use_cuda
:
model
.
cuda
()
return
epoch
+
1
def
is_power2
(
num
):
return
num
!=
0
and
((
num
&
(
num
-
1
))
==
0
)
def
checkpoint_save
(
model
,
exp_name
,
name2
,
epoch
,
use_cuda
=
True
):
f
=
exp_name
+
'-%09d-'
%
epoch
+
name2
+
'.pth'
model
.
cpu
()
torch
.
save
(
model
.
state_dict
(),
f
)
if
use_cuda
:
model
.
cuda
()
#remove previous checkpoints unless they are a power of 2 to save disk space
epoch
=
epoch
-
1
f
=
exp_name
+
'-%09d-'
%
epoch
+
name2
+
'.pth'
if
os
.
path
.
isfile
(
f
):
if
not
is_power2
(
epoch
):
os
.
remove
(
f
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment