Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
SparseConvNet
Commits
29c810b2
"vscode:/vscode.git/clone" did not exist on "a35f429768b3aa6ea2e7cd0a6452e1fcb7671c8d"
Commit
29c810b2
authored
Jul 16, 2018
by
Benjamin Thomas Graham
Browse files
OpenMP for some CPU ops
parent
2f6072ed
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
91 additions
and
50 deletions
+91
-50
sparseconvnet/SCN/CPU/ActivePooling.cpp
sparseconvnet/SCN/CPU/ActivePooling.cpp
+8
-4
sparseconvnet/SCN/CPU/AveragePooling.cpp
sparseconvnet/SCN/CPU/AveragePooling.cpp
+6
-2
sparseconvnet/SCN/CPU/BatchNormalization.cpp
sparseconvnet/SCN/CPU/BatchNormalization.cpp
+22
-24
sparseconvnet/SCN/CPU/IOLayers.cpp
sparseconvnet/SCN/CPU/IOLayers.cpp
+8
-4
sparseconvnet/SCN/CPU/LeakyReLU.cpp
sparseconvnet/SCN/CPU/LeakyReLU.cpp
+9
-4
sparseconvnet/SCN/CPU/MaxPooling.cpp
sparseconvnet/SCN/CPU/MaxPooling.cpp
+6
-2
sparseconvnet/SCN/CPU/SparseToDense.cpp
sparseconvnet/SCN/CPU/SparseToDense.cpp
+6
-3
sparseconvnet/SCN/CPU/UnPooling.cpp
sparseconvnet/SCN/CPU/UnPooling.cpp
+6
-2
sparseconvnet/SCN/CUDA/ActivePooling.cu
sparseconvnet/SCN/CUDA/ActivePooling.cu
+2
-2
sparseconvnet/SCN/CUDA/IOLayers.cu
sparseconvnet/SCN/CUDA/IOLayers.cu
+2
-2
sparseconvnet/__init__.py
sparseconvnet/__init__.py
+1
-1
sparseconvnet/utils.py
sparseconvnet/utils.py
+15
-0
No files found.
sparseconvnet/SCN/CPU/ActivePooling.cpp
View file @
29c810b2
...
...
@@ -9,11 +9,13 @@ template <typename T>
void
ActivePooling_ForwardPass
(
T
*
input_features
,
T
*
output_features
,
Int
batchSize
,
Int
maxActive
,
Int
nPlanes
,
RuleBook
&
rules
,
bool
average
)
{
for
(
Int
outSite
=
0
;
outSite
<
batchSize
;
outSite
++
)
{
Int
outSite
;
#pragma omp parallel for private(outSite)
for
(
outSite
=
0
;
outSite
<
batchSize
;
outSite
++
)
{
T
*
out
=
&
output_features
[
outSite
*
nPlanes
];
Int
*
r
=
&
rules
[
0
][
outSite
*
(
maxActive
+
1
)];
Int
nActive
=
*
r
++
;
T
multiplier
=
(
average
and
nActive
>
0
)
?
1.0
f
/
nActive
:
1.0
f
;
T
multiplier
=
(
average
and
nActive
>
0
)
?
(
T
)
1
/
nActive
:
(
T
)
1
;
while
(
nActive
--
>
0
)
{
T
*
inp
=
&
input_features
[(
*
r
++
)
*
nPlanes
];
for
(
Int
plane
=
0
;
plane
<
nPlanes
;
plane
++
)
...
...
@@ -25,11 +27,13 @@ template <typename T>
void
ActivePooling_BackwardPass
(
T
*
d_input_features
,
T
*
d_output_features
,
Int
batchSize
,
Int
maxActive
,
Int
nPlanes
,
RuleBook
&
rules
,
bool
average
)
{
for
(
Int
outSite
=
0
;
outSite
<
batchSize
;
outSite
++
)
{
Int
outSite
;
#pragma omp parallel for private(outSite)
for
(
outSite
=
0
;
outSite
<
batchSize
;
outSite
++
)
{
T
*
out
=
&
d_output_features
[
outSite
*
nPlanes
];
Int
*
r
=
&
rules
[
0
][
outSite
*
(
maxActive
+
1
)];
Int
nActive
=
*
r
++
;
T
multiplier
=
(
average
and
nActive
>
0
)
?
1.0
f
/
nActive
:
1.0
f
;
T
multiplier
=
(
average
and
nActive
>
0
)
?
(
T
)
1
/
nActive
:
(
T
)
1
;
while
(
nActive
--
>
0
)
{
T
*
inp
=
&
d_input_features
[(
*
r
++
)
*
nPlanes
];
for
(
Int
plane
=
0
;
plane
<
nPlanes
;
plane
++
)
...
...
sparseconvnet/SCN/CPU/AveragePooling.cpp
View file @
29c810b2
...
...
@@ -9,7 +9,9 @@ void AveragePooling_ForwardPass(T *input_features, T *output_features,
Int
nPlanes
,
Int
input_stride
,
Int
output_stride
,
Int
*
rules
,
Int
nHot
,
Int
filterVolume
)
{
for
(
Int
outSite
=
0
;
outSite
<
nHot
;
outSite
++
)
{
Int
outSite
;
#pragma omp parallel for private(outSite)
for
(
outSite
=
0
;
outSite
<
nHot
;
outSite
++
)
{
Int
i
=
rules
[
2
*
outSite
]
*
input_stride
;
Int
o
=
rules
[
2
*
outSite
+
1
]
*
output_stride
;
for
(
Int
plane
=
0
;
plane
<
nPlanes
;
plane
++
)
...
...
@@ -21,7 +23,9 @@ void AveragePooling_BackwardPass(T *d_input_features, T *d_output_features,
Int
nPlanes
,
Int
input_stride
,
Int
output_stride
,
Int
*
rules
,
Int
nHot
,
Int
filterVolume
)
{
for
(
Int
outSite
=
0
;
outSite
<
nHot
;
outSite
++
)
{
Int
outSite
;
#pragma omp parallel for private(outSite)
for
(
outSite
=
0
;
outSite
<
nHot
;
outSite
++
)
{
Int
i
=
rules
[
2
*
outSite
]
*
input_stride
;
Int
o
=
rules
[
2
*
outSite
+
1
]
*
output_stride
;
for
(
Int
plane
=
0
;
plane
<
nPlanes
;
plane
++
)
...
...
sparseconvnet/SCN/CPU/BatchNormalization.cpp
View file @
29c810b2
...
...
@@ -19,28 +19,21 @@ void BatchNormalization_ForwardPass(T *input_features, T *output_features,
if
(
train
)
{
std
::
memset
(
saveMean
,
0
,
nPlanes
*
sizeof
(
T
));
std
::
memset
(
saveInvStd
,
0
,
nPlanes
*
sizeof
(
T
));
for
(
Int
row
=
0
,
ci
=
0
;
row
<
nActive
;
row
++
,
ci
+=
input_stride
-
nPlanes
)
{
for
(
Int
row
=
0
;
row
<
nActive
;
row
++
)
{
Int
ci
=
row
*
input_stride
;
for
(
Int
plane
=
0
;
plane
<
nPlanes
;
plane
++
,
ci
++
)
{
saveMean
[
plane
]
+=
input_features
[
ci
];
T
ifci
=
input_features
[
ci
];
saveMean
[
plane
]
+=
ifci
;
saveInvStd
[
plane
]
+=
ifci
*
ifci
;
// accumulate sum-squares
// before inverse square
// rooting
}
}
for
(
Int
plane
=
0
;
plane
<
nPlanes
;
plane
++
)
{
saveMean
[
plane
]
/=
nActive
;
runningMean
[
plane
]
=
momentum
*
runningMean
[
plane
]
+
(
1
-
momentum
)
*
saveMean
[
plane
];
}
for
(
Int
row
=
0
,
ci
=
0
;
row
<
nActive
;
row
++
,
ci
+=
input_stride
-
nPlanes
)
{
for
(
Int
plane
=
0
;
plane
<
nPlanes
;
plane
++
,
ci
++
)
{
saveInvStd
[
plane
]
+=
(
input_features
[
ci
]
-
saveMean
[
plane
])
*
(
input_features
[
ci
]
-
saveMean
[
plane
]);
// accumulate sum-squares
// before inverse square
// rooting
}
}
for
(
Int
plane
=
0
;
plane
<
nPlanes
;
plane
++
)
{
saveInvStd
[
plane
]
-=
saveMean
[
plane
]
*
saveMean
[
plane
]
*
nActive
;
runningVar
[
plane
]
=
momentum
*
runningVar
[
plane
]
+
(
1
-
momentum
)
*
saveInvStd
[
plane
]
/
(
nActive
-
1
);
saveInvStd
[
plane
]
=
powf
(
saveInvStd
[
plane
]
/
nActive
+
eps
,
-
0.5
);
...
...
@@ -57,12 +50,13 @@ void BatchNormalization_ForwardPass(T *input_features, T *output_features,
w
[
plane
]
=
saveInvStd
[
plane
]
*
(
weight
?
weight
[
plane
]
:
1
);
b
[
plane
]
=
-
saveMean
[
plane
]
*
w
[
plane
]
+
(
bias
?
bias
[
plane
]
:
0
);
}
for
(
Int
row
=
0
,
ci
=
0
,
co
=
0
;
row
<
nActive
;
row
++
,
ci
+=
input_stride
-
nPlanes
,
co
+=
output_stride
-
nPlanes
)
{
for
(
Int
row
=
0
;
row
<
nActive
;
row
++
)
{
Int
ci
=
row
*
input_stride
;
Int
co
=
row
*
output_stride
;
for
(
Int
plane
=
0
;
plane
<
nPlanes
;
plane
++
,
ci
++
,
co
++
)
{
T
out
=
input_features
[
ci
]
*
w
[
plane
]
+
b
[
plane
];
out
=
(
out
>
0
)
?
out
:
(
out
*
leakiness
)
;
output_features
[
co
]
=
out
;
const
T
r
=
(
out
>
0
)
?
1
:
leakiness
;
output_features
[
co
]
=
out
*
r
;
}
}
}
...
...
@@ -78,11 +72,13 @@ void BatchNormalization_BackwardPass(T *input_features, T *d_input_features,
std
::
vector
<
T
>
gradMean
(
nPlanes
);
std
::
vector
<
T
>
dotp
(
nPlanes
);
std
::
vector
<
T
>
k
(
nPlanes
);
for
(
Int
row
=
0
,
ci
=
0
,
co
=
0
;
row
<
nActive
;
row
++
,
ci
+=
input_stride
-
nPlanes
,
co
+=
output_stride
-
nPlanes
)
{
for
(
Int
row
=
0
;
row
<
nActive
;
row
++
)
{
Int
ci
=
row
*
input_stride
;
Int
co
=
row
*
output_stride
;
for
(
Int
plane
=
0
;
plane
<
nPlanes
;
plane
++
,
ci
++
,
co
++
)
{
T
d
=
d_output_features
[
co
];
d
=
(
output_features
[
co
]
>
0
)
?
d
:
(
d
*
leakiness
);
const
T
r
=
(
output_features
[
co
]
>
0
)
?
1
:
leakiness
;
d
*=
r
;
d_output_features
[
co
]
=
d
;
gradMean
[
plane
]
+=
d
;
dotp
[
plane
]
+=
(
input_features
[
ci
]
-
saveMean
[
plane
])
*
d
;
...
...
@@ -94,8 +90,9 @@ void BatchNormalization_BackwardPass(T *input_features, T *d_input_features,
gradMean
[
plane
]
/=
nActive
;
// ...now
k
[
plane
]
=
dotp
[
plane
]
*
saveInvStd
[
plane
]
*
saveInvStd
[
plane
]
/
nActive
;
}
for
(
Int
row
=
0
,
ci
=
0
,
co
=
0
;
row
<
nActive
;
row
++
,
ci
+=
input_stride
-
nPlanes
,
co
+=
output_stride
-
nPlanes
)
{
for
(
Int
row
=
0
;
row
<
nActive
;
row
++
)
{
Int
ci
=
row
*
input_stride
;
Int
co
=
row
*
output_stride
;
for
(
Int
plane
=
0
;
plane
<
nPlanes
;
plane
++
,
ci
++
,
co
++
)
{
d_input_features
[
ci
]
=
(
d_output_features
[
co
]
-
gradMean
[
plane
]
-
...
...
@@ -158,3 +155,4 @@ void cpu_BatchNormalization_backward(
leakiness
);
}
}
sparseconvnet/SCN/CPU/IOLayers.cpp
View file @
29c810b2
...
...
@@ -12,9 +12,11 @@ template <typename T>
void
InputLayer_ForwardPass
(
T
*
input_features
,
T
*
output_features
,
Int
nRows
,
Int
maxActive
,
Int
nPlanes
,
Int
*
rules
,
bool
average
)
{
for
(
Int
row
=
0
;
row
<
nRows
;
row
++
)
{
Int
row
;
#pragma omp parallel for private(row)
for
(
row
=
0
;
row
<
nRows
;
row
++
)
{
auto
nActive
=
rules
[
0
];
T
multiplier
=
(
average
and
nActive
>
0
)
?
1.0
f
/
nActive
:
1.0
f
;
T
multiplier
=
(
average
and
nActive
>
0
)
?
(
T
)
1
/
nActive
:
(
T
)
1
;
for
(
Int
i
=
1
;
i
<=
nActive
;
++
i
)
{
auto
in_f
=
input_features
+
nPlanes
*
rules
[
i
];
for
(
Int
plane
=
0
;
plane
<
nPlanes
;
plane
++
)
{
...
...
@@ -29,9 +31,11 @@ template <typename T>
void
InputLayer_BackwardPass
(
T
*
d_input_features
,
T
*
d_output_features
,
Int
nRows
,
Int
maxActive
,
Int
nPlanes
,
Int
*
rules
,
bool
average
)
{
for
(
Int
row
=
0
;
row
<
nRows
;
row
++
)
{
Int
row
;
#pragma omp parallel for private(row)
for
(
row
=
0
;
row
<
nRows
;
row
++
)
{
auto
nActive
=
rules
[
0
];
T
multiplier
=
(
average
and
nActive
>
0
)
?
1.0
f
/
nActive
:
1.0
f
;
T
multiplier
=
(
average
and
nActive
>
0
)
?
(
T
)
1
/
nActive
:
(
T
)
1
;
for
(
Int
i
=
1
;
i
<=
nActive
;
++
i
)
{
auto
d_in_f
=
d_input_features
+
nPlanes
*
rules
[
i
];
for
(
Int
plane
=
0
;
plane
<
nPlanes
;
plane
++
)
...
...
sparseconvnet/SCN/CPU/LeakyReLU.cpp
View file @
29c810b2
...
...
@@ -12,8 +12,11 @@ void cpu_LeakyReLU_updateOutput(/*float*/ at::Tensor input_features,
auto
oF
=
output_features
.
data
<
T
>
();
auto
n
=
input_features
.
numel
();
for
(
Int
i
=
0
;
i
<
n
;
i
++
)
oF
[
i
]
=
(
iF
[
i
]
>
0
)
?
iF
[
i
]
:
iF
[
i
]
*
alpha
;
for
(
Int
i
=
0
;
i
<
n
;
i
++
)
{
const
T
x
=
iF
[
i
];
const
T
r
=
(
x
>
0
)
?
1
:
alpha
;
oF
[
i
]
=
x
*
r
;
}
}
template
<
typename
T
>
void
cpu_LeakyReLU_updateGradInput
(
/*float*/
at
::
Tensor
input_features
,
...
...
@@ -26,6 +29,8 @@ void cpu_LeakyReLU_updateGradInput(/*float*/ at::Tensor input_features,
auto
doF
=
d_output_features
.
data
<
T
>
();
auto
n
=
d_input_features
.
numel
();
for
(
Int
i
=
0
;
i
<
n
;
i
++
)
diF
[
i
]
=
(
iF
[
i
]
>
0
)
?
doF
[
i
]
:
doF
[
i
]
*
alpha
;
for
(
Int
i
=
0
;
i
<
n
;
i
++
)
{
const
T
r
=
(
iF
[
i
]
>
0
)
?
1
:
alpha
;
diF
[
i
]
=
doF
[
i
]
*
r
;
}
}
sparseconvnet/SCN/CPU/MaxPooling.cpp
View file @
29c810b2
...
...
@@ -8,7 +8,9 @@ template <typename T>
void
MaxPooling_ForwardPass
(
T
*
input_features
,
T
*
output_features
,
Int
nPlanes
,
Int
input_stride
,
Int
output_stride
,
Int
*
rules
,
Int
nHot
)
{
for
(
Int
outSite
=
0
;
outSite
<
nHot
;
outSite
++
)
{
Int
outSite
;
#pragma omp parallel for private(outSite)
for
(
outSite
=
0
;
outSite
<
nHot
;
outSite
++
)
{
Int
i
=
rules
[
2
*
outSite
]
*
input_stride
;
Int
o
=
rules
[
2
*
outSite
+
1
]
*
output_stride
;
for
(
Int
plane
=
0
;
plane
<
nPlanes
;
plane
++
)
...
...
@@ -21,7 +23,9 @@ void MaxPooling_BackwardPass(T *input_features, T *d_input_features,
T
*
output_features
,
T
*
d_output_features
,
Int
nPlanes
,
Int
input_stride
,
Int
output_stride
,
Int
*
rules
,
Int
nHot
)
{
for
(
Int
outSite
=
0
;
outSite
<
nHot
;
outSite
++
)
{
Int
outSite
;
#pragma omp parallel for private(outSite)
for
(
outSite
=
0
;
outSite
<
nHot
;
outSite
++
)
{
Int
i
=
rules
[
2
*
outSite
]
*
input_stride
;
Int
o
=
rules
[
2
*
outSite
+
1
]
*
output_stride
;
for
(
Int
plane
=
0
;
plane
<
nPlanes
;
plane
++
)
...
...
sparseconvnet/SCN/CPU/SparseToDense.cpp
View file @
29c810b2
...
...
@@ -8,7 +8,9 @@ template <typename T>
void
SparseToDense_ForwardPass
(
T
*
input_features
,
T
*
output_features
,
Int
nPlanes
,
Int
spatialVolume
,
Int
*
rules
,
int
nHot
)
{
for
(
Int
outSite
=
0
;
outSite
<
nHot
;
outSite
++
)
{
Int
outSite
;
#pragma omp parallel for private(outSite)
for
(
outSite
=
0
;
outSite
<
nHot
;
outSite
++
)
{
T
*
i
=
input_features
+
rules
[
2
*
outSite
]
*
nPlanes
;
T
*
o
=
output_features
+
rules
[
2
*
outSite
+
1
];
for
(
Int
plane
=
0
;
plane
<
nPlanes
;
plane
++
)
...
...
@@ -20,8 +22,9 @@ template <typename T>
void
SparseToDense_BackwardPass
(
T
*
d_input_features
,
T
*
d_output_features
,
Int
nPlanes
,
Int
spatialVolume
,
Int
*
rules
,
int
nHot
)
{
for
(
Int
outSite
=
0
;
outSite
<
nHot
;
outSite
++
)
{
Int
outSite
;
#pragma omp parallel for private(outSite)
for
(
outSite
=
0
;
outSite
<
nHot
;
outSite
++
)
{
T
*
d_i
=
d_input_features
+
rules
[
2
*
outSite
]
*
nPlanes
;
T
*
d_o
=
d_output_features
+
rules
[
2
*
outSite
+
1
];
for
(
Int
plane
=
0
;
plane
<
nPlanes
;
plane
++
)
...
...
sparseconvnet/SCN/CPU/UnPooling.cpp
View file @
29c810b2
...
...
@@ -8,7 +8,9 @@ template <typename T>
void
UnPooling_ForwardPass
(
T
*
input_features
,
T
*
output_features
,
Int
nPlanes
,
Int
input_stride
,
Int
output_stride
,
Int
*
rules
,
Int
nHot
)
{
for
(
Int
outSite
=
0
;
outSite
<
nHot
;
outSite
++
)
{
Int
outSite
;
#pragma omp parallel for private(outSite)
for
(
outSite
=
0
;
outSite
<
nHot
;
outSite
++
)
{
Int
i
=
rules
[
2
*
outSite
+
1
]
*
input_stride
;
Int
o
=
rules
[
2
*
outSite
]
*
output_stride
;
for
(
Int
plane
=
0
;
plane
<
nPlanes
;
plane
++
)
...
...
@@ -19,7 +21,9 @@ template <typename T>
void
UnPooling_BackwardPass
(
T
*
d_input_features
,
T
*
d_output_features
,
Int
nPlanes
,
Int
input_stride
,
Int
output_stride
,
Int
*
rules
,
Int
nHot
)
{
for
(
Int
outSite
=
0
;
outSite
<
nHot
;
outSite
++
)
{
Int
outSite
;
#pragma omp parallel for private(outSite)
for
(
outSite
=
0
;
outSite
<
nHot
;
outSite
++
)
{
Int
i
=
rules
[
2
*
outSite
+
1
]
*
input_stride
;
Int
o
=
rules
[
2
*
outSite
]
*
output_stride
;
for
(
Int
plane
=
0
;
plane
<
nPlanes
;
plane
++
)
...
...
sparseconvnet/SCN/CUDA/ActivePooling.cu
View file @
29c810b2
...
...
@@ -11,7 +11,7 @@ __global__ void ActivePooling_fp(T *input_features, T *output_features,
T
*
out
=
&
output_features
[
blockIdx
.
x
*
nPlanes
];
Int
*
r
=
&
rules
[
blockIdx
.
x
*
(
maxActive
+
1
)];
Int
nActive
=
*
r
++
;
T
multiplier
=
(
average
and
nActive
>
0
)
?
1.0
f
/
nActive
:
1.0
f
;
T
multiplier
=
(
average
and
nActive
>
0
)
?
(
T
)
1
/
nActive
:
(
T
)
1
;
while
(
nActive
--
>
0
)
{
T
*
inp
=
&
input_features
[(
*
r
++
)
*
nPlanes
];
for
(
Int
plane
=
threadIdx
.
x
;
plane
<
nPlanes
;
plane
+=
32
)
...
...
@@ -46,7 +46,7 @@ __global__ void ActivePooling_bp(T *d_input_features, T *d_output_features,
T
*
out
=
&
d_output_features
[
blockIdx
.
x
*
nPlanes
];
Int
*
r
=
&
rules
[
blockIdx
.
x
*
(
maxActive
+
1
)];
Int
nActive
=
*
r
++
;
T
multiplier
=
(
average
and
nActive
>
0
)
?
1.0
f
/
nActive
:
1.0
f
;
T
multiplier
=
(
average
and
nActive
>
0
)
?
(
T
)
1
/
nActive
:
(
T
)
1
;
while
(
nActive
--
>
0
)
{
T
*
inp
=
&
d_input_features
[(
*
r
++
)
*
nPlanes
];
for
(
Int
plane
=
threadIdx
.
x
;
plane
<
nPlanes
;
plane
+=
32
)
...
...
sparseconvnet/SCN/CUDA/IOLayers.cu
View file @
29c810b2
...
...
@@ -19,7 +19,7 @@ __global__ void InputLayer_fp_(T *input_features, T *output_features, Int nRows,
T
*
out
=
output_features
+
row
*
nPlanes
;
Int
*
r
=
rules
+
row
*
(
1
+
maxActive
);
Int
nActive
=
r
[
0
];
T
multiplier
=
(
average
and
nActive
>
0
)
?
1.0
f
/
nActive
:
1.0
f
;
T
multiplier
=
(
average
and
nActive
>
0
)
?
(
T
)
1
/
nActive
:
(
T
)
1
;
for
(
int
i
=
1
;
i
<=
nActive
;
i
++
)
{
T
*
inp
=
input_features
+
r
[
i
]
*
nPlanes
;
for
(
Int
plane
=
threadIdx
.
x
;
plane
<
nPlanes
;
plane
+=
blockDim
.
x
)
...
...
@@ -48,7 +48,7 @@ __global__ void InputLayer_bp_(T *d_input_features, T *d_output_features,
T
*
out
=
d_output_features
+
row
*
nPlanes
;
Int
*
r
=
rules
+
row
*
(
1
+
maxActive
);
Int
nActive
=
r
[
0
];
T
multiplier
=
(
average
and
nActive
>
0
)
?
1.0
f
/
nActive
:
1.0
f
;
T
multiplier
=
(
average
and
nActive
>
0
)
?
(
T
)
1
/
nActive
:
(
T
)
1
;
for
(
int
i
=
1
;
i
<=
nActive
;
i
++
)
{
T
*
inp
=
d_input_features
+
r
[
i
]
*
nPlanes
;
for
(
Int
plane
=
threadIdx
.
x
;
plane
<
nPlanes
;
plane
+=
blockDim
.
x
)
...
...
sparseconvnet/__init__.py
View file @
29c810b2
...
...
@@ -32,7 +32,7 @@ from .spectral_norm import spectral_norm
from
.submanifoldConvolution
import
SubmanifoldConvolution
,
ValidConvolution
from
.tables
import
*
from
.unPooling
import
UnPooling
from
.utils
import
appendSparseConvTensors
from
.utils
import
appendSparseConvTensors
,
AddCoords
def
concatenate_feature_planes
(
input
):
output
=
SparseConvNetTensor
()
...
...
sparseconvnet/utils.py
View file @
29c810b2
...
...
@@ -59,3 +59,18 @@ def appendSparseConvTensors(tensors):
for
t
in
tensors
:
x
.
metadata
.
appendMetadata
(
t
.
metadata
,
spatial_size
)
return
x
class
AddCoords
(
torch
.
nn
.
Module
):
def
forward
(
self
,
input
):
output
=
SparseConvNetTensor
()
if
input
.
features
.
numel
():
with
torch
.
no_grad
():
coords
=
input
.
get_spatial_locations
()
d
=
(
input
.
spatial_size
.
type_as
(
input
.
features
)
-
1
)
/
2
coords
=
coords
[:,:
-
1
].
type_as
(
input
.
features
)
/
d
[
None
,:]
-
1
output
.
features
=
torch
.
cat
([
input
.
features
,
coords
],
1
)
else
:
output
.
features
=
input
.
features
output
.
metadata
=
input
.
metadata
output
.
spatial_size
=
input
.
spatial_size
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment