Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
SparseConvNet
Commits
66986767
Commit
66986767
authored
Jul 31, 2018
by
Benjamin Thomas Graham
Browse files
fixes
parent
edf89af3
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
425 additions
and
411 deletions
+425
-411
sparseconvnet/SCN/CPU/IOLayers.cpp
sparseconvnet/SCN/CPU/IOLayers.cpp
+8
-8
sparseconvnet/SCN/CUDA/Convolution.cpp
sparseconvnet/SCN/CUDA/Convolution.cpp
+6
-7
sparseconvnet/SCN/CUDA/Convolution.cu
sparseconvnet/SCN/CUDA/Convolution.cu
+181
-182
sparseconvnet/SCN/CUDA/Deconvolution.cpp
sparseconvnet/SCN/CUDA/Deconvolution.cpp
+1
-1
sparseconvnet/SCN/CUDA/Deconvolution.cu
sparseconvnet/SCN/CUDA/Deconvolution.cu
+161
-157
sparseconvnet/SCN/cuda.cu
sparseconvnet/SCN/cuda.cu
+55
-55
sparseconvnet/__init__.py
sparseconvnet/__init__.py
+1
-1
sparseconvnet/activations.py
sparseconvnet/activations.py
+12
-0
No files found.
sparseconvnet/SCN/CPU/IOLayers.cpp
View file @
66986767
...
...
@@ -17,14 +17,14 @@ void InputLayer_ForwardPass(T *input_features, T *output_features, Int nRows,
for
(
row
=
0
;
row
<
nRows
;
row
++
)
{
auto
nActive
=
rules
[
0
];
T
multiplier
=
(
average
and
nActive
>
0
)
?
(
T
)
1
/
nActive
:
(
T
)
1
;
auto
out_f
=
output_features
+
row
*
nPlanes
;
auto
r
=
rules
+
row
*
(
1
+
maxActive
);
for
(
Int
i
=
1
;
i
<=
nActive
;
++
i
)
{
auto
in_f
=
input_features
+
nPlanes
*
rules
[
i
]
;
auto
in_f
=
input_features
+
r
[
i
]
*
nPlanes
;
for
(
Int
plane
=
0
;
plane
<
nPlanes
;
plane
++
)
{
out
put_features
[
plane
]
+=
multiplier
*
in_f
[
plane
];
out
_f
[
plane
]
+=
multiplier
*
in_f
[
plane
];
}
}
output_features
+=
nPlanes
;
rules
+=
1
+
maxActive
;
}
}
template
<
typename
T
>
...
...
@@ -36,13 +36,13 @@ void InputLayer_BackwardPass(T *d_input_features, T *d_output_features,
for
(
row
=
0
;
row
<
nRows
;
row
++
)
{
auto
nActive
=
rules
[
0
];
T
multiplier
=
(
average
and
nActive
>
0
)
?
(
T
)
1
/
nActive
:
(
T
)
1
;
auto
d_out_f
=
d_output_features
+
row
*
nPlanes
;
auto
r
=
rules
+
row
*
(
1
+
maxActive
);
for
(
Int
i
=
1
;
i
<=
nActive
;
++
i
)
{
auto
d_in_f
=
d_input_features
+
nPlanes
*
rules
[
i
]
;
auto
d_in_f
=
d_input_features
+
r
[
i
]
*
nPlanes
;
for
(
Int
plane
=
0
;
plane
<
nPlanes
;
plane
++
)
d_in_f
[
plane
]
+=
multiplier
*
d_out
put_features
[
plane
];
d_in_f
[
plane
]
+=
multiplier
*
d_out
_f
[
plane
];
}
d_output_features
+=
nPlanes
;
rules
+=
1
+
maxActive
;
}
}
...
...
sparseconvnet/SCN/CUDA/Convolution.cpp
View file @
66986767
...
...
@@ -5,10 +5,9 @@
// LICENSE file in the root directory of this source tree.
template
<
typename
T
>
void
Convolution_fp_bias
(
T
*
o
f
,
T
*
b
,
Int
nPlanes
,
Int
nActive
Out
);
void
Convolution_fp_bias
(
T
*
o
F
,
T
*
b
,
Int
nPlanes
,
Int
nActive
);
template
<
typename
T
>
void
Convolution_bp_bias
(
T
*
matrix
,
T
*
target
,
Int
nRows
,
Int
nColumns
,
Int
nCOLUMNS
);
void
Convolution_bp_bias
(
T
*
d_oF
,
T
*
d_b
,
Int
nPlanes
,
Int
nActive
);
template
<
typename
T
>
double
dConvolution_forward2
(
T
*
inFeatures
,
T
*
outFeatures
,
T
*
w
,
RuleBook
_rules
,
Int
input_nPlanes
,
...
...
@@ -84,7 +83,7 @@ void cuda_Convolution_backward(
if
(
d_bias
.
numel
())
{
auto
db
=
d_bias
.
data
<
T
>
();
Convolution_bp_bias
(
doF
,
db
,
op
,
op
,
nActiveOut
);
Convolution_bp_bias
(
doF
,
db
,
op
,
nActiveOut
);
}
}
}
...
...
@@ -147,7 +146,7 @@ void cuda_SubmanifoldConvolution_backward(
if
(
d_bias
.
numel
())
{
auto
db
=
d_bias
.
data
<
T
>
();
Convolution_bp_bias
(
doF
,
db
,
op
,
op
,
nActive
);
Convolution_bp_bias
(
doF
,
db
,
op
,
nActive
);
}
}
}
...
...
@@ -216,7 +215,7 @@ void cuda_FullConvolution_backward(
if
(
d_bias
.
numel
())
{
auto
db
=
d_bias
.
data
<
T
>
();
Convolution_bp_bias
(
doF
,
db
,
op
,
op
,
nActiveOut
);
Convolution_bp_bias
(
doF
,
db
,
op
,
nActiveOut
);
}
}
}
...
...
@@ -283,7 +282,7 @@ void cuda_RandomizedStrideConvolution_backward(
if
(
d_bias
.
numel
())
{
auto
db
=
d_bias
.
data
<
T
>
();
Convolution_bp_bias
(
doF
,
db
,
op
,
op
,
nActiveOut
);
Convolution_bp_bias
(
doF
,
db
,
op
,
nActiveOut
);
}
}
}
sparseconvnet/SCN/CUDA/Convolution.cu
View file @
66986767
...
...
@@ -5,6 +5,7 @@
// LICENSE file in the root directory of this source tree.
#include "RuleBookIterator.h"
#define TACC double
template
<
typename
T
>
__global__
void
Convolution_fp_bias_
(
T
*
output_features
,
T
*
bias
,
Int
nPlanes
,
...
...
@@ -30,24 +31,21 @@ void Convolution_fp_bias(T *oF, T *b, Int nPlanes, Int nActive) {
}
template
<
typename
T
>
__global__
void
dColumnSum
(
T
*
matrix
,
T
*
target
,
Int
n
Row
s
,
Int
n
Columns
,
Int
nCOLUMNS
)
{
Int
i
=
blockIdx
.
x
*
32
+
threadIdx
.
x
;
T
t
=
0
;
for
(
Int
j
=
blockIdx
.
y
;
j
<
n
Rows
;
j
+=
32
)
t
+=
matrix
[
j
*
nCOLUMNS
+
i
];
atomicAdd
(
&
target
[
i
],
t
);
__global__
void
Convolution_bp_bias_
(
T
*
d_oF
,
T
*
d_b
,
Int
n
Plane
s
,
Int
n
Active
)
{
Int
n
=
blockIdx
.
x
*
32
+
threadIdx
.
x
;
d_oF
+=
n
;
T
ACC
t
=
0
;
for
(
Int
row
=
blockIdx
.
y
;
row
<
n
Active
;
row
+=
gridDim
.
y
)
t
+=
d_oF
[
row
*
nPlanes
];
atomicAdd
(
&
d_b
[
n
],
t
);
}
template
<
typename
T
>
void
Convolution_bp_bias
(
T
*
matrix
,
T
*
target
,
Int
nRows
,
Int
nColumns
,
Int
nCOLUMNS
)
{
if
(
nColumns
/
32
>
0
)
dColumnSum
<<<
dim3
(
nColumns
/
32
,
32
),
32
>>>
(
matrix
,
target
,
nRows
,
nColumns
,
nCOLUMNS
);
if
(
nColumns
%
32
>
0
)
{
Int
o
=
nColumns
/
32
*
32
;
dColumnSum
<<<
dim3
(
1
,
32
),
nColumns
-
o
>>>
(
matrix
+
o
,
target
+
o
,
nRows
,
nColumns
,
nCOLUMNS
);
void
Convolution_bp_bias
(
T
*
d_oF
,
T
*
d_b
,
Int
nPlanes
,
Int
nActive
)
{
if
(
nPlanes
/
32
>
0
)
Convolution_bp_bias_
<<<
dim3
(
nPlanes
/
32
,
32
),
32
>>>
(
d_oF
,
d_b
,
nPlanes
,
nActive
);
if
(
nPlanes
%
32
>
0
)
{
Int
o
=
nPlanes
/
32
*
32
;
Convolution_bp_bias_
<<<
dim3
(
1
,
32
),
nPlanes
-
o
>>>
(
d_oF
+
o
,
d_b
+
o
,
nPlanes
,
nActive
);
}
}
...
...
@@ -70,7 +68,7 @@ dConvolution_KMxKN_forwardA(T *inFeatures, T *outFeatures, T *w, Int *rules,
outFeatures
+=
n
*
K
;
w
+=
n
*
K
;
T
O
[
V
];
T
ACC
O
[
V
];
__shared__
T
W
[
K
][
K
];
__shared__
T
I
[
K
][
K
];
Int
R0
[
V
];
...
...
@@ -138,7 +136,7 @@ dConvolution_KMxKN_forwardB(T *inFeatures, T *outFeatures, T *w, Int *rules,
outFeatures
+=
n
*
K
;
w
+=
n
*
K
;
T
O
[
V
];
T
ACC
O
[
V
];
__shared__
T
W
[
K
][
K
];
__shared__
T
I
[
K
][
K
];
Int
R0
[
V
];
...
...
@@ -253,8 +251,8 @@ dConvolution_KMxKN_backward_dW_A(T *inFeatures, T *dInFeatures, T *dOutFeatures,
w
+=
m
*
K
*
output_nPlanes
;
dw
+=
m
*
K
*
output_nPlanes
;
T
dI
[
V
];
T
dW
[
V
];
T
ACC
dI
[
V
];
T
ACC
dW
[
V
];
__shared__
T
I
[
K
][
K
];
__shared__
T
dO
[
K
][
K
];
__shared__
T
W
[
K
][
K
];
...
...
@@ -330,8 +328,8 @@ dConvolution_KMxKN_backward_dW_B(T *inFeatures, T *dInFeatures, T *dOutFeatures,
w
+=
m
*
K
*
output_nPlanes
;
dw
+=
m
*
K
*
output_nPlanes
;
T
dI
[
V
];
T
dW
[
V
];
T
ACC
dI
[
V
];
T
ACC
dW
[
V
];
__shared__
T
I
[
K
][
K
];
__shared__
T
dO
[
K
][
K
];
__shared__
T
W
[
K
][
K
];
...
...
@@ -449,7 +447,7 @@ dConvolution_KMxKN_forward2(T *inFeatures, T *outFeatures, T *w, Int *rules,
w
+=
n
*
K
;
Int
KO
=
min
(
K
,
output_nPlanes
-
K
*
n
);
T
O
[
V
];
T
ACC
O
[
V
];
__shared__
T
W
[
K
][
K
];
__shared__
T
I
[
K
][
K
];
__shared__
Int
R
[
K
*
2
];
...
...
@@ -525,8 +523,8 @@ dConvolution_KMxKN_backward_dW2(T *inFeatures, T *dInFeatures, T *dOutFeatures,
dw
+=
m
*
K
*
output_nPlanes
;
Int
KI
=
min
(
K
,
input_nPlanes
-
K
*
m
);
T
dI
[
V
];
T
dW
[
V
];
T
ACC
dI
[
V
];
T
ACC
dW
[
V
];
__shared__
T
I
[
K
][
K
];
__shared__
T
dO
[
K
][
K
];
__shared__
T
W
[
K
][
K
];
...
...
@@ -650,3 +648,4 @@ void dConvolution_backward_dW2(T *inFeatures, T *dInFeatures, T *dOutFeatures,
,
w
+=
c
;
dw
+=
c
;)
}
}
#undef TACC
\ No newline at end of file
sparseconvnet/SCN/CUDA/Deconvolution.cpp
View file @
66986767
...
...
@@ -78,7 +78,7 @@ void cuda_Deconvolution_backward(
dDeconvolution_backward_dW2
<
T
>
(
iF
,
diF
,
doF
,
w
,
dw
,
_rules
,
ip
,
ip
,
op
,
op
);
if
(
d_bias
.
numel
())
{
auto
db
=
d_bias
.
data
<
T
>
();
Convolution_bp_bias
(
doF
,
db
,
op
,
op
,
nActiveOut
);
Convolution_bp_bias
(
doF
,
db
,
op
,
nActiveOut
);
}
}
}
sparseconvnet/SCN/CUDA/Deconvolution.cu
View file @
66986767
...
...
@@ -4,6 +4,8 @@
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#define TACC double
template
<
typename
T
,
Int
K
,
Int
V
>
__global__
void
dDeconvolution_KMxKN_forwardA
(
T
*
inFeatures
,
T
*
outFeatures
,
T
*
w
,
Int
*
rules
,
...
...
@@ -23,7 +25,7 @@ dDeconvolution_KMxKN_forwardA(T *inFeatures, T *outFeatures, T *w, Int *rules,
outFeatures
+=
n
*
K
;
w
+=
n
*
K
;
T
O
[
V
];
T
ACC
O
[
V
];
__shared__
T
W
[
K
][
K
];
__shared__
T
I
[
K
][
K
];
Int
R0
[
V
];
...
...
@@ -91,7 +93,7 @@ dDeconvolution_KMxKN_forwardB(T *inFeatures, T *outFeatures, T *w, Int *rules,
outFeatures
+=
n
*
K
;
w
+=
n
*
K
;
T
O
[
V
];
T
ACC
O
[
V
];
__shared__
T
W
[
K
][
K
];
__shared__
T
I
[
K
][
K
];
Int
R0
[
V
];
...
...
@@ -205,8 +207,8 @@ __global__ void dDeconvolution_KMxKN_backward_dW_A(
w
+=
m
*
K
*
output_nPlanes
;
dw
+=
m
*
K
*
output_nPlanes
;
T
dI
[
V
];
T
dW
[
V
];
T
ACC
dI
[
V
];
T
ACC
dW
[
V
];
__shared__
T
I
[
K
][
K
];
__shared__
T
dO
[
K
][
K
];
__shared__
T
W
[
K
][
K
];
...
...
@@ -281,8 +283,8 @@ __global__ void dDeconvolution_KMxKN_backward_dW_B(
w
+=
m
*
K
*
output_nPlanes
;
dw
+=
m
*
K
*
output_nPlanes
;
T
dI
[
V
];
T
dW
[
V
];
T
ACC
dI
[
V
];
T
ACC
dW
[
V
];
__shared__
T
I
[
K
][
K
];
__shared__
T
dO
[
K
][
K
];
__shared__
T
W
[
K
][
K
];
...
...
@@ -400,7 +402,7 @@ dDeconvolution_KMxKN_forward2(T *inFeatures, T *outFeatures, T *w, Int *rules,
w
+=
n
*
K
;
Int
KO
=
min
(
K
,
output_nPlanes
-
K
*
n
);
T
O
[
V
];
T
ACC
O
[
V
];
__shared__
T
W
[
K
][
K
];
__shared__
T
I
[
K
][
K
];
__shared__
Int
R
[
K
*
2
];
...
...
@@ -476,8 +478,8 @@ dDeconvolution_KMxKN_backward_dW2(T *inFeatures, T *dInFeatures,
dw
+=
m
*
K
*
output_nPlanes
;
Int
KI
=
min
(
K
,
input_nPlanes
-
K
*
m
);
T
dI
[
V
];
T
dW
[
V
];
T
ACC
dI
[
V
];
T
ACC
dW
[
V
];
__shared__
T
I
[
K
][
K
];
__shared__
T
dO
[
K
][
K
];
__shared__
T
W
[
K
][
K
];
...
...
@@ -601,3 +603,5 @@ void dDeconvolution_backward_dW2(T *inFeatures, T *dInFeatures, T *dOutFeatures,
,
w
+=
c
;
dw
+=
c
;)
}
}
#undef TACC
\ No newline at end of file
sparseconvnet/SCN/cuda.cu
View file @
66986767
...
...
@@ -43,10 +43,10 @@ template void cuda_AveragePooling_BackwardPass<float>(
float
*
d_input_features
,
float
*
d_output_features
,
Int
nPlanes
,
Int
input_stride
,
Int
output_stride
,
RuleBook
_rules
,
Int
filterVolume
);
template
void
Convolution_fp_bias
<
float
>(
float
*
o
f
,
float
*
b
,
Int
op
,
template
void
Convolution_fp_bias
<
float
>(
float
*
o
F
,
float
*
b
,
Int
nPlanes
,
Int
nActive
);
template
void
Convolution_bp_bias
<
float
>(
float
*
matrix
,
float
*
target
,
Int
nRows
,
Int
nColumn
s
,
Int
n
COLUMNS
);
template
void
Convolution_bp_bias
<
float
>(
float
*
d_oF
,
float
*
d_b
,
Int
nPlane
s
,
Int
n
Active
);
template
double
dConvolution_forward2
<
float
>(
float
*
inFeatures
,
float
*
outFeatures
,
float
*
w
,
RuleBook
_rules
,
Int
input_nPlanes
,
Int
input_stride
,
Int
output_nPlanes
,
Int
output_stride
);
...
...
sparseconvnet/__init__.py
View file @
66986767
...
...
@@ -6,7 +6,7 @@
forward_pass_multiplyAdd_count
=
0
forward_pass_hidden_states
=
0
from
.activations
import
Tanh
,
Sigmoid
,
ReLU
,
ELU
,
SELU
,
BatchNormELU
from
.activations
import
Tanh
,
Sigmoid
,
ReLU
,
LeakyReLU
,
ELU
,
SELU
,
BatchNormELU
from
.averagePooling
import
AveragePooling
from
.batchNormalization
import
BatchNormalization
,
BatchNormReLU
,
BatchNormLeakyReLU
from
.classificationTrainValidate
import
ClassificationTrainValidate
...
...
sparseconvnet/activations.py
View file @
66986767
...
...
@@ -22,6 +22,18 @@ class Sigmoid(Module):
return
output
class
LeakyReLU
(
Module
):
def
__init__
(
self
,
leak
=
1
/
3
):
Module
.
__init__
(
self
)
self
.
leak
=
leak
def
forward
(
self
,
input
):
output
=
SparseConvNetTensor
()
output
.
features
=
F
.
leaky_relu
(
input
.
features
,
self
.
leak
)
output
.
metadata
=
input
.
metadata
output
.
spatial_size
=
input
.
spatial_size
return
output
class
Tanh
(
Module
):
def
forward
(
self
,
input
):
output
=
SparseConvNetTensor
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment