Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
SparseConvNet
Commits
e55e3bea
Commit
e55e3bea
authored
Oct 18, 2018
by
Benjamin Thomas Graham
Browse files
ATen API change
parent
55d55a6a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
31 additions
and
31 deletions
+31
-31
sparseconvnet/SCN/CUDA/ActivePooling.cu
sparseconvnet/SCN/CUDA/ActivePooling.cu
+18
-18
sparseconvnet/SCN/CUDA/IOLayers.cpp
sparseconvnet/SCN/CUDA/IOLayers.cpp
+8
-8
sparseconvnet/SCN/CUDA/RuleBookIterator.h
sparseconvnet/SCN/CUDA/RuleBookIterator.h
+1
-1
sparseconvnet/SCN/Metadata/Metadata.cpp
sparseconvnet/SCN/Metadata/Metadata.cpp
+4
-4
No files found.
sparseconvnet/SCN/CUDA/ActivePooling.cu
View file @
e55e3bea
...
@@ -6,8 +6,8 @@
...
@@ -6,8 +6,8 @@
template
<
typename
T
>
template
<
typename
T
>
__global__
void
ActivePooling_fp
(
T
*
input_features
,
T
*
output_features
,
__global__
void
ActivePooling_fp
(
T
*
input_features
,
T
*
output_features
,
Int
maxActive
,
Int
nPlanes
,
Int
*
rules
,
Int
maxActive
,
Int
nPlanes
,
Int
*
rules
,
bool
average
)
{
bool
average
)
{
T
*
out
=
&
output_features
[
blockIdx
.
x
*
nPlanes
];
T
*
out
=
&
output_features
[
blockIdx
.
x
*
nPlanes
];
Int
*
r
=
&
rules
[
blockIdx
.
x
*
(
maxActive
+
1
)];
Int
*
r
=
&
rules
[
blockIdx
.
x
*
(
maxActive
+
1
)];
Int
nActive
=
*
r
++
;
Int
nActive
=
*
r
++
;
...
@@ -20,10 +20,10 @@ __global__ void ActivePooling_fp(T *input_features, T *output_features,
...
@@ -20,10 +20,10 @@ __global__ void ActivePooling_fp(T *input_features, T *output_features,
}
}
template
<
typename
T
>
template
<
typename
T
>
void
ActivePooling_ForwardPass
(
T
*
input_features
,
T
*
output_features
,
void
ActivePooling_ForwardPass
(
T
*
input_features
,
T
*
output_features
,
Int
batchSize
,
Int
maxActive
,
Int
nPlanes
,
Int
batchSize
,
Int
maxActive
,
Int
nPlanes
,
Int
*
rules
,
bool
average
)
{
Int
*
rules
,
bool
average
)
{
auto
rulesBuffer
=
at
::
CUDA
(
at_kINT
).
tensor
({
1
<<
22
}
);
auto
rulesBuffer
=
at
::
empty
({
1
<<
22
},
at
::
CUDA
(
at_kINT
)
);
Int
*
rb
=
rulesBuffer
.
data
<
Int
>
();
Int
*
rb
=
rulesBuffer
.
data
<
Int
>
();
Int
rowBatchSize
=
std
::
min
((
Int
)
32768
,
(
1
<<
22
)
/
(
maxActive
+
1
));
Int
rowBatchSize
=
std
::
min
((
Int
)
32768
,
(
1
<<
22
)
/
(
maxActive
+
1
));
assert
(
rowBatchSize
>
0
);
assert
(
rowBatchSize
>
0
);
...
@@ -32,17 +32,17 @@ void ActivePooling_ForwardPass(T *input_features, T *output_features,
...
@@ -32,17 +32,17 @@ void ActivePooling_ForwardPass(T *input_features, T *output_features,
for
(
Int
o
=
0
;
o
<
batchSize
;
o
+=
rowBatchSize
)
{
for
(
Int
o
=
0
;
o
<
batchSize
;
o
+=
rowBatchSize
)
{
Int
batchSize_
=
std
::
min
(
rowBatchSize
,
(
Int
(
batchSize
-
o
)));
Int
batchSize_
=
std
::
min
(
rowBatchSize
,
(
Int
(
batchSize
-
o
)));
cudaMemcpy
(
rb
,
rules
+
o
*
(
maxActive
+
1
),
cudaMemcpy
(
rb
,
rules
+
o
*
(
maxActive
+
1
),
sizeof
(
Int
)
*
(
maxActive
+
1
)
*
batchSize_
,
sizeof
(
Int
)
*
(
maxActive
+
1
)
*
batchSize_
,
cudaMemcpyHostToDevice
);
cudaMemcpyHostToDevice
);
ActivePooling_fp
<
T
><<<
batchSize_
,
kernelBlockDim
>>>
(
ActivePooling_fp
<
T
><<<
batchSize_
,
kernelBlockDim
>>>
(
input_features
,
output_features
+
0
*
nPlanes
,
maxActive
,
nPlanes
,
input_features
,
output_features
+
0
*
nPlanes
,
maxActive
,
nPlanes
,
rules
,
average
);
rules
,
average
);
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
ActivePooling_bp
(
T
*
d_input_features
,
T
*
d_output_features
,
__global__
void
ActivePooling_bp
(
T
*
d_input_features
,
T
*
d_output_features
,
Int
maxActive
,
Int
nPlanes
,
Int
*
rules
,
Int
maxActive
,
Int
nPlanes
,
Int
*
rules
,
bool
average
)
{
bool
average
)
{
T
*
out
=
&
d_output_features
[
blockIdx
.
x
*
nPlanes
];
T
*
out
=
&
d_output_features
[
blockIdx
.
x
*
nPlanes
];
Int
*
r
=
&
rules
[
blockIdx
.
x
*
(
maxActive
+
1
)];
Int
*
r
=
&
rules
[
blockIdx
.
x
*
(
maxActive
+
1
)];
Int
nActive
=
*
r
++
;
Int
nActive
=
*
r
++
;
...
@@ -56,9 +56,9 @@ __global__ void ActivePooling_bp(T *d_input_features, T *d_output_features,
...
@@ -56,9 +56,9 @@ __global__ void ActivePooling_bp(T *d_input_features, T *d_output_features,
template
<
typename
T
>
template
<
typename
T
>
void
ActivePooling_BackwardPass
(
T
*
d_input_features
,
T
*
d_output_features
,
void
ActivePooling_BackwardPass
(
T
*
d_input_features
,
T
*
d_output_features
,
Int
batchSize
,
Int
maxActive
,
Int
nPlanes
,
Int
batchSize
,
Int
maxActive
,
Int
nPlanes
,
Int
*
rules
,
bool
average
)
{
Int
*
rules
,
bool
average
)
{
auto
rulesBuffer
=
at
::
CUDA
(
at_kINT
).
tensor
({
1
<<
22
}
);
auto
rulesBuffer
=
at
::
empty
({
1
<<
22
},
at
::
CUDA
(
at_kINT
)
);
Int
*
rb
=
rulesBuffer
.
data
<
Int
>
();
Int
*
rb
=
rulesBuffer
.
data
<
Int
>
();
Int
rowBatchSize
=
std
::
min
((
Int
)
32768
,
(
1
<<
22
)
/
(
maxActive
+
1
));
Int
rowBatchSize
=
std
::
min
((
Int
)
32768
,
(
1
<<
22
)
/
(
maxActive
+
1
));
assert
(
rowBatchSize
>
0
);
assert
(
rowBatchSize
>
0
);
...
@@ -67,10 +67,10 @@ void ActivePooling_BackwardPass(T *d_input_features, T *d_output_features,
...
@@ -67,10 +67,10 @@ void ActivePooling_BackwardPass(T *d_input_features, T *d_output_features,
for
(
Int
o
=
0
;
o
<
batchSize
;
o
+=
rowBatchSize
)
{
for
(
Int
o
=
0
;
o
<
batchSize
;
o
+=
rowBatchSize
)
{
Int
batchSize_
=
std
::
min
(
rowBatchSize
,
(
Int
(
batchSize
-
o
)));
Int
batchSize_
=
std
::
min
(
rowBatchSize
,
(
Int
(
batchSize
-
o
)));
cudaMemcpy
(
rb
,
rules
+
o
*
(
maxActive
+
1
),
cudaMemcpy
(
rb
,
rules
+
o
*
(
maxActive
+
1
),
sizeof
(
Int
)
*
(
maxActive
+
1
)
*
batchSize_
,
sizeof
(
Int
)
*
(
maxActive
+
1
)
*
batchSize_
,
cudaMemcpyHostToDevice
);
cudaMemcpyHostToDevice
);
ActivePooling_bp
<
T
><<<
batchSize_
,
kernelBlockDim
>>>
(
ActivePooling_bp
<
T
><<<
batchSize_
,
kernelBlockDim
>>>
(
d_input_features
,
d_output_features
+
o
*
nPlanes
,
maxActive
,
nPlanes
,
d_input_features
,
d_output_features
+
o
*
nPlanes
,
maxActive
,
nPlanes
,
rules
,
average
);
rules
,
average
);
}
}
}
}
sparseconvnet/SCN/CUDA/IOLayers.cpp
View file @
e55e3bea
...
@@ -33,7 +33,7 @@ void cuda_InputLayer_updateOutput(Metadata<Dimension> &m,
...
@@ -33,7 +33,7 @@ void cuda_InputLayer_updateOutput(Metadata<Dimension> &m,
}
else
{
}
else
{
output_features
.
resize_
({
*
m
.
inputNActive
,
nPlanes
});
output_features
.
resize_
({
*
m
.
inputNActive
,
nPlanes
});
output_features
.
zero_
();
output_features
.
zero_
();
auto
rulesBuffer
=
at
::
CUDA
(
at_kINT
).
tensor
({(
int
)
rules
[
1
].
size
()});
auto
rulesBuffer
=
at
::
empty
({(
int
)
rules
[
1
].
size
()}
,
at
::
CUDA
(
at_kINT
)
);
auto
iF
=
input_features
.
data
<
T
>
();
auto
iF
=
input_features
.
data
<
T
>
();
auto
oF
=
output_features
.
data
<
T
>
();
auto
oF
=
output_features
.
data
<
T
>
();
Int
*
rb
=
rulesBuffer
.
data
<
Int
>
();
Int
*
rb
=
rulesBuffer
.
data
<
Int
>
();
...
@@ -58,7 +58,7 @@ void cuda_InputLayer_updateGradInput(
...
@@ -58,7 +58,7 @@ void cuda_InputLayer_updateGradInput(
}
else
{
}
else
{
d_input_features
.
resize_
({
rules
[
0
][
2
],
nPlanes
});
d_input_features
.
resize_
({
rules
[
0
][
2
],
nPlanes
});
d_input_features
.
zero_
();
d_input_features
.
zero_
();
auto
rulesBuffer
=
at
::
CUDA
(
at_kINT
).
tensor
({(
int
)
rules
[
1
].
size
()});
auto
rulesBuffer
=
at
::
empty
({(
int
)
rules
[
1
].
size
()}
,
at
::
CUDA
(
at_kINT
)
);
auto
diF
=
d_input_features
.
data
<
T
>
();
auto
diF
=
d_input_features
.
data
<
T
>
();
auto
doF
=
d_output_features
.
data
<
T
>
();
auto
doF
=
d_output_features
.
data
<
T
>
();
Int
*
rb
=
rulesBuffer
.
data
<
Int
>
();
Int
*
rb
=
rulesBuffer
.
data
<
Int
>
();
...
@@ -83,7 +83,7 @@ void cuda_OutputLayer_updateOutput(Metadata<Dimension> &m,
...
@@ -83,7 +83,7 @@ void cuda_OutputLayer_updateOutput(Metadata<Dimension> &m,
}
else
{
}
else
{
output_features
.
resize_
({
rules
[
0
][
2
],
nPlanes
});
output_features
.
resize_
({
rules
[
0
][
2
],
nPlanes
});
output_features
.
zero_
();
output_features
.
zero_
();
auto
rulesBuffer
=
at
::
CUDA
(
at_kINT
).
tensor
({(
int
)
rules
[
1
].
size
()});
auto
rulesBuffer
=
at
::
empty
({(
int
)
rules
[
1
].
size
()}
,
at
::
CUDA
(
at_kINT
)
);
auto
iF
=
input_features
.
data
<
T
>
();
auto
iF
=
input_features
.
data
<
T
>
();
auto
oF
=
output_features
.
data
<
T
>
();
auto
oF
=
output_features
.
data
<
T
>
();
Int
*
rb
=
rulesBuffer
.
data
<
Int
>
();
Int
*
rb
=
rulesBuffer
.
data
<
Int
>
();
...
@@ -107,7 +107,7 @@ void cuda_OutputLayer_updateGradInput(
...
@@ -107,7 +107,7 @@ void cuda_OutputLayer_updateGradInput(
}
else
{
}
else
{
d_input_features
.
resize_
({
nRows
,
nPlanes
});
d_input_features
.
resize_
({
nRows
,
nPlanes
});
d_input_features
.
zero_
();
d_input_features
.
zero_
();
auto
rulesBuffer
=
at
::
CUDA
(
at_kINT
).
tensor
({(
int
)
rules
[
1
].
size
()});
auto
rulesBuffer
=
at
::
empty
({(
int
)
rules
[
1
].
size
()}
,
at
::
CUDA
(
at_kINT
)
);
auto
diF
=
d_input_features
.
data
<
T
>
();
auto
diF
=
d_input_features
.
data
<
T
>
();
auto
doF
=
d_output_features
.
data
<
T
>
();
auto
doF
=
d_output_features
.
data
<
T
>
();
Int
*
rb
=
rulesBuffer
.
data
<
Int
>
();
Int
*
rb
=
rulesBuffer
.
data
<
Int
>
();
...
@@ -137,7 +137,7 @@ void cuda_BLInputLayer_updateOutput(Metadata<Dimension> &m,
...
@@ -137,7 +137,7 @@ void cuda_BLInputLayer_updateOutput(Metadata<Dimension> &m,
output_features
.
copy_
(
input_features
);
output_features
.
copy_
(
input_features
);
output_features
.
resize_
({
*
m
.
inputNActive
,
nPlanes
});
output_features
.
resize_
({
*
m
.
inputNActive
,
nPlanes
});
}
else
{
}
else
{
auto
rulesBuffer
=
at
::
CUDA
(
at_kINT
).
tensor
({(
int
)
rules
[
1
].
size
()});
auto
rulesBuffer
=
at
::
empty
({(
int
)
rules
[
1
].
size
()}
,
at
::
CUDA
(
at_kINT
)
);
auto
iF
=
input_features
.
data
<
T
>
();
auto
iF
=
input_features
.
data
<
T
>
();
auto
oF
=
output_features
.
data
<
T
>
();
auto
oF
=
output_features
.
data
<
T
>
();
Int
*
rb
=
rulesBuffer
.
data
<
Int
>
();
Int
*
rb
=
rulesBuffer
.
data
<
Int
>
();
...
@@ -164,7 +164,7 @@ void cuda_BLInputLayer_updateGradInput(
...
@@ -164,7 +164,7 @@ void cuda_BLInputLayer_updateGradInput(
}
else
{
}
else
{
d_input_features
.
resize_
({
rules
[
0
][
2
],
rules
[
0
][
3
],
nPlanes
});
d_input_features
.
resize_
({
rules
[
0
][
2
],
rules
[
0
][
3
],
nPlanes
});
d_input_features
.
zero_
();
d_input_features
.
zero_
();
auto
rulesBuffer
=
at
::
CUDA
(
at_kINT
).
tensor
({(
int
)
rules
[
1
].
size
()});
auto
rulesBuffer
=
at
::
empty
({(
int
)
rules
[
1
].
size
()}
,
at
::
CUDA
(
at_kINT
)
);
auto
diF
=
d_input_features
.
data
<
T
>
();
auto
diF
=
d_input_features
.
data
<
T
>
();
auto
doF
=
d_output_features
.
data
<
T
>
();
auto
doF
=
d_output_features
.
data
<
T
>
();
Int
*
rb
=
rulesBuffer
.
data
<
Int
>
();
Int
*
rb
=
rulesBuffer
.
data
<
Int
>
();
...
@@ -191,7 +191,7 @@ void cuda_BLOutputLayer_updateOutput(
...
@@ -191,7 +191,7 @@ void cuda_BLOutputLayer_updateOutput(
}
else
{
}
else
{
output_features
.
resize_
({
rules
[
0
][
2
],
rules
[
0
][
3
],
nPlanes
});
output_features
.
resize_
({
rules
[
0
][
2
],
rules
[
0
][
3
],
nPlanes
});
output_features
.
zero_
();
output_features
.
zero_
();
auto
rulesBuffer
=
at
::
CUDA
(
at_kINT
).
tensor
({(
int
)
rules
[
1
].
size
()});
auto
rulesBuffer
=
at
::
empty
({(
int
)
rules
[
1
].
size
()}
,
at
::
CUDA
(
at_kINT
)
);
auto
iF
=
input_features
.
data
<
T
>
();
auto
iF
=
input_features
.
data
<
T
>
();
auto
oF
=
output_features
.
data
<
T
>
();
auto
oF
=
output_features
.
data
<
T
>
();
Int
*
rb
=
rulesBuffer
.
data
<
Int
>
();
Int
*
rb
=
rulesBuffer
.
data
<
Int
>
();
...
@@ -216,7 +216,7 @@ void cuda_BLOutputLayer_updateGradInput(
...
@@ -216,7 +216,7 @@ void cuda_BLOutputLayer_updateGradInput(
}
else
{
}
else
{
d_input_features
.
resize_
({
nRows
,
nPlanes
});
d_input_features
.
resize_
({
nRows
,
nPlanes
});
d_input_features
.
zero_
();
d_input_features
.
zero_
();
auto
rulesBuffer
=
at
::
CUDA
(
at_kINT
).
tensor
({(
int
)
rules
[
1
].
size
()});
auto
rulesBuffer
=
at
::
empty
({(
int
)
rules
[
1
].
size
()}
,
at
::
CUDA
(
at_kINT
)
);
auto
diF
=
d_input_features
.
data
<
T
>
();
auto
diF
=
d_input_features
.
data
<
T
>
();
auto
doF
=
d_output_features
.
data
<
T
>
();
auto
doF
=
d_output_features
.
data
<
T
>
();
Int
*
rb
=
rulesBuffer
.
data
<
Int
>
();
Int
*
rb
=
rulesBuffer
.
data
<
Int
>
();
...
...
sparseconvnet/SCN/CUDA/RuleBookIterator.h
View file @
e55e3bea
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
Int rbMaxSize = 0; \
Int rbMaxSize = 0; \
for (auto &r : _rules) \
for (auto &r : _rules) \
rbMaxSize = std::max(rbMaxSize, (Int)r.size()); \
rbMaxSize = std::max(rbMaxSize, (Int)r.size()); \
at::Tensor rulesBuffer = at::
CUDA(at_kINT).tensor({rbMaxSize});
\
at::Tensor rulesBuffer = at::
empty({rbMaxSize}, at::CUDA(at_kINT));
\
Int *rbB = rulesBuffer.data<Int>(); \
Int *rbB = rulesBuffer.data<Int>(); \
for (int k = 0; k < _rules.size(); ++k) { \
for (int k = 0; k < _rules.size(); ++k) { \
auto &r = _rules[k]; \
auto &r = _rules[k]; \
...
...
sparseconvnet/SCN/Metadata/Metadata.cpp
View file @
e55e3bea
...
@@ -588,13 +588,13 @@ Metadata<dimension>::compareSparseHelper(Metadata<dimension> &mR,
...
@@ -588,13 +588,13 @@ Metadata<dimension>::compareSparseHelper(Metadata<dimension> &mR,
}
}
}
}
}
}
at
::
Tensor
cL_
=
torch
::
CPU
(
at
::
kLong
).
tensor
({(
long
)
cL
.
size
()}
);
at
::
Tensor
cL_
=
at
::
empty
({(
long
)
cL
.
size
()},
at
::
CPU
(
at
::
kLong
)
);
std
::
memcpy
(
cL_
.
data
<
long
>
(),
&
cL
[
0
],
cL
.
size
()
*
sizeof
(
long
));
std
::
memcpy
(
cL_
.
data
<
long
>
(),
&
cL
[
0
],
cL
.
size
()
*
sizeof
(
long
));
at
::
Tensor
cR_
=
torch
::
CPU
(
at
::
kLong
).
tensor
({(
long
)
cR
.
size
()}
);
at
::
Tensor
cR_
=
at
::
empty
({(
long
)
cR
.
size
()},
at
::
CPU
(
at
::
kLong
)
);
std
::
memcpy
(
cR_
.
data
<
long
>
(),
&
cR
[
0
],
cR
.
size
()
*
sizeof
(
long
));
std
::
memcpy
(
cR_
.
data
<
long
>
(),
&
cR
[
0
],
cR
.
size
()
*
sizeof
(
long
));
at
::
Tensor
L_
=
torch
::
CPU
(
at
::
kLong
).
tensor
({(
long
)
L
.
size
()}
);
at
::
Tensor
L_
=
at
::
empty
({(
long
)
L
.
size
()},
at
::
CPU
(
at
::
kLong
)
);
std
::
memcpy
(
L_
.
data
<
long
>
(),
&
L
[
0
],
L
.
size
()
*
sizeof
(
long
));
std
::
memcpy
(
L_
.
data
<
long
>
(),
&
L
[
0
],
L
.
size
()
*
sizeof
(
long
));
at
::
Tensor
R_
=
torch
::
CPU
(
at
::
kLong
).
tensor
({(
long
)
R
.
size
()}
);
at
::
Tensor
R_
=
at
::
empty
({(
long
)
R
.
size
()},
at
::
CPU
(
at
::
kLong
)
);
std
::
memcpy
(
R_
.
data
<
long
>
(),
&
R
[
0
],
R
.
size
()
*
sizeof
(
long
));
std
::
memcpy
(
R_
.
data
<
long
>
(),
&
R
[
0
],
R
.
size
()
*
sizeof
(
long
));
return
{
cL_
,
cR_
,
L_
,
R_
};
return
{
cL_
,
cR_
,
L_
,
R_
};
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment