Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
SparseConvNet
Commits
5f0860fc
Commit
5f0860fc
authored
Sep 13, 2017
by
Benjamin Thomas Graham
Browse files
DenseToSparse, tidying
parent
6de372c3
Changes
38
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
544 additions
and
562 deletions
+544
-562
PyTorch/setup.py
PyTorch/setup.py
+1
-0
PyTorch/sparseconvnet/SCN/generic/CPU/BatchNormalization.cpp
PyTorch/sparseconvnet/SCN/generic/CPU/BatchNormalization.cpp
+41
-33
PyTorch/sparseconvnet/SCN/generic/CPU/Convolution.cpp
PyTorch/sparseconvnet/SCN/generic/CPU/Convolution.cpp
+53
-47
PyTorch/sparseconvnet/SCN/generic/CPU/SparseToDense.cpp
PyTorch/sparseconvnet/SCN/generic/CPU/SparseToDense.cpp
+33
-34
PyTorch/sparseconvnet/SCN/generic/CPU/SparseToDense.h
PyTorch/sparseconvnet/SCN/generic/CPU/SparseToDense.h
+17
-18
PyTorch/sparseconvnet/SCN/generic/GPU/AffineReluTrivialConvolution.cu
...seconvnet/SCN/generic/GPU/AffineReluTrivialConvolution.cu
+0
-1
PyTorch/sparseconvnet/SCN/generic/GPU/AffineReluTrivialConvolution.h
...rseconvnet/SCN/generic/GPU/AffineReluTrivialConvolution.h
+67
-156
PyTorch/sparseconvnet/SCN/generic/GPU/BatchNormalization.cu
PyTorch/sparseconvnet/SCN/generic/GPU/BatchNormalization.cu
+24
-22
PyTorch/sparseconvnet/SCN/generic/GPU/Convolution.cu
PyTorch/sparseconvnet/SCN/generic/GPU/Convolution.cu
+78
-70
PyTorch/sparseconvnet/SCN/generic/GPU/Convolution.h
PyTorch/sparseconvnet/SCN/generic/GPU/Convolution.h
+20
-9
PyTorch/sparseconvnet/SCN/generic/GPU/Deconvolution.h
PyTorch/sparseconvnet/SCN/generic/GPU/Deconvolution.h
+20
-9
PyTorch/sparseconvnet/SCN/generic/GPU/SparseToDense.cu
PyTorch/sparseconvnet/SCN/generic/GPU/SparseToDense.cu
+31
-27
PyTorch/sparseconvnet/SCN/generic/GPU/SparseToDense.h
PyTorch/sparseconvnet/SCN/generic/GPU/SparseToDense.h
+19
-18
PyTorch/sparseconvnet/SCN/generic/GPU/THGenerateCudaFloatTypes.h
.../sparseconvnet/SCN/generic/GPU/THGenerateCudaFloatTypes.h
+18
-0
PyTorch/sparseconvnet/SCN/generic/Geometry/ConvolutionRules.h
...rch/sparseconvnet/SCN/generic/Geometry/ConvolutionRules.h
+28
-34
PyTorch/sparseconvnet/SCN/generic/Geometry/Metadata.cpp
PyTorch/sparseconvnet/SCN/generic/Geometry/Metadata.cpp
+4
-5
PyTorch/sparseconvnet/SCN/generic/Geometry/Metadata.h
PyTorch/sparseconvnet/SCN/generic/Geometry/Metadata.h
+12
-1
PyTorch/sparseconvnet/SCN/generic/Geometry/ValidConvolutionRules.h
...parseconvnet/SCN/generic/Geometry/ValidConvolutionRules.h
+5
-7
PyTorch/sparseconvnet/SCN/header_cpu.h
PyTorch/sparseconvnet/SCN/header_cpu.h
+62
-61
PyTorch/sparseconvnet/SCN/header_gpu.h
PyTorch/sparseconvnet/SCN/header_gpu.h
+11
-10
No files found.
PyTorch/setup.py
View file @
5f0860fc
...
...
@@ -35,6 +35,7 @@ if torch.cuda.is_available():
'sparseconvnet/SCN/header_cpu.h'
,
'sparseconvnet/SCN/header_gpu.h'
],
sources
=
[],
include_dirs
=
[
os
.
path
.
expandvars
(
'$CUDA_HOME'
)
+
'/include'
],
extra_objects
=
[
this_dir
+
'/sparseconvnet/SCN/init.cu.o'
],
...
...
PyTorch/sparseconvnet/SCN/generic/CPU/BatchNormalization.cpp
View file @
5f0860fc
...
...
@@ -14,18 +14,20 @@ extern "C" void scn_R_(BatchNormalization_updateOutput)(
THTensor
*
saveInvStd
,
THTensor
*
runningMean
,
THTensor
*
runningVar
,
THTensor
*
weight
,
THTensor
*
bias
,
real
eps
,
real
momentum
,
bool
train
,
real
leakiness
)
{
THTensor_
(
resizeAs
)(
output_features
,
input_features
);
auto
nActive
=
input_features
->
size
[
0
];
auto
nPlanes
=
input_features
->
size
[
1
];
auto
input_stride
=
input_features
->
stride
[
0
];
auto
output_stride
=
output_features
->
stride
[
0
];
BatchNormalization_ForwardPass
<
real
>
(
THTensor_
(
data
)(
input_features
),
THTensor_
(
data
)(
output_features
),
nPlanes
,
input_stride
,
output_stride
,
nActive
,
THTensor_
(
data
)(
saveMean
),
THTensor_
(
data
)(
saveInvStd
),
THTensor_
(
data
)(
runningMean
),
THTensor_
(
data
)(
runningVar
),
THOptionalTensorData
(
weight
),
THOptionalTensorData
(
bias
),
eps
,
momentum
,
train
,
leakiness
);
if
(
input_features
->
nDimension
==
2
)
{
auto
nActive
=
input_features
->
size
[
0
];
auto
nPlanes
=
input_features
->
size
[
1
];
auto
input_stride
=
input_features
->
stride
[
0
];
auto
output_stride
=
output_features
->
stride
[
0
];
BatchNormalization_ForwardPass
<
real
>
(
THTensor_
(
data
)(
input_features
),
THTensor_
(
data
)(
output_features
),
nPlanes
,
input_stride
,
output_stride
,
nActive
,
THTensor_
(
data
)(
saveMean
),
THTensor_
(
data
)(
saveInvStd
),
THTensor_
(
data
)(
runningMean
),
THTensor_
(
data
)(
runningVar
),
THOptionalTensorData
(
weight
),
THOptionalTensorData
(
bias
),
eps
,
momentum
,
train
,
leakiness
);
}
}
extern
"C"
void
scn_R_
(
BatchNormalizationInTensor_updateOutput
)(
...
...
@@ -34,17 +36,20 @@ extern "C" void scn_R_(BatchNormalizationInTensor_updateOutput)(
THTensor
*
weight
,
THTensor
*
bias
,
real
eps
,
real
momentum
,
bool
train
,
real
leakiness
)
{
auto
nActive
=
input_features
->
size
[
0
];
auto
nPlanes
=
input_features
->
size
[
1
];
auto
input_stride
=
input_features
->
stride
[
0
];
auto
output_stride
=
output_features
->
stride
[
0
];
if
(
input_features
->
nDimension
==
2
)
{
auto
nActive
=
input_features
->
size
[
0
];
auto
nPlanes
=
input_features
->
size
[
1
];
auto
input_stride
=
input_features
->
stride
[
0
];
auto
output_stride
=
output_features
->
stride
[
0
];
BatchNormalization_ForwardPass
<
real
>
(
THTensor_
(
data
)(
input_features
),
THTensor_
(
data
)(
output_features
),
nPlanes
,
input_stride
,
output_stride
,
nActive
,
THTensor_
(
data
)(
saveMean
),
THTensor_
(
data
)(
saveInvStd
),
THTensor_
(
data
)(
runningMean
),
THTensor_
(
data
)(
runningVar
),
THOptionalTensorData
(
weight
),
THOptionalTensorData
(
bias
),
eps
,
momentum
,
train
,
leakiness
);
BatchNormalization_ForwardPass
<
real
>
(
THTensor_
(
data
)(
input_features
),
THTensor_
(
data
)(
output_features
),
nPlanes
,
input_stride
,
output_stride
,
nActive
,
THTensor_
(
data
)(
saveMean
),
THTensor_
(
data
)(
saveInvStd
),
THTensor_
(
data
)(
runningMean
),
THTensor_
(
data
)(
runningVar
),
THOptionalTensorData
(
weight
),
THOptionalTensorData
(
bias
),
eps
,
momentum
,
train
,
leakiness
);
}
}
extern
"C"
void
scn_R_
(
BatchNormalization_backward
)(
...
...
@@ -55,17 +60,20 @@ extern "C" void scn_R_(BatchNormalization_backward)(
real
leakiness
)
{
THTensor_
(
resizeAs
)(
d_input_features
,
input_features
);
auto
nActive
=
input_features
->
size
[
0
];
auto
nPlanes
=
input_features
->
size
[
1
];
auto
input_stride
=
input_features
->
stride
[
0
];
auto
output_stride
=
output_features
->
stride
[
0
];
BatchNormalization_BackwardPass
<
real
>
(
THTensor_
(
data
)(
input_features
),
THTensor_
(
data
)(
d_input_features
),
THTensor_
(
data
)(
output_features
),
THTensor_
(
data
)(
d_output_features
),
nPlanes
,
input_stride
,
output_stride
,
nActive
,
THTensor_
(
data
)(
saveMean
),
THTensor_
(
data
)(
saveInvStd
),
THTensor_
(
data
)(
runningMean
),
THTensor_
(
data
)(
runningVar
),
THOptionalTensorData
(
weight
),
THOptionalTensorData
(
bias
),
THOptionalTensorData
(
d_weight
),
THOptionalTensorData
(
d_bias
),
leakiness
);
if
(
input_features
->
nDimension
==
2
)
{
auto
nActive
=
input_features
->
size
[
0
];
auto
nPlanes
=
input_features
->
size
[
1
];
auto
input_stride
=
input_features
->
stride
[
0
];
auto
output_stride
=
output_features
->
stride
[
0
];
BatchNormalization_BackwardPass
<
real
>
(
THTensor_
(
data
)(
input_features
),
THTensor_
(
data
)(
d_input_features
),
THTensor_
(
data
)(
output_features
),
THTensor_
(
data
)(
d_output_features
),
nPlanes
,
input_stride
,
output_stride
,
nActive
,
THTensor_
(
data
)(
saveMean
),
THTensor_
(
data
)(
saveInvStd
),
THTensor_
(
data
)(
runningMean
),
THTensor_
(
data
)(
runningVar
),
THOptionalTensorData
(
weight
),
THOptionalTensorData
(
bias
),
THOptionalTensorData
(
d_weight
),
THOptionalTensorData
(
d_bias
),
leakiness
);
}
}
#endif
PyTorch/sparseconvnet/SCN/generic/CPU/Convolution.cpp
View file @
5f0860fc
...
...
@@ -23,17 +23,19 @@ extern "C" double scn_DR_(Convolution_updateOutput)(
if
(
not
bias
)
THTensor_
(
zero
)(
output_features
);
auto
iF
=
THTensor_
(
data
)(
input_features
);
auto
oF
=
THTensor_
(
data
)(
output_features
);
auto
ip
=
input_features
->
size
[
1
];
auto
op
=
output_features
->
size
[
1
];
auto
w
=
THTensor_
(
data
)(
weight
);
auto
b
=
THOptionalTensorData
(
bias
);
Convolution_ForwardPass
(
iF
,
ip
,
ip
,
oF
,
op
,
op
,
w
,
b
,
_rules
,
nActive
,
THBlas_
(
gemm
));
double
flops
=
0
;
for
(
auto
&
r
:
_rules
)
flops
+=
r
.
size
()
/
2
*
ip
*
op
;
if
(
nActive
)
{
auto
iF
=
THTensor_
(
data
)(
input_features
);
auto
oF
=
THTensor_
(
data
)(
output_features
);
auto
ip
=
input_features
->
size
[
1
];
auto
op
=
output_features
->
size
[
1
];
auto
w
=
THTensor_
(
data
)(
weight
);
auto
b
=
THOptionalTensorData
(
bias
);
Convolution_ForwardPass
(
iF
,
ip
,
ip
,
oF
,
op
,
op
,
w
,
b
,
_rules
,
nActive
,
THBlas_
(
gemm
));
for
(
auto
&
r
:
_rules
)
flops
+=
r
.
size
()
/
2
*
ip
*
op
;
}
return
flops
;
}
...
...
@@ -51,17 +53,19 @@ extern "C" void scn_DR_(Convolution_backward)(
THTensor_
(
resizeAs
)(
d_input_features
,
input_features
);
THTensor_
(
zero
)(
d_input_features
);
auto
iF
=
THTensor_
(
data
)(
input_features
);
auto
diF
=
THTensor_
(
data
)(
d_input_features
);
auto
doF
=
THTensor_
(
data
)(
d_output_features
);
auto
ip
=
input_features
->
size
[
1
];
auto
op
=
d_output_features
->
size
[
1
];
auto
w
=
THTensor_
(
data
)(
weight
);
auto
dw
=
THTensor_
(
data
)(
d_weight
);
auto
db
=
THOptionalTensorData
(
d_bias
);
Convolution_BackwardPass
(
iF
,
diF
,
ip
,
ip
,
doF
,
op
,
op
,
w
,
dw
,
db
,
_rules
,
nActive
,
THBlas_
(
gemm
));
if
(
nActive
)
{
auto
iF
=
THTensor_
(
data
)(
input_features
);
auto
diF
=
THTensor_
(
data
)(
d_input_features
);
auto
doF
=
THTensor_
(
data
)(
d_output_features
);
auto
ip
=
input_features
->
size
[
1
];
auto
op
=
d_output_features
->
size
[
1
];
auto
w
=
THTensor_
(
data
)(
weight
);
auto
dw
=
THTensor_
(
data
)(
d_weight
);
auto
db
=
THOptionalTensorData
(
d_bias
);
Convolution_BackwardPass
(
iF
,
diF
,
ip
,
ip
,
doF
,
op
,
op
,
w
,
dw
,
db
,
_rules
,
nActive
,
THBlas_
(
gemm
));
}
}
extern
"C"
double
scn_DR_
(
ValidConvolution_updateOutput
)(
...
...
@@ -71,24 +75,25 @@ extern "C" double scn_DR_(ValidConvolution_updateOutput)(
SCN_INITIALIZE_AND_REFERENCE
(
Metadata
<
Dimension
>
,
m
)
auto
_rules
=
_m
.
getValidRuleBook
(
inputSize
,
filterSize
,
true
);
uInt
nActive
=
input_features
->
size
[
0
]
;
uInt
nActive
=
_m
.
getNActive
(
inputSize
)
;
THTensor_
(
resize2d
)(
output_features
,
nActive
,
weight
->
size
[
1
]);
if
(
not
bias
)
THTensor_
(
zero
)(
output_features
);
auto
iF
=
THTensor_
(
data
)(
input_features
);
auto
oF
=
THTensor_
(
data
)(
output_features
);
auto
ip
=
input_features
->
size
[
1
];
auto
op
=
output_features
->
size
[
1
];
auto
w
=
THTensor_
(
data
)(
weight
);
auto
b
=
THOptionalTensorData
(
bias
);
Convolution_ForwardPass
(
iF
,
ip
,
ip
,
oF
,
op
,
op
,
w
,
b
,
_rules
,
nActive
,
THBlas_
(
gemm
));
double
flops
=
0
;
for
(
auto
&
r
:
_rules
)
flops
+=
r
.
size
()
/
2
*
ip
*
op
;
if
(
nActive
)
{
auto
iF
=
THTensor_
(
data
)(
input_features
);
auto
oF
=
THTensor_
(
data
)(
output_features
);
auto
ip
=
input_features
->
size
[
1
];
auto
op
=
output_features
->
size
[
1
];
auto
w
=
THTensor_
(
data
)(
weight
);
auto
b
=
THOptionalTensorData
(
bias
);
Convolution_ForwardPass
(
iF
,
ip
,
ip
,
oF
,
op
,
op
,
w
,
b
,
_rules
,
nActive
,
THBlas_
(
gemm
));
for
(
auto
&
r
:
_rules
)
flops
+=
r
.
size
()
/
2
*
ip
*
op
;
}
return
flops
;
}
...
...
@@ -100,21 +105,22 @@ extern "C" void scn_DR_(ValidConvolution_backward)(
SCN_INITIALIZE_AND_REFERENCE
(
Metadata
<
Dimension
>
,
m
)
auto
_rules
=
_m
.
getValidRuleBook
(
inputSize
,
filterSize
,
true
);
uInt
nActive
=
input_features
->
size
[
0
]
;
uInt
nActive
=
_m
.
getNActive
(
inputSize
)
;
THTensor_
(
resizeAs
)(
d_input_features
,
input_features
);
THTensor_
(
zero
)(
d_input_features
);
auto
iF
=
THTensor_
(
data
)(
input_features
);
auto
diF
=
THTensor_
(
data
)(
d_input_features
);
auto
doF
=
THTensor_
(
data
)(
d_output_features
);
auto
ip
=
input_features
->
size
[
1
];
auto
op
=
d_output_features
->
size
[
1
];
auto
w
=
THTensor_
(
data
)(
weight
);
auto
dw
=
THTensor_
(
data
)(
d_weight
);
auto
db
=
THOptionalTensorData
(
d_bias
);
Convolution_BackwardPass
(
iF
,
diF
,
ip
,
ip
,
doF
,
op
,
op
,
w
,
dw
,
db
,
_rules
,
nActive
,
THBlas_
(
gemm
));
if
(
nActive
)
{
auto
iF
=
THTensor_
(
data
)(
input_features
);
auto
diF
=
THTensor_
(
data
)(
d_input_features
);
auto
doF
=
THTensor_
(
data
)(
d_output_features
);
auto
ip
=
input_features
->
size
[
1
];
auto
op
=
d_output_features
->
size
[
1
];
auto
w
=
THTensor_
(
data
)(
weight
);
auto
dw
=
THTensor_
(
data
)(
d_weight
);
auto
db
=
THOptionalTensorData
(
d_bias
);
Convolution_BackwardPass
(
iF
,
diF
,
ip
,
ip
,
doF
,
op
,
op
,
w
,
dw
,
db
,
_rules
,
nActive
,
THBlas_
(
gemm
));
}
}
#endif
PyTorch/sparseconvnet/SCN/generic/CPU/SparseToDense.cpp
View file @
5f0860fc
...
...
@@ -9,34 +9,32 @@
#else
#include "SparseToDense.h"
extern
"C"
void
scn_DR_
(
SparseToDense_updateOutput
)(
THLongTensor
*
inputSize
,
void
**
m
,
THTensor
*
in
put_features
,
THTensor
*
output_features
,
void
*
rulesBuffer
)
{
extern
"C"
void
scn_DR_
(
SparseToDense_updateOutput
)(
THLongTensor
*
inputSize
,
void
**
m
,
THTensor
*
input_features
,
THTensor
*
out
put_features
,
void
*
rulesBuffer
,
long
nPlanes
)
{
SCN_INITIALIZE_AND_REFERENCE
(
Metadata
<
Dimension
>
,
m
)
SCN_INITIALIZE_AND_REFERENCE
(
Metadata
<
Dimension
>
,
m
)
{
{
long
sz
[
Dimension
+
2
];
sz
[
0
]
=
_m
.
inputSGs
->
size
();
sz
[
1
]
=
input_features
->
size
[
1
];
for
(
int
i
=
0
;
i
<
Dimension
;
i
++
)
{
auto
x
=
THLongTensor_data
(
inputSize
)[
i
];
sz
[
i
+
2
]
=
x
;
}
sz
[
0
]
=
_m
.
grids
.
begin
()
->
second
.
size
();
sz
[
1
]
=
nPlanes
;
// input_features->size[1];
std
::
memcpy
(
sz
+
2
,
THLongTensor_data
(
inputSize
),
sizeof
(
long
)
*
Dimension
);
THTensor_
(
resizeNd
)(
output_features
,
Dimension
+
2
,
sz
,
NULL
);
THTensor_
(
zero
)(
output_features
);
}
auto
_rules
=
_m
.
getSparseToDenseRuleBook
(
inputSize
,
true
);
auto
spatialVolume
=
_rules
.
size
();
uInt
nPlanes
=
input_features
->
size
[
1
];
auto
iF
=
THTensor_
(
data
)(
input_features
);
auto
oF
=
THTensor_
(
data
)(
output_features
);
for
(
auto
&
r
:
_rules
)
{
uInt
nHot
=
r
.
size
()
/
2
;
SparseToDense_ForwardPass
<
real
>
(
iF
,
oF
,
nPlanes
,
spatialVolume
,
&
r
[
0
],
nHot
);
oF
++
;
if
(
input_features
->
nDimension
==
2
)
{
auto
_rules
=
_m
.
getSparseToDenseRuleBook
(
inputSize
,
true
);
uInt
nPlanes
=
input_features
->
size
[
1
];
auto
iF
=
THTensor_
(
data
)(
input_features
);
auto
oF
=
THTensor_
(
data
)(
output_features
);
long
spatialVolume
=
THLongTensor_prodall
(
inputSize
);
for
(
auto
&
r
:
_rules
)
{
uInt
nHot
=
r
.
size
()
/
2
;
SparseToDense_ForwardPass
<
real
>
(
iF
,
oF
,
nPlanes
,
spatialVolume
,
&
r
[
0
],
nHot
);
oF
+=
nPlanes
*
spatialVolume
;
}
}
}
extern
"C"
void
scn_DR_
(
SparseToDense_updateGradInput
)(
...
...
@@ -44,21 +42,22 @@ extern "C" void scn_DR_(SparseToDense_updateGradInput)(
THTensor
*
d_input_features
,
THTensor
*
d_output_features
,
void
*
rulesBuffer
)
{
SCN_INITIALIZE_AND_REFERENCE
(
Metadata
<
Dimension
>
,
m
)
THTensor_
(
resizeAs
)(
d_input_features
,
input_features
);
THTensor_
(
zero
)(
d_input_features
);
SCN_INITIALIZE_AND_REFERENCE
(
Metadata
<
Dimension
>
,
m
)
auto
_rules
=
_m
.
getSparseToDenseRuleBook
(
inputSize
,
true
);
auto
spatialVolume
=
_rules
.
size
();
uInt
nPlanes
=
d_input_features
->
size
[
1
];
auto
diF
=
THTensor_
(
data
)(
d_input_features
);
auto
doF
=
THTensor_
(
data
)(
d_output_features
);
if
(
input_features
->
nDimension
==
2
)
{
long
spatialVolume
=
THLongTensor_prodall
(
inputSize
);
uInt
nPlanes
=
d_input_features
->
size
[
1
];
auto
diF
=
THTensor_
(
data
)(
d_input_features
);
auto
doF
=
THTensor_
(
data
)(
d_output_features
);
for
(
auto
&
r
:
_rules
)
{
uInt
nHot
=
r
.
size
()
/
2
;
SparseToDense_BackwardPass
<
real
>
(
diF
,
doF
,
nPlanes
,
spatialVolume
,
&
r
[
0
],
nHot
);
doF
++
;
for
(
auto
&
r
:
_rules
)
{
uInt
nHot
=
r
.
size
()
/
2
;
SparseToDense_BackwardPass
<
real
>
(
diF
,
doF
,
nPlanes
,
spatialVolume
,
&
r
[
0
],
nHot
);
doF
+=
nPlanes
*
spatialVolume
;
}
}
}
#endif
PyTorch/sparseconvnet/SCN/generic/CPU/SparseToDense.h
View file @
5f0860fc
...
...
@@ -10,27 +10,26 @@
template
<
typename
T
>
void
SparseToDense_ForwardPass
(
T
*
input_features
,
T
*
output_features
,
uInt
nPlanes
,
uInt
spatialVolume
,
uInt
*
rules
,
int
nHot
)
{
for
(
uInt
outSite
=
0
;
outSite
<
nHot
;
outSite
++
)
{
T
*
i
=
&
input_features
[
rules
[
2
*
outSite
]
*
nPlanes
];
uInt
sample
=
rules
[
2
*
outSite
+
1
];
for
(
uInt
plane
=
0
;
plane
<
nPlanes
;
plane
++
)
output_features
[(
sample
*
nPlanes
+
plane
)
*
spatialVolume
]
=
i
[
plane
];
}
uInt
nPlanes
,
uInt
spatialVolume
,
uInt
*
rules
,
int
nHot
)
{
for
(
uInt
outSite
=
0
;
outSite
<
nHot
;
outSite
++
)
{
T
*
i
=
input_features
+
rules
[
2
*
outSite
]
*
nPlanes
;
T
*
o
=
output_features
+
rules
[
2
*
outSite
+
1
];
for
(
uInt
plane
=
0
;
plane
<
nPlanes
;
plane
++
)
o
[
plane
*
spatialVolume
]
=
i
[
plane
];
}
}
template
<
typename
T
>
void
SparseToDense_BackwardPass
(
T
*
d_input_features
,
T
*
d_output_features
,
uInt
nPlanes
,
uInt
spatialVolume
,
uInt
*
rules
,
int
nHot
)
{
uInt
nPlanes
,
uInt
spatialVolume
,
uInt
*
rules
,
int
nHot
)
{
for
(
uInt
outSite
=
0
;
outSite
<
nHot
;
outSite
++
)
{
T
*
di
=
&
d_input_features
[
rules
[
2
*
outSite
]
*
nPlanes
]
;
uInt
sample
=
rules
[
2
*
outSite
+
1
];
for
(
uInt
plane
=
0
;
plane
<
nPlanes
;
plane
++
)
di
[
plane
]
=
d_output_features
[(
sample
*
nPlanes
+
plane
)
*
spatialVolume
];
}
}
for
(
uInt
outSite
=
0
;
outSite
<
nHot
;
outSite
++
)
{
T
*
d
_
i
=
d_input_features
+
rules
[
2
*
outSite
]
*
nPlanes
;
auto
d_o
=
d_output_features
+
rules
[
2
*
outSite
+
1
];
for
(
uInt
plane
=
0
;
plane
<
nPlanes
;
plane
++
)
d
_
i
[
plane
]
=
d_o
[
plane
*
spatialVolume
];
}
}
#endif
/* CPU_SPARSETODENSE_H */
PyTorch/sparseconvnet/SCN/generic/GPU/AffineReluTrivialConvolution.cu
View file @
5f0860fc
...
...
@@ -10,7 +10,6 @@
#include "AffineReluTrivialConvolution.h"
#include <algorithm>
#include <iostream>
extern
"C"
void
scn_R_
(
AffineReluTrivialConvolution_updateOutput
)(
THCTensor
*
input_features
,
THCTensor
*
output_features
,
...
...
PyTorch/sparseconvnet/SCN/generic/GPU/AffineReluTrivialConvolution.h
View file @
5f0860fc
...
...
@@ -155,6 +155,27 @@ __global__ void dAffineReluTrivialConvolution_forwardB(
}
}
#define FOO(T, K, V) \
{ \
if (input_nPlanes % K == 0 and output_nPlanes % K == 0) { \
uInt o = (nActive / K) * K; \
if (o > 0) \
dAffineReluTrivialConvolution_forwardA<T, K, V> << < \
dim3(std::min(o / K, (uInt)512), output_nPlanes / K), \
dim3(K, K / V), 0, THCState_getCurrentStream(state)>>> \
(inFeatures, outFeatures, affineWeight, affineBias, convWeight, \
input_nPlanes, input_stride, output_nPlanes, output_stride, o); \
if (nActive > o) \
dAffineReluTrivialConvolution_forwardB<T, K, V> << < \
dim3(1, output_nPlanes / K), dim3(K, K / V), 0, \
THCState_getCurrentStream(state)>>> \
(inFeatures + o * input_stride, outFeatures + o * output_stride, \
affineWeight, affineBias, convWeight, input_nPlanes, \
input_stride, output_nPlanes, output_stride, nActive - o); \
return; \
} \
}
template
<
typename
T
>
void
dAffineReluTrivialConvolution_forward
(
T
*
inFeatures
,
T
*
outFeatures
,
T
*
affineWeight
,
T
*
affineBias
,
...
...
@@ -162,92 +183,25 @@ void dAffineReluTrivialConvolution_forward(T *inFeatures, T *outFeatures,
uInt
input_stride
,
uInt
output_nPlanes
,
uInt
output_stride
,
uInt
nActive
)
{
{
const
uInt
K
=
64
;
const
uInt
V
=
16
;
if
(
input_nPlanes
%
K
==
0
and
output_nPlanes
%
K
==
0
)
{
uInt
o
=
(
nActive
/
K
)
*
K
;
if
(
o
>
0
)
dAffineReluTrivialConvolution_forwardA
<
T
,
K
,
V
><<<
dim3
(
std
::
min
(
o
/
K
,
(
uInt
)
512
),
output_nPlanes
/
K
),
dim3
(
K
,
K
/
V
),
0
,
THCState_getCurrentStream
(
state
)
>>>
(
inFeatures
,
outFeatures
,
affineWeight
,
affineBias
,
convWeight
,
input_nPlanes
,
input_stride
,
output_nPlanes
,
output_stride
,
o
);
if
(
nActive
>
o
)
dAffineReluTrivialConvolution_forwardB
<
T
,
K
,
V
><<<
dim3
(
1
,
output_nPlanes
/
K
),
dim3
(
K
,
K
/
V
),
0
,
THCState_getCurrentStream
(
state
)
>>>
(
inFeatures
+
o
*
input_stride
,
outFeatures
+
o
*
output_stride
,
affineWeight
,
affineBias
,
convWeight
,
input_nPlanes
,
input_stride
,
output_nPlanes
,
output_stride
,
nActive
-
o
);
return
;
}
}
{
const
uInt
K
=
32
;
const
uInt
V
=
4
;
if
(
input_nPlanes
%
K
==
0
and
output_nPlanes
%
K
==
0
)
{
uInt
o
=
(
nActive
/
K
)
*
K
;
if
(
o
>
0
)
dAffineReluTrivialConvolution_forwardA
<
T
,
K
,
V
><<<
dim3
(
std
::
min
(
o
/
K
,
(
uInt
)
512
),
output_nPlanes
/
K
),
dim3
(
K
,
K
/
V
),
0
,
THCState_getCurrentStream
(
state
)
>>>
(
inFeatures
,
outFeatures
,
affineWeight
,
affineBias
,
convWeight
,
input_nPlanes
,
input_stride
,
output_nPlanes
,
output_stride
,
o
);
if
(
nActive
>
o
)
dAffineReluTrivialConvolution_forwardB
<
T
,
K
,
V
><<<
dim3
(
1
,
output_nPlanes
/
K
),
dim3
(
K
,
K
/
V
),
0
,
THCState_getCurrentStream
(
state
)
>>>
(
inFeatures
+
o
*
input_stride
,
outFeatures
+
o
*
output_stride
,
affineWeight
,
affineBias
,
convWeight
,
input_nPlanes
,
input_stride
,
output_nPlanes
,
output_stride
,
nActive
-
o
);
return
;
}
}
{
const
uInt
K
=
16
;
const
uInt
V
=
4
;
if
(
input_nPlanes
%
K
==
0
and
output_nPlanes
%
K
==
0
)
{
uInt
o
=
(
nActive
/
K
)
*
K
;
if
(
o
>
0
)
dAffineReluTrivialConvolution_forwardA
<
T
,
K
,
V
><<<
dim3
(
std
::
min
(
o
/
K
,
(
uInt
)
512
),
output_nPlanes
/
K
),
dim3
(
K
,
K
/
V
),
0
,
THCState_getCurrentStream
(
state
)
>>>
(
inFeatures
,
outFeatures
,
affineWeight
,
affineBias
,
convWeight
,
input_nPlanes
,
input_stride
,
output_nPlanes
,
output_stride
,
o
);
if
(
nActive
>
o
)
dAffineReluTrivialConvolution_forwardB
<
T
,
K
,
V
><<<
dim3
(
1
,
output_nPlanes
/
K
),
dim3
(
K
,
K
/
V
),
0
,
THCState_getCurrentStream
(
state
)
>>>
(
inFeatures
+
o
*
input_stride
,
outFeatures
+
o
*
output_stride
,
affineWeight
,
affineBias
,
convWeight
,
input_nPlanes
,
input_stride
,
output_nPlanes
,
output_stride
,
nActive
-
o
);
return
;
}
}
{
const
uInt
K
=
8
;
const
uInt
V
=
2
;
if
(
input_nPlanes
%
K
==
0
and
output_nPlanes
%
K
==
0
)
{
uInt
o
=
(
nActive
/
K
)
*
K
;
if
(
o
>
0
)
dAffineReluTrivialConvolution_forwardA
<
T
,
K
,
V
><<<
dim3
(
std
::
min
(
o
/
K
,
(
uInt
)
512
),
output_nPlanes
/
K
),
dim3
(
K
,
K
/
V
),
0
,
THCState_getCurrentStream
(
state
)
>>>
(
inFeatures
,
outFeatures
,
affineWeight
,
affineBias
,
convWeight
,
input_nPlanes
,
input_stride
,
output_nPlanes
,
output_stride
,
o
);
if
(
nActive
>
o
)
dAffineReluTrivialConvolution_forwardB
<
T
,
K
,
V
><<<
dim3
(
1
,
output_nPlanes
/
K
),
dim3
(
K
,
K
/
V
),
0
,
THCState_getCurrentStream
(
state
)
>>>
(
inFeatures
+
o
*
input_stride
,
outFeatures
+
o
*
output_stride
,
affineWeight
,
affineBias
,
convWeight
,
input_nPlanes
,
input_stride
,
output_nPlanes
,
output_stride
,
nActive
-
o
);
return
;
}
}
FOO
(
T
,
64
,
16
)
FOO
(
T
,
32
,
8
)
FOO
(
T
,
16
,
4
)
FOO
(
T
,
8
,
2
)
assert
(
false
);
}
template
<
>
void
dAffineReluTrivialConvolution_forward
<
double
>
(
double
*
inFeatures
,
double
*
outFeatures
,
double
*
affineWeight
,
double
*
affineBias
,
double
*
convWeight
,
uInt
input_nPlanes
,
uInt
input_stride
,
uInt
output_nPlanes
,
uInt
output_stride
,
uInt
nActive
)
{
FOO
(
double
,
32
,
8
)
FOO
(
double
,
16
,
4
)
FOO
(
double
,
8
,
2
)
assert
(
false
);
}
#undef FOO
// dOutput x W^T -> dInput and
// Input^T x dOutput -> dW
...
...
@@ -449,84 +403,41 @@ __global__ void dAffineReluTrivialConvolution_backward_dW_B(
atomicAdd
(
&
dAffineBias
[
tx
],
dAB
);
}
#define FOO(T, K, V) \
{ \
if (input_nPlanes % K == 0 and output_nPlanes % K == 0) { \
uInt o = (nActive / K) * K; \
if (o > 0) \
dAffineReluTrivialConvolution_backward_dW_A<T, K, V> << < \
dim3(std::min(o / K, (uInt)512), input_nPlanes / K), \
dim3(K, K / V), 0, THCState_getCurrentStream(state)>>> \
(inFeatures, dInFeatures, dOutFeatures, affineWeight, \
dAffineWeight, affineBias, dAffineBias, convWeight, dConvWeight, \
input_nPlanes, input_stride, output_nPlanes, output_stride, o, \
additiveGrad); \
if (nActive > o) \
dAffineReluTrivialConvolution_backward_dW_B<T, K, V> << < \
dim3(1, input_nPlanes / K), dim3(K, K / V), 0, \
THCState_getCurrentStream(state)>>> \
(inFeatures + o * input_stride, dInFeatures + o * input_stride, \
dOutFeatures + o * output_stride, affineWeight, dAffineWeight, \
affineBias, dAffineBias, convWeight, dConvWeight, input_nPlanes, \
input_stride, output_nPlanes, output_stride, nActive - o, \
additiveGrad); \
return; \
} \
}
template
<
typename
T
>
void
dAffineReluTrivialConvolution_backward_dW
(
T
*
inFeatures
,
T
*
dInFeatures
,
T
*
dOutFeatures
,
T
*
affineWeight
,
T
*
dAffineWeight
,
T
*
affineBias
,
T
*
dAffineBias
,
T
*
convWeight
,
T
*
dConvWeight
,
uInt
input_nPlanes
,
uInt
input_stride
,
uInt
output_nPlanes
,
uInt
output_stride
,
uInt
nActive
,
bool
additiveGrad
)
{
{
const
uInt
K
=
32
;
const
uInt
V
=
8
;
if
(
input_nPlanes
%
K
==
0
and
output_nPlanes
%
K
==
0
)
{
uInt
o
=
(
nActive
/
K
)
*
K
;
if
(
o
>
0
)
dAffineReluTrivialConvolution_backward_dW_A
<
T
,
K
,
V
><<<
dim3
(
std
::
min
(
o
/
K
,
(
uInt
)
512
),
input_nPlanes
/
K
),
dim3
(
K
,
K
/
V
),
0
,
THCState_getCurrentStream
(
state
)
>>>
(
inFeatures
,
dInFeatures
,
dOutFeatures
,
affineWeight
,
dAffineWeight
,
affineBias
,
dAffineBias
,
convWeight
,
dConvWeight
,
input_nPlanes
,
input_stride
,
output_nPlanes
,
output_stride
,
o
,
additiveGrad
);
if
(
nActive
>
o
)
dAffineReluTrivialConvolution_backward_dW_B
<
T
,
K
,
V
><<<
dim3
(
1
,
input_nPlanes
/
K
),
dim3
(
K
,
K
/
V
),
0
,
THCState_getCurrentStream
(
state
)
>>>
(
inFeatures
+
o
*
input_stride
,
dInFeatures
+
o
*
input_stride
,
dOutFeatures
+
o
*
output_stride
,
affineWeight
,
dAffineWeight
,
affineBias
,
dAffineBias
,
convWeight
,
dConvWeight
,
input_nPlanes
,
input_stride
,
output_nPlanes
,
output_stride
,
nActive
-
o
,
additiveGrad
);
return
;
}
}
{
const
uInt
K
=
16
;
const
uInt
V
=
4
;
if
(
input_nPlanes
%
K
==
0
and
output_nPlanes
%
K
==
0
)
{
uInt
o
=
(
nActive
/
K
)
*
K
;
if
(
o
>
0
)
dAffineReluTrivialConvolution_backward_dW_A
<
T
,
K
,
V
><<<
dim3
(
std
::
min
(
o
/
K
,
(
uInt
)
512
),
input_nPlanes
/
K
),
dim3
(
K
,
K
/
V
),
0
,
THCState_getCurrentStream
(
state
)
>>>
(
inFeatures
,
dInFeatures
,
dOutFeatures
,
affineWeight
,
dAffineWeight
,
affineBias
,
dAffineBias
,
convWeight
,
dConvWeight
,
input_nPlanes
,
input_stride
,
output_nPlanes
,
output_stride
,
o
,
additiveGrad
);
if
(
nActive
>
o
)
dAffineReluTrivialConvolution_backward_dW_B
<
T
,
K
,
V
><<<
dim3
(
1
,
input_nPlanes
/
K
),
dim3
(
K
,
K
/
V
),
0
,
THCState_getCurrentStream
(
state
)
>>>
(
inFeatures
+
o
*
input_stride
,
dInFeatures
+
o
*
input_stride
,
dOutFeatures
+
o
*
output_stride
,
affineWeight
,
dAffineWeight
,
affineBias
,
dAffineBias
,
convWeight
,
dConvWeight
,
input_nPlanes
,
input_stride
,
output_nPlanes
,
output_stride
,
nActive
-
o
,
additiveGrad
);
return
;
}
}
{
const
uInt
K
=
8
;
const
uInt
V
=
2
;
if
(
input_nPlanes
%
K
==
0
and
output_nPlanes
%
K
==
0
)
{
uInt
o
=
(
nActive
/
K
)
*
K
;
if
(
o
>
0
)
dAffineReluTrivialConvolution_backward_dW_A
<
T
,
K
,
V
><<<
dim3
(
std
::
min
(
o
/
K
,
(
uInt
)
512
),
input_nPlanes
/
K
),
dim3
(
K
,
K
/
V
),
0
,
THCState_getCurrentStream
(
state
)
>>>
(
inFeatures
,
dInFeatures
,
dOutFeatures
,
affineWeight
,
dAffineWeight
,
affineBias
,
dAffineBias
,
convWeight
,
dConvWeight
,
input_nPlanes
,
input_stride
,
output_nPlanes
,
output_stride
,
o
,
additiveGrad
);
if
(
nActive
>
o
)
dAffineReluTrivialConvolution_backward_dW_B
<
T
,
K
,
V
><<<
dim3
(
1
,
input_nPlanes
/
K
),
dim3
(
K
,
K
/
V
),
0
,
THCState_getCurrentStream
(
state
)
>>>
(
inFeatures
+
o
*
input_stride
,
dInFeatures
+
o
*
input_stride
,
dOutFeatures
+
o
*
output_stride
,
affineWeight
,
dAffineWeight
,
affineBias
,
dAffineBias
,
convWeight
,
dConvWeight
,
input_nPlanes
,
input_stride
,
output_nPlanes
,
output_stride
,
nActive
-
o
,
additiveGrad
);
return
;
}
}
FOO
(
T
,
32
,
8
)
FOO
(
T
,
16
,
4
)
FOO
(
T
,
8
,
2
)
}
#undef FOO
#endif
PyTorch/sparseconvnet/SCN/generic/GPU/BatchNormalization.cu
View file @
5f0860fc
...
...
@@ -30,13 +30,14 @@ extern "C" void scn_R_(BatchNormalization_updateOutput)(
real
leakiness
)
{
THCTensor_
(
resizeAs
)(
state
,
output_features
,
input_features
);
auto
nActive
=
input_features
->
size
[
0
];
auto
nPlanes
=
input_features
->
size
[
1
];
auto
input_stride
=
input_features
->
stride
[
0
];
auto
output_stride
=
output_features
->
stride
[
0
];
BN_F_MACRO
(
16
)
else
BN_F_MACRO
(
12
)
else
BN_F_MACRO
(
8
)
else
BN_F_MACRO
(
4
)
else
BN_F_MACRO
(
1
)
if
(
input_features
->
nDimension
==
2
)
{
auto
nActive
=
input_features
->
size
[
0
];
auto
nPlanes
=
input_features
->
size
[
1
];
auto
input_stride
=
input_features
->
stride
[
0
];
auto
output_stride
=
output_features
->
stride
[
0
];
BN_F_MACRO
(
16
)
else
BN_F_MACRO
(
12
)
else
BN_F_MACRO
(
8
)
else
BN_F_MACRO
(
4
)
else
BN_F_MACRO
(
1
)
}
}
extern
"C"
void
scn_R_
(
BatchNormalizationInTensor_updateOutput
)(
...
...
@@ -44,14 +45,14 @@ extern "C" void scn_R_(BatchNormalizationInTensor_updateOutput)(
THCTensor
*
saveInvStd
,
THCTensor
*
runningMean
,
THCTensor
*
runningVar
,
THCTensor
*
weight
,
THCTensor
*
bias
,
real
eps
,
real
momentum
,
bool
train
,
real
leakiness
)
{
auto
nActive
=
input_features
->
size
[
0
];
auto
nPlanes
=
input_features
->
size
[
1
];
auto
input_stride
=
input_features
->
stride
[
0
];
auto
output_stride
=
output_features
->
stride
[
0
];
BN_F_MACRO
(
1
6
)
else
BN_F_MACRO
(
12
)
else
BN_F_MACRO
(
8
)
else
BN_F_MACRO
(
4
)
else
BN_F_MACRO
(
1
)
if
(
input_features
->
nDimension
==
2
)
{
auto
nActive
=
input_features
->
size
[
0
];
auto
nPlanes
=
input_features
->
size
[
1
];
auto
input_stride
=
input_features
->
stride
[
0
];
auto
output_stride
=
output_features
->
stride
[
0
];
BN_F_MACRO
(
16
)
else
BN_F_MACRO
(
1
2
)
else
BN_F_MACRO
(
8
)
else
BN_F_MACRO
(
4
)
else
BN_F_MACRO
(
1
)
}
}
#undef BN_F_MACRO
...
...
@@ -81,12 +82,13 @@ extern "C" void scn_R_(BatchNormalization_backward)(
THCTensor
*
d_weight
,
THCTensor
*
d_bias
,
real
leakiness
)
{
THCTensor_
(
resizeAs
)(
state
,
d_input_features
,
d_output_features
);
auto
nActive
=
input_features
->
size
[
0
];
auto
nPlanes
=
input_features
->
size
[
1
];
auto
input_stride
=
input_features
->
stride
[
0
];
auto
output_stride
=
output_features
->
stride
[
0
];
BN_B_MACRO
(
16
)
else
BN_B_MACRO
(
12
)
else
BN_B_MACRO
(
8
)
else
BN_B_MACRO
(
4
)
else
BN_B_MACRO
(
1
)
if
(
input_features
->
nDimension
==
2
)
{
auto
nActive
=
input_features
->
size
[
0
];
auto
nPlanes
=
input_features
->
size
[
1
];
auto
input_stride
=
input_features
->
stride
[
0
];
auto
output_stride
=
output_features
->
stride
[
0
];
BN_B_MACRO
(
16
)
else
BN_B_MACRO
(
12
)
else
BN_B_MACRO
(
8
)
else
BN_B_MACRO
(
4
)
else
BN_B_MACRO
(
1
)
}
}
#endif
PyTorch/sparseconvnet/SCN/generic/GPU/Convolution.cu
View file @
5f0860fc
...
...
@@ -25,28 +25,30 @@ extern "C" double scn_DR_(Convolution_updateOutput)(
if
(
not
bias
)
THCTensor_
(
zero
)(
state
,
output_features
);
auto
iF
=
THCTensor_
(
data
)(
state
,
input_features
);
auto
oF
=
THCTensor_
(
data
)(
state
,
output_features
);
auto
ip
=
input_features
->
size
[
1
];
auto
op
=
output_features
->
size
[
1
];
auto
w
=
THCTensor_
(
data
)(
state
,
weight
);
double
flops
=
0
;
if
(
nActive
)
{
auto
iF
=
THCTensor_
(
data
)(
state
,
input_features
);
auto
oF
=
THCTensor_
(
data
)(
state
,
output_features
);
auto
ip
=
input_features
->
size
[
1
];
auto
op
=
output_features
->
size
[
1
];
auto
w
=
THCTensor_
(
data
)(
state
,
weight
);
if
(
bias
)
{
auto
b
=
THCTensor_
(
data
)(
state
,
bias
);
for
(
uInt
i
=
0
;
i
<
op
;
i
+=
32
)
{
uInt
blockDim
=
min
(
32L
,
op
-
i
);
uInt
gridDim
=
min
(
4096
,
nActive
);
Convolution_fp_bias
<<
<
gridDim
,
blockDim
,
0
,
THCState_getCurrentStream
(
state
)
>>>
(
oF
+
i
,
b
+
i
,
op
,
op
,
nActive
);
if
(
bias
)
{
auto
b
=
THCTensor_
(
data
)(
state
,
bias
);
for
(
uInt
i
=
0
;
i
<
op
;
i
+=
32
)
{
uInt
blockDim
=
min
(
32L
,
op
-
i
);
uInt
gridDim
=
min
(
4096
,
nActive
);
Convolution_fp_bias
<<
<
gridDim
,
blockDim
,
0
,
THCState_getCurrentStream
(
state
)
>>>
(
oF
+
i
,
b
+
i
,
op
,
op
,
nActive
);
}
}
uInt
c
=
ip
*
op
;
RULEBOOKITERATOR
(
dConvolution_forward2
<
real
>
(
iF
,
oF
,
w
,
rbB
,
nHotB
,
ip
,
ip
,
op
,
op
,
THCState_getCurrentStream
(
state
));
,
w
+=
c
;
flops
+=
nHotB
*
c
;)
}
uInt
c
=
ip
*
op
;
RULEBOOKITERATOR
(
dConvolution_forward2
<
real
>
(
iF
,
oF
,
w
,
rbB
,
nHotB
,
ip
,
ip
,
op
,
op
,
THCState_getCurrentStream
(
state
));
,
w
+=
c
;
flops
+=
nHotB
*
c
;)
return
flops
;
}
...
...
@@ -63,23 +65,25 @@ extern "C" void scn_DR_(Convolution_backward)(
THCTensor_
(
resizeAs
)(
state
,
d_input_features
,
input_features
);
THCTensor_
(
zero
)(
state
,
d_input_features
);
auto
iF
=
THCTensor_
(
data
)(
state
,
input_features
);
auto
diF
=
THCTensor_
(
data
)(
state
,
d_input_features
);
auto
doF
=
THCTensor_
(
data
)(
state
,
d_output_features
);
auto
ip
=
input_features
->
size
[
1
];
auto
op
=
d_output_features
->
size
[
1
];
auto
w
=
THCTensor_
(
data
)(
state
,
weight
);
auto
dw
=
THCTensor_
(
data
)(
state
,
d_weight
);
uInt
c
=
ip
*
op
;
RULEBOOKITERATOR
(
dConvolution_backward_dW2
<
real
>
(
iF
,
diF
,
doF
,
w
,
dw
,
rbB
,
nHotB
,
ip
,
ip
,
op
,
op
,
THCState_getCurrentStream
(
state
));
,
w
+=
c
;
dw
+=
c
;)
if
(
nActive
)
{
auto
iF
=
THCTensor_
(
data
)(
state
,
input_features
);
auto
diF
=
THCTensor_
(
data
)(
state
,
d_input_features
);
auto
doF
=
THCTensor_
(
data
)(
state
,
d_output_features
);
auto
ip
=
input_features
->
size
[
1
];
auto
op
=
d_output_features
->
size
[
1
];
auto
w
=
THCTensor_
(
data
)(
state
,
weight
);
auto
dw
=
THCTensor_
(
data
)(
state
,
d_weight
);
uInt
c
=
ip
*
op
;
RULEBOOKITERATOR
(
dConvolution_backward_dW2
<
real
>
(
iF
,
diF
,
doF
,
w
,
dw
,
rbB
,
nHotB
,
ip
,
ip
,
op
,
op
,
THCState_getCurrentStream
(
state
));
,
w
+=
c
;
dw
+=
c
;)
if
(
d_bias
)
{
auto
db
=
THCTensor_
(
data
)(
state
,
d_bias
);
Convolution_bp_bias
(
doF
,
db
,
op
,
op
,
nActive
,
THCState_getCurrentStream
(
state
));
if
(
d_bias
)
{
auto
db
=
THCTensor_
(
data
)(
state
,
d_bias
);
Convolution_bp_bias
(
doF
,
db
,
op
,
op
,
nActive
,
THCState_getCurrentStream
(
state
));
}
}
}
...
...
@@ -89,33 +93,35 @@ extern "C" double scn_DR_(ValidConvolution_updateOutput)(
THCTensor
*
bias
,
long
filterVolume
,
THCITensor
*
rulesBuffer
)
{
SCN_INITIALIZE_AND_REFERENCE
(
Metadata
<
Dimension
>
,
m
)
auto
_rules
=
_m
.
getValidRuleBook
(
inputSize
,
filterSize
,
true
);
uInt
nActive
=
input_features
->
size
[
0
]
;
uInt
nActive
=
_m
.
getNActive
(
inputSize
)
;
THCTensor_
(
resize2d
)(
state
,
output_features
,
nActive
,
weight
->
size
[
1
]);
if
(
not
bias
)
THCTensor_
(
zero
)(
state
,
output_features
);
auto
iF
=
THCTensor_
(
data
)(
state
,
input_features
);
auto
oF
=
THCTensor_
(
data
)(
state
,
output_features
);
auto
ip
=
input_features
->
size
[
1
];
auto
op
=
output_features
->
size
[
1
];
auto
w
=
THCTensor_
(
data
)(
state
,
weight
);
double
flops
=
0
;
if
(
nActive
)
{
auto
iF
=
THCTensor_
(
data
)(
state
,
input_features
);
auto
oF
=
THCTensor_
(
data
)(
state
,
output_features
);
auto
ip
=
input_features
->
size
[
1
];
auto
op
=
output_features
->
size
[
1
];
auto
w
=
THCTensor_
(
data
)(
state
,
weight
);
if
(
bias
)
{
auto
b
=
THCTensor_
(
data
)(
state
,
bias
);
for
(
uInt
i
=
0
;
i
<
op
;
i
+=
32
)
{
uInt
blockDim
=
min
(
32L
,
op
-
i
);
uInt
gridDim
=
min
(
4096
,
nActive
);
Convolution_fp_bias
<<
<
gridDim
,
blockDim
,
0
,
THCState_getCurrentStream
(
state
)
>>>
(
oF
+
i
,
b
+
i
,
op
,
op
,
nActive
);
if
(
bias
)
{
auto
b
=
THCTensor_
(
data
)(
state
,
bias
);
for
(
uInt
i
=
0
;
i
<
op
;
i
+=
32
)
{
uInt
blockDim
=
min
(
32L
,
op
-
i
);
uInt
gridDim
=
min
(
4096
,
nActive
);
Convolution_fp_bias
<<
<
gridDim
,
blockDim
,
0
,
THCState_getCurrentStream
(
state
)
>>>
(
oF
+
i
,
b
+
i
,
op
,
op
,
nActive
);
}
}
uInt
c
=
ip
*
op
;
RULEBOOKITERATOR
(
dConvolution_forward2
<
real
>
(
iF
,
oF
,
w
,
rbB
,
nHotB
,
ip
,
ip
,
op
,
op
,
THCState_getCurrentStream
(
state
));
,
w
+=
c
;
flops
+=
nHotB
*
c
;)
}
uInt
c
=
ip
*
op
;
RULEBOOKITERATOR
(
dConvolution_forward2
<
real
>
(
iF
,
oF
,
w
,
rbB
,
nHotB
,
ip
,
ip
,
op
,
op
,
THCState_getCurrentStream
(
state
));
,
w
+=
c
;
flops
+=
nHotB
*
c
;)
return
flops
;
}
...
...
@@ -126,27 +132,29 @@ extern "C" void scn_DR_(ValidConvolution_backward)(
THCTensor
*
d_bias
,
long
filterVolume
,
THCITensor
*
rulesBuffer
)
{
SCN_INITIALIZE_AND_REFERENCE
(
Metadata
<
Dimension
>
,
m
)
auto
_rules
=
_m
.
getValidRuleBook
(
inputSize
,
filterSize
,
true
);
uInt
nActive
=
input_features
->
size
[
0
]
;
uInt
nActive
=
_m
.
getNActive
(
inputSize
)
;
THCTensor_
(
resizeAs
)(
state
,
d_input_features
,
input_features
);
THCTensor_
(
zero
)(
state
,
d_input_features
);
auto
iF
=
THCTensor_
(
data
)(
state
,
input_features
);
auto
diF
=
THCTensor_
(
data
)(
state
,
d_input_features
);
auto
doF
=
THCTensor_
(
data
)(
state
,
d_output_features
);
auto
ip
=
input_features
->
size
[
1
];
auto
op
=
d_output_features
->
size
[
1
];
auto
w
=
THCTensor_
(
data
)(
state
,
weight
);
auto
dw
=
THCTensor_
(
data
)(
state
,
d_weight
);
uInt
c
=
ip
*
op
;
RULEBOOKITERATOR
(
dConvolution_backward_dW2
<
real
>
(
iF
,
diF
,
doF
,
w
,
dw
,
rbB
,
nHotB
,
ip
,
ip
,
op
,
op
,
THCState_getCurrentStream
(
state
));
,
w
+=
c
;
dw
+=
c
;)
if
(
nActive
)
{
auto
iF
=
THCTensor_
(
data
)(
state
,
input_features
);
auto
diF
=
THCTensor_
(
data
)(
state
,
d_input_features
);
auto
doF
=
THCTensor_
(
data
)(
state
,
d_output_features
);
auto
ip
=
input_features
->
size
[
1
];
auto
op
=
d_output_features
->
size
[
1
];
auto
w
=
THCTensor_
(
data
)(
state
,
weight
);
auto
dw
=
THCTensor_
(
data
)(
state
,
d_weight
);
uInt
c
=
ip
*
op
;
RULEBOOKITERATOR
(
dConvolution_backward_dW2
<
real
>
(
iF
,
diF
,
doF
,
w
,
dw
,
rbB
,
nHotB
,
ip
,
ip
,
op
,
op
,
THCState_getCurrentStream
(
state
));
,
w
+=
c
;
dw
+=
c
;)
if
(
d_bias
)
{
auto
db
=
THCTensor_
(
data
)(
state
,
d_bias
);
Convolution_bp_bias
(
doF
,
db
,
op
,
op
,
nActive
,
THCState_getCurrentStream
(
state
));
if
(
d_bias
)
{
auto
db
=
THCTensor_
(
data
)(
state
,
d_bias
);
Convolution_bp_bias
(
doF
,
db
,
op
,
op
,
nActive
,
THCState_getCurrentStream
(
state
));
}
}
}
...
...
PyTorch/sparseconvnet/SCN/generic/GPU/Convolution.h
View file @
5f0860fc
...
...
@@ -184,7 +184,7 @@ dConvolution_KMxKN_forwardB(T *inFeatures, T *outFeatures, T *w, uInt *rules,
}
}
#define FOO(K, V)
\
#define FOO(
T,
K, V) \
{ \
if (input_nPlanes % K == 0 and output_nPlanes % K == 0) { \
uInt o = (nHot / K) * K; \
...
...
@@ -208,10 +208,21 @@ void dConvolution_forward(T *inFeatures, T *outFeatures, T *w, uInt *rules,
uInt
nHot
,
uInt
input_nPlanes
,
uInt
input_stride
,
uInt
output_nPlanes
,
uInt
output_stride
,
cudaStream_t
stream
)
{
FOO
(
64
,
16
)
FOO
(
32
,
8
)
FOO
(
16
,
4
)
FOO
(
8
,
2
)
FOO
(
T
,
64
,
16
)
FOO
(
T
,
32
,
8
)
FOO
(
T
,
16
,
4
)
FOO
(
T
,
8
,
2
)
assert
(
false
);
}
template
<
>
void
dConvolution_forward
<
double
>
(
double
*
inFeatures
,
double
*
outFeatures
,
double
*
w
,
uInt
*
rules
,
uInt
nHot
,
uInt
input_nPlanes
,
uInt
input_stride
,
uInt
output_nPlanes
,
uInt
output_stride
,
cudaStream_t
stream
)
{
FOO
(
double
,
32
,
8
)
FOO
(
double
,
16
,
4
)
FOO
(
double
,
8
,
2
)
assert
(
false
);
}
#undef FOO
...
...
@@ -378,7 +389,7 @@ dConvolution_KMxKN_backward_dW_B(T *inFeatures, T *dInFeatures, T *dOutFeatures,
}
}
#define FOO(K, V)
\
#define FOO(
T,
K, V) \
{ \
if (input_nPlanes % K == 0 and output_nPlanes % K == 0) { \
uInt o = (nHot / K) * K; \
...
...
@@ -404,9 +415,9 @@ void dConvolution_backward_dW(T *inFeatures, T *dInFeatures, T *dOutFeatures,
uInt
input_nPlanes
,
uInt
input_stride
,
uInt
output_nPlanes
,
uInt
output_stride
,
cudaStream_t
stream
)
{
FOO
(
32
,
8
)
FOO
(
16
,
4
)
FOO
(
8
,
2
)
FOO
(
T
,
32
,
8
)
FOO
(
T
,
16
,
4
)
FOO
(
T
,
8
,
2
)
assert
(
false
);
}
#undef FOO
...
...
PyTorch/sparseconvnet/SCN/generic/GPU/Deconvolution.h
View file @
5f0860fc
...
...
@@ -153,7 +153,7 @@ dDeconvolution_KMxKN_forwardB(T *inFeatures, T *outFeatures, T *w, uInt *rules,
}
}
#define FOO(K, V)
\
#define FOO(
T,
K, V) \
{ \
if (input_nPlanes % K == 0 and output_nPlanes % K == 0) { \
uInt o = (nHot / K) * K; \
...
...
@@ -177,10 +177,21 @@ void dDeconvolution_forward(T *inFeatures, T *outFeatures, T *w, uInt *rules,
uInt
nHot
,
uInt
input_nPlanes
,
uInt
input_stride
,
uInt
output_nPlanes
,
uInt
output_stride
,
cudaStream_t
stream
)
{
FOO
(
64
,
16
)
FOO
(
32
,
8
)
FOO
(
16
,
4
)
FOO
(
8
,
2
)
FOO
(
T
,
64
,
16
)
FOO
(
T
,
32
,
8
)
FOO
(
T
,
16
,
4
)
FOO
(
T
,
8
,
2
)
assert
(
false
);
}
template
<
>
void
dDeconvolution_forward
<
double
>
(
double
*
inFeatures
,
double
*
outFeatures
,
double
*
w
,
uInt
*
rules
,
uInt
nHot
,
uInt
input_nPlanes
,
uInt
input_stride
,
uInt
output_nPlanes
,
uInt
output_stride
,
cudaStream_t
stream
)
{
FOO
(
double
,
32
,
8
)
FOO
(
double
,
16
,
4
)
FOO
(
double
,
8
,
2
)
assert
(
false
);
}
#undef FOO
...
...
@@ -345,7 +356,7 @@ __global__ void dDeconvolution_KMxKN_backward_dW_B(
}
}
#define FOO(K, V)
\
#define FOO(
T,
K, V) \
{ \
if (input_nPlanes % K == 0 and output_nPlanes % K == 0) { \
uInt o = (nHot / K) * K; \
...
...
@@ -371,9 +382,9 @@ void dDeconvolution_backward_dW(T *inFeatures, T *dInFeatures, T *dOutFeatures,
uInt
input_nPlanes
,
uInt
input_stride
,
uInt
output_nPlanes
,
uInt
output_stride
,
cudaStream_t
stream
)
{
FOO
(
32
,
8
)
FOO
(
16
,
4
)
FOO
(
8
,
2
)
FOO
(
T
,
32
,
8
)
FOO
(
T
,
16
,
4
)
FOO
(
T
,
8
,
2
)
assert
(
false
);
}
#undef FOO
...
...
PyTorch/sparseconvnet/SCN/generic/GPU/SparseToDense.cu
View file @
5f0860fc
...
...
@@ -9,50 +9,54 @@
#else
#include "SparseToDense.h"
extern
"C"
void
scn_DR_
(
SparseToDense_updateOutput
)(
THLongTensor
*
inputSize
,
void
**
m
,
THCTensor
*
input_features
,
THCTensor
*
output_features
,
THCITensor
*
rulesBuffer
)
{
extern
"C"
void
scn_DR_
(
SparseToDense_updateOutput
)(
THLongTensor
*
inputSize
,
void
**
m
,
THCTensor
*
input_features
,
THCTensor
*
output_features
,
THCITensor
*
rulesBuffer
,
long
nPlanes
)
{
SCN_INITIALIZE_AND_REFERENCE
(
Metadata
<
Dimension
>
,
m
)
{
SCN_INITIALIZE_AND_REFERENCE
(
Metadata
<
Dimension
>
,
m
)
long
spatialVolume
=
1
;
{
long
sz
[
Dimension
+
2
];
sz
[
0
]
=
_m
.
inputSGs
->
size
();
sz
[
1
]
=
input_features
->
size
[
1
];
sz
[
0
]
=
_m
.
grids
.
begin
()
->
second
.
size
();
sz
[
1
]
=
nPlanes
;
//
input_features->size[1];
for
(
int
i
=
0
;
i
<
Dimension
;
i
++
)
{
auto
x
=
THLongTensor_data
(
inputSize
)[
i
];
sz
[
i
+
2
]
=
x
;
spatialVolume
*=
x
;
}
THCTensor_
(
resizeNd
)(
state
,
output_features
,
Dimension
+
2
,
sz
,
NULL
);
THCTensor_
(
zero
)(
state
,
output_features
);
}
auto
_rules
=
_m
.
getSparseToDenseRuleBook
(
inputSize
,
true
);
auto
spatialVolume
=
_rules
.
size
();
uInt
nPlanes
=
input_features
->
size
[
1
];
auto
iF
=
THCTensor_
(
data
)(
state
,
input_features
);
auto
oF
=
THCTensor_
(
data
)(
state
,
output_features
);
RULEBOOKITERATOR
(
SparseToDense_ForwardPass
<
real
>
(
THCState_getCurrentStream
(
state
),
iF
,
oF
,
nPlanes
,
spatialVolume
,
rbB
,
nHotB
);
,
oF
++
;)
// todo check ++ or +=spatialVolume????zzz
if
(
input_features
->
nDimension
==
2
)
{
auto
_rules
=
_m
.
getSparseToDenseRuleBook
(
inputSize
,
true
);
uInt
nPlanes
=
input_features
->
size
[
1
];
auto
iF
=
THCTensor_
(
data
)(
state
,
input_features
);
auto
oF
=
THCTensor_
(
data
)(
state
,
output_features
);
RULEBOOKITERATOR
(
SparseToDense_ForwardPass
<
real
>
(
THCState_getCurrentStream
(
state
),
iF
,
oF
,
nPlanes
,
spatialVolume
,
rbB
,
nHotB
);
,
oF
+=
nPlanes
*
spatialVolume
;)
}
}
extern
"C"
void
scn_DR_
(
SparseToDense_updateGradInput
)(
THLongTensor
*
inputSize
,
void
**
m
,
THCTensor
*
input_features
,
THCTensor
*
d_input_features
,
THCTensor
*
d_output_features
,
THCITensor
*
rulesBuffer
)
{
SCN_INITIALIZE_AND_REFERENCE
(
Metadata
<
Dimension
>
,
m
)
THCTensor_
(
resizeAs
)(
state
,
d_input_features
,
input_features
);
THCTensor_
(
zero
)(
state
,
d_input_features
);
SCN_INITIALIZE_AND_REFERENCE
(
Metadata
<
Dimension
>
,
m
)
auto
_rules
=
_m
.
getSparseToDenseRuleBook
(
inputSize
,
true
);
auto
spatialVolume
=
_rules
.
size
();
uInt
nPlanes
=
d_input_features
->
size
[
1
];
auto
diF
=
THCTensor_
(
data
)(
state
,
d_input_features
);
auto
doF
=
THCTensor_
(
data
)(
state
,
d_output_features
);
RULEBOOKITERATOR
(
SparseToDense_BackwardPass
<
real
>
(
THCState_getCurrentStream
(
state
),
diF
,
doF
,
nPlanes
,
spatialVolume
,
rbB
,
nHotB
);
,
doF
++
;)
if
(
input_features
->
nDimension
==
2
)
{
auto
_rules
=
_m
.
getSparseToDenseRuleBook
(
inputSize
,
true
);
long
spatialVolume
=
THLongTensor_prodall
(
inputSize
);
uInt
nPlanes
=
d_input_features
->
size
[
1
];
auto
diF
=
THCTensor_
(
data
)(
state
,
d_input_features
);
auto
doF
=
THCTensor_
(
data
)(
state
,
d_output_features
);
RULEBOOKITERATOR
(
SparseToDense_BackwardPass
<
real
>
(
THCState_getCurrentStream
(
state
),
diF
,
doF
,
nPlanes
,
spatialVolume
,
rbB
,
nHotB
);
,
doF
+=
nPlanes
*
spatialVolume
;)
}
}
#endif
PyTorch/sparseconvnet/SCN/generic/GPU/SparseToDense.h
View file @
5f0860fc
...
...
@@ -12,7 +12,8 @@
// NTX must be >=2 so r is filled properly
template
<
typename
T
,
uInt
NTX
,
uInt
NTY
>
__global__
void
SparseToDense_fp
(
T
*
input_features
,
T
*
output_features
,
uInt
nPlanes
,
uInt
spatialVolume
,
uInt
*
rules
,
uInt
nHot
)
{
uInt
nPlanes
,
uInt
spatialVolume
,
uInt
*
rules
,
uInt
nHot
)
{
__shared__
uInt
r
[
NTY
*
2
];
for
(
uInt
n
=
blockIdx
.
x
*
NTY
;
n
<
nHot
;
n
+=
gridDim
.
x
*
NTY
)
{
{
...
...
@@ -22,10 +23,10 @@ __global__ void SparseToDense_fp(T *input_features, T *output_features,
}
__syncthreads
();
if
(
n
+
threadIdx
.
y
<
nHot
)
{
T
*
i
=
&
input_features
[
r
[
2
*
threadIdx
.
y
]
*
nPlanes
]
;
T
*
o
=
&
output_features
[
r
[
2
*
threadIdx
.
y
+
1
]
*
spatialVolume
*
nPlanes
];
T
*
i
=
input_features
+
r
[
2
*
threadIdx
.
y
]
*
nPlanes
;
T
*
o
=
output_features
+
r
[
2
*
threadIdx
.
y
+
1
];
for
(
uInt
plane
=
threadIdx
.
x
;
plane
<
nPlanes
;
plane
+=
NTX
)
o
[
plane
*
spatialVolume
]
=
i
[
plane
];
o
[
plane
*
spatialVolume
]
=
i
[
plane
];
}
__syncthreads
();
}
...
...
@@ -33,16 +34,16 @@ __global__ void SparseToDense_fp(T *input_features, T *output_features,
template
<
typename
T
>
void
SparseToDense_ForwardPass
(
cudaStream_t
stream
,
T
*
input_features
,
T
*
output_features
,
uInt
nPlanes
,
uInt
spatialVolume
,
uInt
*
rules
,
uInt
nHot
)
{
SparseToDense_fp
<
T
,
32
,
32
><<<
32
,
dim3
(
32
,
32
),
0
,
stream
>>>
(
input_features
,
output_features
,
nPlanes
,
spatialVolume
,
rules
,
nHot
);
T
*
output_features
,
uInt
nPlanes
,
uInt
spatialVolume
,
uInt
*
rules
,
uInt
nHot
)
{
SparseToDense_fp
<
T
,
32
,
32
>
<<
<
32
,
dim3
(
32
,
32
),
0
,
stream
>>>
(
input_features
,
output_features
,
nPlanes
,
spatialVolume
,
rules
,
nHot
);
}
// NTX must be >=2 so r is filled properly
template
<
typename
T
,
uInt
NTX
,
uInt
NTY
>
__global__
void
SparseToDense_bp
(
T
*
d_input_features
,
T
*
d_output_features
,
uInt
nPlanes
,
uInt
spatialVolume
,
uInt
*
rules
,
uInt
nHot
)
{
uInt
nPlanes
,
uInt
spatialVolume
,
uInt
*
rules
,
uInt
nHot
)
{
__shared__
uInt
r
[
NTY
*
2
];
for
(
uInt
n
=
blockIdx
.
x
*
NTY
;
n
<
nHot
;
n
+=
gridDim
.
x
*
NTY
)
{
{
...
...
@@ -52,10 +53,10 @@ __global__ void SparseToDense_bp(T *d_input_features, T *d_output_features,
}
__syncthreads
();
if
(
n
+
threadIdx
.
y
<
nHot
)
{
T
*
i
=
&
d_input_features
[
r
[
2
*
threadIdx
.
y
]
*
nPlanes
]
;
T
*
o
=
&
d_output_features
[
r
[
2
*
threadIdx
.
y
+
1
]
*
spatialVolume
*
nPlanes
];
T
*
d_
i
=
d_input_features
+
r
[
2
*
threadIdx
.
y
]
*
nPlanes
;
T
*
d_
o
=
d_output_features
+
r
[
2
*
threadIdx
.
y
+
1
];
for
(
uInt
plane
=
threadIdx
.
x
;
plane
<
nPlanes
;
plane
+=
NTX
)
i
[
plane
]
=
o
[
plane
*
spatialVolume
];
d_
i
[
plane
]
=
d_
o
[
plane
*
spatialVolume
];
}
__syncthreads
();
}
...
...
@@ -63,10 +64,10 @@ __global__ void SparseToDense_bp(T *d_input_features, T *d_output_features,
template
<
typename
T
>
void
SparseToDense_BackwardPass
(
cudaStream_t
stream
,
T
*
d_input_features
,
T
*
d_output_features
,
uInt
nPlanes
,
uInt
spatialVolume
,
uInt
*
rules
,
uInt
nHot
)
{
SparseToDense_bp
<
T
,
32
,
32
><<<
32
,
dim3
(
32
,
32
),
0
,
stream
>>>
(
d_input_features
,
d_output_features
,
nPlanes
,
spatialVolume
,
rules
,
nHot
);
T
*
d_output_features
,
uInt
nPlanes
,
uInt
spatialVolume
,
uInt
*
rules
,
uInt
nHot
)
{
SparseToDense_bp
<
T
,
32
,
32
>
<<
<
32
,
dim3
(
32
,
32
),
0
,
stream
>>>
(
d_input_features
,
d_output_features
,
nPlanes
,
spatialVolume
,
rules
,
nHot
);
}
#endif
/* GPU_SPARSETODENSE_H */
PyTorch/sparseconvnet/SCN/generic/GPU/THGenerateCudaFloatTypes.h
View file @
5f0860fc
...
...
@@ -27,4 +27,22 @@
#undef TH_REAL_IS_FLOAT
#undef THBLAS_GEMM
// double
// #define real double
// #define accreal double
// #define Real Double
// #define CReal CudaDouble
// #define TH_REAL_IS_DOUBLE
// #define THBLAS_GEMM THCudaBlas_Dgemm
// #line 1 TH_GENERIC_FILE
// #include TH_GENERIC_FILE
// #undef accreal
// #undef real
// #undef Real
// #undef CReal
// #undef TH_REAL_IS_DOUBLE
// #undef THBLAS_GEMM
#undef TH_GENERIC_FILE
PyTorch/sparseconvnet/SCN/generic/Geometry/ConvolutionRules.h
View file @
5f0860fc
...
...
@@ -103,20 +103,26 @@ uInt Convolution_InputSgsToRulesAndOutputSgs_OMP(
return
output_nActive
;
}
// for each site in filterVolume, list of (inputFeatureNumber,batchIdx) pairs
// for each active site, list of (inputFeatureNumber,batchIdx, spatialOffset)
// triples
template
<
uInt
dimension
>
void
SparseToDense_InputSgsToRulesAndOutputSgs
(
SparseGrids
<
dimension
>
&
input_SGs
,
RuleBook
&
rules
,
long
*
spatialSize
)
{
uInt
batchSize
=
input_SGs
.
size
();
SparseGrids
<
dimension
>
output_SGs
(
batchSize
);
std
::
vector
<
long
>
ones
(
dimension
,
1
);
rules
.
clear
();
for
(
uInt
i
=
0
;
i
<
batchSize
;
i
++
)
{
auto
&
iSG
=
input_SGs
[
i
];
auto
&
oSG
=
output_SGs
[
i
];
oSG
.
ctr
=
i
;
// batchIdx
Convolution_InputSgToRulesAndOutputSg
<
dimension
>
(
iSG
,
oSG
,
rules
,
spatialSize
,
&
ones
[
0
],
spatialSize
,
&
ones
[
0
]);
rules
.
resize
(
batchSize
);
Point
<
dimension
>
lb
,
ub
;
for
(
int
i
=
0
;
i
<
dimension
;
++
i
)
{
lb
[
i
]
=
0
;
ub
[
i
]
=
spatialSize
[
i
]
-
1
;
}
auto
region
=
RectangularRegion
<
dimension
>
(
lb
,
ub
);
for
(
uInt
batchIdx
=
0
;
batchIdx
<
batchSize
;
batchIdx
++
)
{
auto
&
iSG
=
input_SGs
[
batchIdx
];
for
(
auto
const
&
inIter
:
iSG
.
mp
)
{
rules
[
batchIdx
].
push_back
(
inIter
.
second
+
iSG
.
ctr
);
rules
[
batchIdx
].
push_back
(
region
.
offset
(
inIter
.
first
));
}
}
}
...
...
@@ -124,33 +130,21 @@ template <uInt dimension>
void
SparseToDense_InputSgsToRulesAndOutputSgs_OMP
(
SparseGrids
<
dimension
>
&
input_SGs
,
RuleBook
&
rules
,
long
*
spatialSize
)
{
uInt
batchSize
=
input_SGs
.
size
();
SparseGrids
<
dimension
>
output_SGs
(
batchSize
);
std
::
vector
<
long
>
ones
(
dimension
,
1
);
rules
.
clear
();
rules
.
resize
(
volume
<
dimension
>
(
spatialSize
));
std
::
vector
<
RuleBook
>
rbs
(
batchSize
);
{
uInt
i
;
#pragma omp parallel for private(i)
for
(
i
=
0
;
i
<
batchSize
;
i
++
)
{
output_SGs
[
i
].
ctr
=
i
;
// batchIdx
Convolution_InputSgToRulesAndOutputSg
<
dimension
>
(
input_SGs
[
i
],
output_SGs
[
i
],
rbs
[
i
],
spatialSize
,
&
ones
[
0
],
spatialSize
,
&
ones
[
0
]);
}
rules
.
resize
(
batchSize
);
Point
<
dimension
>
lb
,
ub
;
for
(
int
i
=
0
;
i
<
dimension
;
++
i
)
{
lb
[
i
]
=
0
;
ub
[
i
]
=
spatialSize
[
i
]
-
1
;
}
{
uInt
i
;
#pragma omp parallel for private(i)
for
(
i
=
0
;
i
<
rules
.
size
();
i
++
)
{
auto
&
R
=
rules
[
i
];
for
(
uInt
j
=
0
;
j
<
batchSize
;
j
++
)
{
auto
&
r
=
rbs
[
j
][
i
];
for
(
uInt
k
=
0
;
k
<
r
.
size
();)
{
R
.
push_back
(
r
[
k
++
]);
R
.
push_back
(
r
[
k
++
]);
}
}
auto
region
=
RectangularRegion
<
dimension
>
(
lb
,
ub
);
uInt
batchIdx
;
#pragma omp parallel for private(batchIdx)
for
(
batchIdx
=
0
;
batchIdx
<
batchSize
;
batchIdx
++
)
{
auto
&
iSG
=
input_SGs
[
batchIdx
];
for
(
auto
const
&
inIter
:
iSG
.
mp
)
{
rules
[
batchIdx
].
push_back
(
inIter
.
second
+
iSG
.
ctr
);
rules
[
batchIdx
].
push_back
(
region
.
offset
(
inIter
.
first
));
}
}
}
...
...
PyTorch/sparseconvnet/SCN/generic/Geometry/Metadata.cpp
View file @
5f0860fc
...
...
@@ -125,16 +125,15 @@ extern "C" void scn_D_(getSpatialLocations)(void **m, THLongTensor *spatialSize,
}
extern
"C"
void
scn_D_
(
createMetadataForDenseToSparse
)(
void
**
m
,
THLongTensor
*
spatialSize_
,
THLongTensor
*
pad_
,
THLongTensor
*
nz_
,
long
batchSize
)
{
THLongTensor
*
nz_
,
long
batchSize
)
{
SCN_INITIALIZE_AND_REFERENCE
(
Metadata
<
Dimension
>
,
m
)
_m
.
clear
();
_m
.
setInputSpatialSize
(
spatialSize_
);
_m
.
inputSGs
->
resize
(
batchSize
);
auto
&
nActive
=
*
_m
.
inputNActive
;
nActive
=
nz_
->
size
[
0
];
auto
nz
=
THLongTensor_data
(
nz_
);
auto
pad
=
THLongTensor_data
(
pad_
);
auto
spatialSize
=
THLongTensor_data
(
spatialSize_
);
std
::
vector
<
uInt
>
br
(
batchSize
+
1
);
...
...
@@ -157,8 +156,7 @@ scn_D_(createMetadataForDenseToSparse)(void **m, THLongTensor *spatialSize_,
for
(
uInt
i
=
br
[
b
];
i
<
br
[
b
+
1
];
i
++
)
{
Point
<
Dimension
>
x
;
for
(
uInt
j
=
0
;
j
<
Dimension
;
j
++
)
{
x
[
j
]
=
nz
[
i
*
(
Dimension
+
1
)
+
j
+
1
]
+
pad
[
b
*
Dimension
+
j
];
// 0-indexed
x
[
j
]
=
nz
[
i
*
(
Dimension
+
1
)
+
j
+
1
];
// 0-indexed
}
sg
.
mp
[
x
]
=
i
;
}
...
...
@@ -281,6 +279,7 @@ extern "C" void scn_D_(generateRuleBooks2s2)(void **m) {
p2
[
i
]
=
p3
[
i
]
=
inS
[
i
]
=
outS
[
i
];
}
}
extern
"C"
void
scn_D_
(
freeMetadata
)(
void
**
m
)
{
SCN_DELETE
(
Metadata
<
Dimension
>
,
m
)
}
...
...
PyTorch/sparseconvnet/SCN/generic/Geometry/Metadata.h
View file @
5f0860fc
...
...
@@ -11,7 +11,6 @@
#include "ActivePoolingRules.h"
#include "ConvolutionRules.h"
#include "ValidConvolutionRules.h"
#include <iostream>
#include <tuple>
#include <unordered_map>
...
...
@@ -40,6 +39,18 @@ public:
uInt
*
inputNActive
;
Metadata
()
{}
void
clear
()
{
nActive
.
clear
();
grids
.
clear
();
activePoolingRuleBooks
.
clear
();
validRuleBooks
.
clear
();
ruleBooks
.
clear
();
sparseToDenseRuleBooks
.
clear
();
inputSGs
=
nullptr
;
inputSG
=
nullptr
;
inputNActive
=
nullptr
;
}
void
setInputSpatialSize
(
THLongTensor
*
spatialSize
)
{
inputSpatialSize
=
LongTensorToPoint
<
dimension
>
(
spatialSize
);
inputSGs
=
&
grids
[
inputSpatialSize
];
...
...
PyTorch/sparseconvnet/SCN/generic/Geometry/ValidConvolutionRules.h
View file @
5f0860fc
...
...
@@ -6,8 +6,6 @@
#ifndef VALIDCONVOLUTIONRULES_H
#define VALIDCONVOLUTIONRULES_H
#include<iostream>
// Full input region for an output point
template
<
uInt
dimension
>
...
...
@@ -26,8 +24,8 @@ InputRegionCalculator_Valid(const Point<dimension> &output, long *size) {
// rules is used to carry out the "lowering" whilst carrying out the convolution
template
<
uInt
dimension
>
double
ValidConvolution_SgToRules
(
SparseGrid
<
dimension
>
&
grid
,
RuleBook
&
rules
,
long
*
size
)
{
double
ValidConvolution_SgToRules
(
SparseGrid
<
dimension
>
&
grid
,
RuleBook
&
rules
,
long
*
size
)
{
uInt
sd
=
volume
<
dimension
>
(
size
);
double
countActiveInputs
=
0
;
for
(
auto
const
&
outputIter
:
grid
.
mp
)
{
...
...
@@ -48,8 +46,8 @@ double ValidConvolution_SgToRules(SparseGrid<dimension> &grid,
}
template
<
uInt
dimension
>
uInt
ValidConvolution_SgsToRules
(
SparseGrids
<
dimension
>
&
SGs
,
RuleBook
&
rules
,
long
*
size
)
{
uInt
ValidConvolution_SgsToRules
(
SparseGrids
<
dimension
>
&
SGs
,
RuleBook
&
rules
,
long
*
size
)
{
uInt
sd
=
volume
<
dimension
>
(
size
);
uInt
countActiveInputs
=
0
;
rules
.
clear
();
...
...
@@ -61,7 +59,7 @@ uInt ValidConvolution_SgsToRules(SparseGrids<dimension> &SGs,
}
template
<
uInt
dimension
>
uInt
ValidConvolution_SgsToRules_OMP
(
SparseGrids
<
dimension
>
&
SGs
,
RuleBook
&
rules
,
long
*
size
)
{
RuleBook
&
rules
,
long
*
size
)
{
std
::
vector
<
RuleBook
>
rbs
(
SGs
.
size
());
std
::
vector
<
double
>
countActiveInputs
(
SGs
.
size
());
rules
.
clear
();
...
...
PyTorch/sparseconvnet/SCN/header_cpu.h
View file @
5f0860fc
This diff is collapsed.
Click to expand it.
PyTorch/sparseconvnet/SCN/header_gpu.h
View file @
5f0860fc
...
...
@@ -122,7 +122,8 @@ void scn_gpu_float1MaxPooling_updateGradInput(
// SparseToDense
void
scn_gpu_float1SparseToDense_updateOutput
(
THLongTensor
*
inputSize
,
void
**
m
,
THCudaTensor
*
input_features
,
THCudaTensor
*
output_features
,
THCudaIntTensor
*
rulesBuffer
);
THCudaTensor
*
output_features
,
THCudaIntTensor
*
rulesBuffer
,
long
nPlanes
);
void
scn_gpu_float1SparseToDense_updateGradInput
(
THLongTensor
*
inputSize
,
void
**
m
,
THCudaTensor
*
input_features
,
THCudaTensor
*
d_input_features
,
THCudaTensor
*
d_output_features
,
...
...
@@ -199,7 +200,7 @@ void scn_gpu_float2MaxPooling_updateGradInput(
// SparseToDense
void
scn_gpu_float2SparseToDense_updateOutput
(
THLongTensor
*
inputSize
,
void
**
m
,
THCudaTensor
*
input_features
,
THCudaTensor
*
output_features
,
THCudaIntTensor
*
rulesBuffer
);
THCudaTensor
*
output_features
,
THCudaIntTensor
*
rulesBuffer
,
long
nPlanes
);
void
scn_gpu_float2SparseToDense_updateGradInput
(
THLongTensor
*
inputSize
,
void
**
m
,
THCudaTensor
*
input_features
,
THCudaTensor
*
d_input_features
,
THCudaTensor
*
d_output_features
,
...
...
@@ -276,7 +277,7 @@ void scn_gpu_float3MaxPooling_updateGradInput(
// SparseToDense
void
scn_gpu_float3SparseToDense_updateOutput
(
THLongTensor
*
inputSize
,
void
**
m
,
THCudaTensor
*
input_features
,
THCudaTensor
*
output_features
,
THCudaIntTensor
*
rulesBuffer
);
THCudaTensor
*
output_features
,
THCudaIntTensor
*
rulesBuffer
,
long
nPlanes
);
void
scn_gpu_float3SparseToDense_updateGradInput
(
THLongTensor
*
inputSize
,
void
**
m
,
THCudaTensor
*
input_features
,
THCudaTensor
*
d_input_features
,
THCudaTensor
*
d_output_features
,
...
...
@@ -353,7 +354,7 @@ void scn_gpu_float4MaxPooling_updateGradInput(
// SparseToDense
void
scn_gpu_float4SparseToDense_updateOutput
(
THLongTensor
*
inputSize
,
void
**
m
,
THCudaTensor
*
input_features
,
THCudaTensor
*
output_features
,
THCudaIntTensor
*
rulesBuffer
);
THCudaTensor
*
output_features
,
THCudaIntTensor
*
rulesBuffer
,
long
nPlanes
);
void
scn_gpu_float4SparseToDense_updateGradInput
(
THLongTensor
*
inputSize
,
void
**
m
,
THCudaTensor
*
input_features
,
THCudaTensor
*
d_input_features
,
THCudaTensor
*
d_output_features
,
...
...
@@ -430,7 +431,7 @@ void scn_gpu_float5MaxPooling_updateGradInput(
// SparseToDense
void
scn_gpu_float5SparseToDense_updateOutput
(
THLongTensor
*
inputSize
,
void
**
m
,
THCudaTensor
*
input_features
,
THCudaTensor
*
output_features
,
THCudaIntTensor
*
rulesBuffer
);
THCudaTensor
*
output_features
,
THCudaIntTensor
*
rulesBuffer
,
long
nPlanes
);
void
scn_gpu_float5SparseToDense_updateGradInput
(
THLongTensor
*
inputSize
,
void
**
m
,
THCudaTensor
*
input_features
,
THCudaTensor
*
d_input_features
,
THCudaTensor
*
d_output_features
,
...
...
@@ -507,7 +508,7 @@ void scn_gpu_float6MaxPooling_updateGradInput(
// SparseToDense
void
scn_gpu_float6SparseToDense_updateOutput
(
THLongTensor
*
inputSize
,
void
**
m
,
THCudaTensor
*
input_features
,
THCudaTensor
*
output_features
,
THCudaIntTensor
*
rulesBuffer
);
THCudaTensor
*
output_features
,
THCudaIntTensor
*
rulesBuffer
,
long
nPlanes
);
void
scn_gpu_float6SparseToDense_updateGradInput
(
THLongTensor
*
inputSize
,
void
**
m
,
THCudaTensor
*
input_features
,
THCudaTensor
*
d_input_features
,
THCudaTensor
*
d_output_features
,
...
...
@@ -584,7 +585,7 @@ void scn_gpu_float7MaxPooling_updateGradInput(
// SparseToDense
void
scn_gpu_float7SparseToDense_updateOutput
(
THLongTensor
*
inputSize
,
void
**
m
,
THCudaTensor
*
input_features
,
THCudaTensor
*
output_features
,
THCudaIntTensor
*
rulesBuffer
);
THCudaTensor
*
output_features
,
THCudaIntTensor
*
rulesBuffer
,
long
nPlanes
);
void
scn_gpu_float7SparseToDense_updateGradInput
(
THLongTensor
*
inputSize
,
void
**
m
,
THCudaTensor
*
input_features
,
THCudaTensor
*
d_input_features
,
THCudaTensor
*
d_output_features
,
...
...
@@ -661,7 +662,7 @@ void scn_gpu_float8MaxPooling_updateGradInput(
// SparseToDense
void
scn_gpu_float8SparseToDense_updateOutput
(
THLongTensor
*
inputSize
,
void
**
m
,
THCudaTensor
*
input_features
,
THCudaTensor
*
output_features
,
THCudaIntTensor
*
rulesBuffer
);
THCudaTensor
*
output_features
,
THCudaIntTensor
*
rulesBuffer
,
long
nPlanes
);
void
scn_gpu_float8SparseToDense_updateGradInput
(
THLongTensor
*
inputSize
,
void
**
m
,
THCudaTensor
*
input_features
,
THCudaTensor
*
d_input_features
,
THCudaTensor
*
d_output_features
,
...
...
@@ -738,7 +739,7 @@ void scn_gpu_float9MaxPooling_updateGradInput(
// SparseToDense
void
scn_gpu_float9SparseToDense_updateOutput
(
THLongTensor
*
inputSize
,
void
**
m
,
THCudaTensor
*
input_features
,
THCudaTensor
*
output_features
,
THCudaIntTensor
*
rulesBuffer
);
THCudaTensor
*
output_features
,
THCudaIntTensor
*
rulesBuffer
,
long
nPlanes
);
void
scn_gpu_float9SparseToDense_updateGradInput
(
THLongTensor
*
inputSize
,
void
**
m
,
THCudaTensor
*
input_features
,
THCudaTensor
*
d_input_features
,
THCudaTensor
*
d_output_features
,
...
...
@@ -815,7 +816,7 @@ void scn_gpu_float10MaxPooling_updateGradInput(
// SparseToDense
void
scn_gpu_float10SparseToDense_updateOutput
(
THLongTensor
*
inputSize
,
void
**
m
,
THCudaTensor
*
input_features
,
THCudaTensor
*
output_features
,
THCudaIntTensor
*
rulesBuffer
);
THCudaTensor
*
output_features
,
THCudaIntTensor
*
rulesBuffer
,
long
nPlanes
);
void
scn_gpu_float10SparseToDense_updateGradInput
(
THLongTensor
*
inputSize
,
void
**
m
,
THCudaTensor
*
input_features
,
THCudaTensor
*
d_input_features
,
THCudaTensor
*
d_output_features
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment