Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
SparseConvNet
Commits
f6e15d2f
"vscode:/vscode.git/clone" did not exist on "6aff8be6547967f1e14fcd65679dff5ee8445da3"
Commit
f6e15d2f
authored
Mar 10, 2020
by
Benjamin Thomas Graham
Browse files
Fixes
parent
b862d6a2
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
132 additions
and
93 deletions
+132
-93
build.sh
build.sh
+2
-0
develop.sh
develop.sh
+2
-2
setup.py
setup.py
+3
-4
sparseconvnet/SCN/Metadata/IOLayersRules.h
sparseconvnet/SCN/Metadata/IOLayersRules.h
+71
-73
sparseconvnet/batchNormalization.py
sparseconvnet/batchNormalization.py
+6
-6
sparseconvnet/networkArchitectures.py
sparseconvnet/networkArchitectures.py
+3
-3
sparseconvnet/sequential.py
sparseconvnet/sequential.py
+9
-0
sparseconvnet/utils.py
sparseconvnet/utils.py
+36
-5
No files found.
build.sh
View file @
f6e15d2f
...
...
@@ -4,5 +4,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#export TORCH_CUDA_ARCH_LIST="6.0;6.1;6.2;7.0;7.5"
rm
-rf
build/ dist/ sparseconvnet.egg-info
python setup.py
install
&&
python examples/hello-world.py
develop.sh
View file @
f6e15d2f
...
...
@@ -5,6 +5,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#export TORCH_CUDA_ARCH_LIST="6.0;6.1;6.2;7.0;7.5"
rm
-rf
build/ dist/ sparseconvnet.egg-info sparseconvnet_SCN
*
.so
python setup.py develop
python examples/hello-world.py
python setup.py develop
&&
python examples/hello-world.py
setup.py
View file @
f6e15d2f
...
...
@@ -14,8 +14,7 @@ if torch.cuda.is_available():
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
torch_dir
=
os
.
path
.
dirname
(
torch
.
__file__
)
conda_include_dir
=
'/'
.
join
(
torch_dir
.
split
(
'/'
)[:
-
4
])
+
'/include'
extra
=
{
'cxx'
:
[
'-std=c++11'
,
'-fopenmp'
],
'nvcc'
:
[
'-std=c++11'
,
'-Xcompiler'
,
'-fopenmp'
]}
extra
=
{
'cxx'
:
[
'-std=c++14'
,
'-fopenmp'
],
'nvcc'
:
[
'-std=c++14'
,
'-Xcompiler'
,
'-fopenmp'
]}
setup
(
name
=
'sparseconvnet'
,
...
...
@@ -29,12 +28,12 @@ setup(
CUDAExtension
(
'sparseconvnet.SCN'
,
[
'sparseconvnet/SCN/cuda.cu'
,
'sparseconvnet/SCN/sparseconvnet_cuda.cpp'
,
'sparseconvnet/SCN/pybind.cpp'
],
include_dirs
=
[
conda_include_dir
,
this_dir
+
'/sparseconvnet/SCN/'
],
include_dirs
=
[
this_dir
+
'/sparseconvnet/SCN/'
],
extra_compile_args
=
extra
)
if
torch
.
cuda
.
is_available
()
else
CppExtension
(
'sparseconvnet.SCN'
,
[
'sparseconvnet/SCN/pybind.cpp'
,
'sparseconvnet/SCN/sparseconvnet_cpu.cpp'
],
include_dirs
=
[
conda_include_dir
,
this_dir
+
'/sparseconvnet/SCN/'
],
include_dirs
=
[
this_dir
+
'/sparseconvnet/SCN/'
],
extra_compile_args
=
extra
[
'cxx'
])],
cmdclass
=
{
'build_ext'
:
BuildExtension
},
zip_safe
=
False
,
...
...
sparseconvnet/SCN/Metadata/IOLayersRules.h
View file @
f6e15d2f
...
...
@@ -216,7 +216,6 @@ void blRules(SparseGrids<dimension> &SGs, RuleBook &rules, long *coords,
rules
[
0
].
push_back
(
length
);
rules
[
0
].
push_back
(
nActive
);
auto
&
rule
=
rules
[
1
];
if
(
mode
==
1
)
{
rule
.
resize
(
2
*
nActive
);
#pragma omp parallel for private(I)
for
(
I
=
0
;
I
<
batchSize
;
I
++
)
{
...
...
@@ -228,7 +227,6 @@ void blRules(SparseGrids<dimension> &SGs, RuleBook &rules, long *coords,
rr
+=
2
;
}
}
}
return
;
}
...
...
sparseconvnet/batchNormalization.py
View file @
f6e15d2f
...
...
@@ -15,7 +15,7 @@ class BatchNormalization(Module):
Parameters:
nPlanes : number of input planes
eps : small number used to stabilise standard deviation calculation
momentum : for calculating running average for testing (default 0.9)
momentum : for calculating running average for testing (default 0.9
9
)
affine : only 'true' is supported at present (default 'true')
noise : add multiplicative and additive noise during training if >0.
leakiness : Apply activation def inplace: 0<=leakiness<=1.
...
...
@@ -25,7 +25,7 @@ class BatchNormalization(Module):
self
,
nPlanes
,
eps
=
1e-4
,
momentum
=
0.9
,
momentum
=
0.9
9
,
affine
=
True
,
leakiness
=
1
):
Module
.
__init__
(
self
)
...
...
@@ -72,7 +72,7 @@ class BatchNormalization(Module):
class
BatchNormReLU
(
BatchNormalization
):
def
__init__
(
self
,
nPlanes
,
eps
=
1e-4
,
momentum
=
0.9
):
def
__init__
(
self
,
nPlanes
,
eps
=
1e-4
,
momentum
=
0.9
9
):
BatchNormalization
.
__init__
(
self
,
nPlanes
,
eps
,
momentum
,
True
,
0
)
def
__repr__
(
self
):
...
...
@@ -82,7 +82,7 @@ class BatchNormReLU(BatchNormalization):
class
BatchNormLeakyReLU
(
BatchNormalization
):
def
__init__
(
self
,
nPlanes
,
eps
=
1e-4
,
momentum
=
0.9
,
leakiness
=
0.333
):
def
__init__
(
self
,
nPlanes
,
eps
=
1e-4
,
momentum
=
0.9
9
,
leakiness
=
0.333
):
BatchNormalization
.
__init__
(
self
,
nPlanes
,
eps
,
momentum
,
True
,
leakiness
)
def
__repr__
(
self
):
...
...
@@ -166,7 +166,7 @@ class MeanOnlyBNLeakyReLU(Module):
"""
Parameters:
nPlanes : number of input planes
momentum : for calculating running average for testing (default 0.9)
momentum : for calculating running average for testing (default 0.9
9
)
leakiness : Apply activation def inplace: 0<=leakiness<=1.
0 for ReLU, values in (0,1) for LeakyReLU, 1 for no activation def.
"""
...
...
@@ -175,7 +175,7 @@ class MeanOnlyBNLeakyReLU(Module):
nPlanes
,
affine
=
True
,
leakiness
=
1
,
momentum
=
0.9
):
momentum
=
0.9
9
):
Module
.
__init__
(
self
)
self
.
nPlanes
=
nPlanes
self
.
momentum
=
momentum
...
...
sparseconvnet/networkArchitectures.py
View file @
f6e15d2f
...
...
@@ -318,7 +318,7 @@ def FullConvolutionalNetIntegratedLinear(dimension, reps, nPlanes, nClasses=-1,
return
x
+
nPlanes
def
foo
(
m
,
np
):
for
_
in
range
(
reps
):
if
residual
_blocks
:
#ResNet style blocks
if
residual
:
#ResNet style blocks
m
.
add
(
scn
.
ConcatTable
()
.
add
(
scn
.
Identity
())
.
add
(
scn
.
Sequential
()
...
...
@@ -333,7 +333,7 @@ def FullConvolutionalNetIntegratedLinear(dimension, reps, nPlanes, nClasses=-1,
def
bar
(
m
,
nPlanes
,
bias
):
m
.
add
(
scn
.
BatchNormLeakyReLU
(
nPlanes
,
leakiness
=
leakiness
))
m
.
add
(
scn
.
NetworkInNetwork
(
nPlanes
,
nClasses
,
bias
))
#accumulte softmax input, only one set of biases
def
baz
(
depth
,
nPlanes
):
def
baz
(
nPlanes
):
m
=
scn
.
Sequential
()
foo
(
m
,
nPlanes
[
0
])
if
len
(
nPlanes
)
==
1
:
...
...
@@ -348,4 +348,4 @@ def FullConvolutionalNetIntegratedLinear(dimension, reps, nPlanes, nClasses=-1,
scn
.
UnPooling
(
dimension
,
downsample
[
0
],
downsample
[
1
]))
m
.
add
(
ConcatTable
(
a
,
b
))
m
.
add
(
scn
.
AddTable
())
return
baz
(
depth
,
nPlanes
)
return
baz
(
nPlanes
)
sparseconvnet/sequential.py
View file @
f6e15d2f
...
...
@@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
import
torch
,
torch
.
utils
.
checkpoint
from
.utils
import
checkpoint101
class
Sequential
(
torch
.
nn
.
Sequential
):
def
input_spatial_size
(
self
,
out_size
):
...
...
@@ -12,6 +13,14 @@ class Sequential(torch.nn.Sequential):
out_size
=
self
.
_modules
[
m
].
input_spatial_size
(
out_size
)
return
out_size
def
__add__
(
self
,
x
):
r
=
Sequential
()
for
m
in
self
:
r
.
append
(
m
)
for
m
in
x
:
r
.
append
(
m
)
return
r
def
add
(
self
,
module
):
self
.
_modules
[
str
(
len
(
self
.
_modules
))]
=
module
return
self
...
...
sparseconvnet/utils.py
View file @
f6e15d2f
...
...
@@ -127,7 +127,7 @@ def batch_location_tensors(location_tensors):
def
prepare_BLInput
(
l
,
f
):
with
torch
.
no_grad
():
n
=
max
([
x
.
size
(
0
)
for
x
in
l
])
L
=
torch
.
empty
(
len
(
l
),
n
,
l
[
0
].
size
(
1
)).
fill_
(
-
1
)
L
=
torch
.
empty
(
len
(
l
),
n
,
l
[
0
].
size
(
1
)
,
dtype
=
torch
.
int64
).
fill_
(
-
1
)
F
=
torch
.
zeros
(
len
(
l
),
n
,
f
[
0
].
size
(
1
))
for
i
,
(
ll
,
ff
)
in
enumerate
(
zip
(
l
,
f
)):
L
[
i
,:
ll
.
size
(
0
),:].
copy_
(
ll
)
...
...
@@ -156,6 +156,9 @@ def checkpoint_restore(model,exp_name,name2,use_cuda=True,epoch=0):
def
is_power2
(
num
):
return
num
!=
0
and
((
num
&
(
num
-
1
))
==
0
)
def
is_square
(
num
):
return
int
(
num
**
0.5
+
0.5
)
**
2
==
num
def
has_only_one_nonzero_digit
(
num
):
#https://oeis.org/A037124
return
num
!=
0
and
(
num
/
10
**
math
.
floor
(
math
.
log
(
num
,
10
))).
is_integer
()
...
...
@@ -291,9 +294,37 @@ def matplotlib_planes(ax, positions,colors):
pass
ax
.
set_axis_off
()
def
visdom_scatter
(
vis
,
xyz
,
rgb
,
win
=
'3d'
,
markersize
=
3
):
def
visdom_scatter
(
vis
,
xyz
,
rgb
,
win
=
'3d'
,
markersize
=
3
,
title
=
''
):
rgb
=
rgb
.
detach
()
rgb
-=
rgb
.
min
()
rgb
/=
rgb
.
max
()
/
255
+
1e-10
rgb
=
rgb
.
floor
().
cpu
().
numpy
()
vis
.
scatter
(
xyz
,
opts
=
{
'markersize'
:
markersize
,
'markercolor'
:
rgb
},
xyz
.
detach
().
cpu
().
numpy
()
,
opts
=
{
'markersize'
:
markersize
,
'markercolor'
:
rgb
,
'title'
:
title
},
win
=
win
)
def
ply_scatter
(
name
,
xyz
,
rgb
):
rgb
=
rgb
.
detach
()
rgb
-=
rgb
.
min
()
rgb
/=
rgb
.
max
()
/
255
+
1e-10
rgb
=
rgb
.
floor
().
cpu
().
numpy
()
with
open
(
name
+
'.ply'
,
'w'
)
as
f
:
print
(
"""ply
format ascii 1.0
element vertex %d
property float x
property float y
property float z
property uchar red
property uchar green
property uchar blue
end_header"""
%
(
xyz
.
size
(
0
)),
file
=
f
)
for
(
x
,
y
,
z
),(
r
,
g
,
b
)
in
zip
(
xyz
,
rgb
):
print
(
'%d %d %d %d %d %d'
%
(
x
,
y
,
z
,
r
,
g
,
b
),
file
=
f
)
class
VerboseIdentity
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
print
(
x
)
return
x
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment