Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
SparseConvNet
Commits
556a8c06
Commit
556a8c06
authored
Oct 11, 2018
by
Benjamin Thomas Graham
Browse files
refactor index_select
parent
ed2a1c04
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
56 deletions
+26
-56
sparseconvnet/SCN/CPU/Convolution.cpp
sparseconvnet/SCN/CPU/Convolution.cpp
+23
-48
sparseconvnet/SCN/CPU/Deconvolution.cpp
sparseconvnet/SCN/CPU/Deconvolution.cpp
+3
-8
No files found.
sparseconvnet/SCN/CPU/Convolution.cpp
View file @
556a8c06
...
@@ -6,15 +6,16 @@
...
@@ -6,15 +6,16 @@
#include <cstring>
#include <cstring>
template
<
typename
T
>
template
<
typename
T
>
void
rule_index_select
(
at
::
Tensor
target
,
at
::
Tensor
src
,
Int
nRules
,
at
::
Tensor
rule_index_select
(
at
::
Tensor
src
,
Int
nRules
,
Int
*
rules
)
{
Int
*
rules
)
{
auto
n
=
src
.
size
(
1
);
auto
target
=
at
::
empty
({
nRules
,
n
},
src
.
type
());
auto
t_ptr
=
target
.
data
<
T
>
();
auto
t_ptr
=
target
.
data
<
T
>
();
auto
s_ptr
=
src
.
data
<
T
>
();
auto
s_ptr
=
src
.
data
<
T
>
();
auto
n
=
target
.
size
(
1
);
#pragma omp parallel for
Int
i
;
for
(
Int
i
=
0
;
i
<
nRules
;
++
i
)
#pragma omp parallel for private(i)
for
(
i
=
0
;
i
<
nRules
;
++
i
)
std
::
memcpy
(
t_ptr
+
i
*
n
,
s_ptr
+
rules
[
2
*
i
]
*
n
,
sizeof
(
T
)
*
n
);
std
::
memcpy
(
t_ptr
+
i
*
n
,
s_ptr
+
rules
[
2
*
i
]
*
n
,
sizeof
(
T
)
*
n
);
return
target
;
}
}
template
<
typename
T
>
template
<
typename
T
>
void
rule_index_add_
(
at
::
Tensor
target
,
at
::
Tensor
src
,
Int
nRules
,
void
rule_index_add_
(
at
::
Tensor
target
,
at
::
Tensor
src
,
Int
nRules
,
...
@@ -22,9 +23,8 @@ void rule_index_add_(at::Tensor target, at::Tensor src, Int nRules,
...
@@ -22,9 +23,8 @@ void rule_index_add_(at::Tensor target, at::Tensor src, Int nRules,
auto
t_ptr
=
target
.
data
<
T
>
();
auto
t_ptr
=
target
.
data
<
T
>
();
auto
s_ptr
=
src
.
data
<
T
>
();
auto
s_ptr
=
src
.
data
<
T
>
();
auto
n
=
target
.
size
(
1
);
auto
n
=
target
.
size
(
1
);
Int
i
;
#pragma omp parallel for
#pragma omp parallel for private(i)
for
(
Int
i
=
0
;
i
<
nRules
;
++
i
)
{
for
(
i
=
0
;
i
<
nRules
;
++
i
)
{
auto
t
=
t_ptr
+
rules
[
2
*
i
]
*
n
;
auto
t
=
t_ptr
+
rules
[
2
*
i
]
*
n
;
auto
s
=
s_ptr
+
i
*
n
;
auto
s
=
s_ptr
+
i
*
n
;
for
(
int
j
=
0
;
j
<
n
;
++
j
)
for
(
int
j
=
0
;
j
<
n
;
++
j
)
...
@@ -62,8 +62,7 @@ double cpu_Convolution_updateOutput(
...
@@ -62,8 +62,7 @@ double cpu_Convolution_updateOutput(
// auto w = weight.select(0, i);
// auto w = weight.select(0, i);
// auto output_rows = at::mm(input_rows, w);
// auto output_rows = at::mm(input_rows, w);
// output_features.index_add_(0, rt.select(1, 1), output_rows);
// output_features.index_add_(0, rt.select(1, 1), output_rows);
auto
input_rows
=
at
::
empty
({
nRules
,
ip
},
input_features
.
type
());
auto
input_rows
=
rule_index_select
<
T
>
(
input_features
,
nRules
,
&
r
[
0
]);
rule_index_select
<
T
>
(
input_rows
,
input_features
,
nRules
,
&
r
[
0
]);
auto
w
=
weight
.
select
(
0
,
i
);
auto
w
=
weight
.
select
(
0
,
i
);
auto
output_rows
=
at
::
mm
(
input_rows
,
w
);
auto
output_rows
=
at
::
mm
(
input_rows
,
w
);
rule_index_add_
<
T
>
(
output_features
,
output_rows
,
nRules
,
&
r
[
1
]);
rule_index_add_
<
T
>
(
output_features
,
output_rows
,
nRules
,
&
r
[
1
]);
...
@@ -90,8 +89,6 @@ void cpu_Convolution_backward(
...
@@ -90,8 +89,6 @@ void cpu_Convolution_backward(
if
(
nActive
and
d_bias
.
numel
())
if
(
nActive
and
d_bias
.
numel
())
at
::
sum_out
(
d_bias
,
d_output_features
,
{
0
},
false
);
at
::
sum_out
(
d_bias
,
d_output_features
,
{
0
},
false
);
auto
ip
=
weight
.
size
(
1
);
auto
op
=
weight
.
size
(
2
);
for
(
Int
i
=
0
;
i
<
(
Int
)
_rules
.
size
();
i
++
)
{
for
(
Int
i
=
0
;
i
<
(
Int
)
_rules
.
size
();
i
++
)
{
auto
r
=
_rules
[
i
];
auto
r
=
_rules
[
i
];
int
nRules
=
r
.
size
()
/
2
;
int
nRules
=
r
.
size
()
/
2
;
...
@@ -105,10 +102,8 @@ void cpu_Convolution_backward(
...
@@ -105,10 +102,8 @@ void cpu_Convolution_backward(
// at::mm_out(dw, input_rows.t(), d_output_rows);
// at::mm_out(dw, input_rows.t(), d_output_rows);
// auto d_input_rows = at::mm(d_output_rows, w.t());
// auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto
input_rows
=
at
::
empty
({
nRules
,
ip
},
input_features
.
type
());
auto
input_rows
=
rule_index_select
<
T
>
(
input_features
,
nRules
,
&
r
[
0
]);
rule_index_select
<
T
>
(
input_rows
,
input_features
,
nRules
,
&
r
[
0
]);
auto
d_output_rows
=
rule_index_select
<
T
>
(
d_output_features
,
nRules
,
&
r
[
1
]);
auto
d_output_rows
=
at
::
empty
({
nRules
,
op
},
d_output_features
.
type
());
rule_index_select
<
T
>
(
d_output_rows
,
d_output_features
,
nRules
,
&
r
[
1
]);
at
::
mm_out
(
dw
,
input_rows
.
t
(),
d_output_rows
);
at
::
mm_out
(
dw
,
input_rows
.
t
(),
d_output_rows
);
auto
d_input_rows
=
at
::
mm
(
d_output_rows
,
w
.
t
());
auto
d_input_rows
=
at
::
mm
(
d_output_rows
,
w
.
t
());
rule_index_add_
<
T
>
(
d_input_features
,
d_input_rows
,
nRules
,
&
r
[
0
]);
rule_index_add_
<
T
>
(
d_input_features
,
d_input_rows
,
nRules
,
&
r
[
0
]);
...
@@ -144,8 +139,7 @@ double cpu_SubmanifoldConvolution_updateOutput(
...
@@ -144,8 +139,7 @@ double cpu_SubmanifoldConvolution_updateOutput(
// auto w = weight.select(0, i);
// auto w = weight.select(0, i);
// auto output_rows = at::mm(input_rows, w);
// auto output_rows = at::mm(input_rows, w);
// output_features.index_add_(0, rt.select(1, 1), output_rows);
// output_features.index_add_(0, rt.select(1, 1), output_rows);
auto
input_rows
=
at
::
empty
({
nRules
,
ip
},
input_features
.
type
());
auto
input_rows
=
rule_index_select
<
T
>
(
input_features
,
nRules
,
&
r
[
0
]);
rule_index_select
<
T
>
(
input_rows
,
input_features
,
nRules
,
&
r
[
0
]);
auto
w
=
weight
.
select
(
0
,
i
);
auto
w
=
weight
.
select
(
0
,
i
);
auto
output_rows
=
at
::
mm
(
input_rows
,
w
);
auto
output_rows
=
at
::
mm
(
input_rows
,
w
);
rule_index_add_
<
T
>
(
output_features
,
output_rows
,
nRules
,
&
r
[
1
]);
rule_index_add_
<
T
>
(
output_features
,
output_rows
,
nRules
,
&
r
[
1
]);
...
@@ -171,8 +165,6 @@ void cpu_SubmanifoldConvolution_backward(
...
@@ -171,8 +165,6 @@ void cpu_SubmanifoldConvolution_backward(
if
(
nActive
and
d_bias
.
numel
())
if
(
nActive
and
d_bias
.
numel
())
at
::
sum_out
(
d_bias
,
d_output_features
,
{
0
},
false
);
at
::
sum_out
(
d_bias
,
d_output_features
,
{
0
},
false
);
auto
ip
=
weight
.
size
(
1
);
auto
op
=
weight
.
size
(
2
);
for
(
Int
i
=
0
;
i
<
(
Int
)
_rules
.
size
();
i
++
)
{
for
(
Int
i
=
0
;
i
<
(
Int
)
_rules
.
size
();
i
++
)
{
auto
r
=
_rules
[
i
];
auto
r
=
_rules
[
i
];
int
nRules
=
r
.
size
()
/
2
;
int
nRules
=
r
.
size
()
/
2
;
...
@@ -186,10 +178,8 @@ void cpu_SubmanifoldConvolution_backward(
...
@@ -186,10 +178,8 @@ void cpu_SubmanifoldConvolution_backward(
// at::mm_out(dw, input_rows.t(), d_output_rows);
// at::mm_out(dw, input_rows.t(), d_output_rows);
// auto d_input_rows = at::mm(d_output_rows, w.t());
// auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto
input_rows
=
at
::
empty
({
nRules
,
ip
},
input_features
.
type
());
auto
input_rows
=
rule_index_select
<
T
>
(
input_features
,
nRules
,
&
r
[
0
]);
rule_index_select
<
T
>
(
input_rows
,
input_features
,
nRules
,
&
r
[
0
]);
auto
d_output_rows
=
rule_index_select
<
T
>
(
d_output_features
,
nRules
,
&
r
[
1
]);
auto
d_output_rows
=
at
::
empty
({
nRules
,
op
},
d_output_features
.
type
());
rule_index_select
<
T
>
(
d_output_rows
,
d_output_features
,
nRules
,
&
r
[
1
]);
at
::
mm_out
(
dw
,
input_rows
.
t
(),
d_output_rows
);
at
::
mm_out
(
dw
,
input_rows
.
t
(),
d_output_rows
);
auto
d_input_rows
=
at
::
mm
(
d_output_rows
,
w
.
t
());
auto
d_input_rows
=
at
::
mm
(
d_output_rows
,
w
.
t
());
rule_index_add_
<
T
>
(
d_input_features
,
d_input_rows
,
nRules
,
&
r
[
0
]);
rule_index_add_
<
T
>
(
d_input_features
,
d_input_rows
,
nRules
,
&
r
[
0
]);
...
@@ -224,8 +214,7 @@ double cpu_PermutohedralSubmanifoldConvolution_updateOutput(
...
@@ -224,8 +214,7 @@ double cpu_PermutohedralSubmanifoldConvolution_updateOutput(
// auto w = weight.select(0, i);
// auto w = weight.select(0, i);
// auto output_rows = at::mm(input_rows, w);
// auto output_rows = at::mm(input_rows, w);
// output_features.index_add_(0, rt.select(1, 1), output_rows);
// output_features.index_add_(0, rt.select(1, 1), output_rows);
auto
input_rows
=
at
::
empty
({
nRules
,
ip
},
input_features
.
type
());
auto
input_rows
=
rule_index_select
<
T
>
(
input_features
,
nRules
,
&
r
[
0
]);
rule_index_select
<
T
>
(
input_rows
,
input_features
,
nRules
,
&
r
[
0
]);
auto
w
=
weight
.
select
(
0
,
i
);
auto
w
=
weight
.
select
(
0
,
i
);
auto
output_rows
=
at
::
mm
(
input_rows
,
w
);
auto
output_rows
=
at
::
mm
(
input_rows
,
w
);
rule_index_add_
<
T
>
(
output_features
,
output_rows
,
nRules
,
&
r
[
1
]);
rule_index_add_
<
T
>
(
output_features
,
output_rows
,
nRules
,
&
r
[
1
]);
...
@@ -250,8 +239,6 @@ void cpu_PermutohedralSubmanifoldConvolution_backward(
...
@@ -250,8 +239,6 @@ void cpu_PermutohedralSubmanifoldConvolution_backward(
if
(
nActive
and
d_bias
.
numel
())
if
(
nActive
and
d_bias
.
numel
())
at
::
sum_out
(
d_bias
,
d_output_features
,
{
0
},
false
);
at
::
sum_out
(
d_bias
,
d_output_features
,
{
0
},
false
);
auto
ip
=
weight
.
size
(
1
);
auto
op
=
weight
.
size
(
2
);
for
(
Int
i
=
0
;
i
<
(
Int
)
_rules
.
size
();
i
++
)
{
for
(
Int
i
=
0
;
i
<
(
Int
)
_rules
.
size
();
i
++
)
{
auto
r
=
_rules
[
i
];
auto
r
=
_rules
[
i
];
int
nRules
=
r
.
size
()
/
2
;
int
nRules
=
r
.
size
()
/
2
;
...
@@ -265,10 +252,8 @@ void cpu_PermutohedralSubmanifoldConvolution_backward(
...
@@ -265,10 +252,8 @@ void cpu_PermutohedralSubmanifoldConvolution_backward(
// at::mm_out(dw, input_rows.t(), d_output_rows);
// at::mm_out(dw, input_rows.t(), d_output_rows);
// auto d_input_rows = at::mm(d_output_rows, w.t());
// auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto
input_rows
=
at
::
empty
({
nRules
,
ip
},
input_features
.
type
());
auto
input_rows
=
rule_index_select
<
T
>
(
input_features
,
nRules
,
&
r
[
0
]);
rule_index_select
<
T
>
(
input_rows
,
input_features
,
nRules
,
&
r
[
0
]);
auto
d_output_rows
=
rule_index_select
<
T
>
(
d_output_features
,
nRules
,
&
r
[
1
]);
auto
d_output_rows
=
at
::
empty
({
nRules
,
op
},
d_output_features
.
type
());
rule_index_select
<
T
>
(
d_output_rows
,
d_output_features
,
nRules
,
&
r
[
1
]);
at
::
mm_out
(
dw
,
input_rows
.
t
(),
d_output_rows
);
at
::
mm_out
(
dw
,
input_rows
.
t
(),
d_output_rows
);
auto
d_input_rows
=
at
::
mm
(
d_output_rows
,
w
.
t
());
auto
d_input_rows
=
at
::
mm
(
d_output_rows
,
w
.
t
());
rule_index_add_
<
T
>
(
d_input_features
,
d_input_rows
,
nRules
,
&
r
[
0
]);
rule_index_add_
<
T
>
(
d_input_features
,
d_input_rows
,
nRules
,
&
r
[
0
]);
...
@@ -307,8 +292,7 @@ double cpu_FullConvolution_updateOutput(
...
@@ -307,8 +292,7 @@ double cpu_FullConvolution_updateOutput(
// auto w = weight.select(0, i);
// auto w = weight.select(0, i);
// auto output_rows = at::mm(input_rows, w);
// auto output_rows = at::mm(input_rows, w);
// output_features.index_add_(0, rt.select(1, 1), output_rows);
// output_features.index_add_(0, rt.select(1, 1), output_rows);
auto
input_rows
=
at
::
empty
({
nRules
,
ip
},
input_features
.
type
());
auto
input_rows
=
rule_index_select
<
T
>
(
input_features
,
nRules
,
&
r
[
0
]);
rule_index_select
<
T
>
(
input_rows
,
input_features
,
nRules
,
&
r
[
0
]);
auto
w
=
weight
.
select
(
0
,
i
);
auto
w
=
weight
.
select
(
0
,
i
);
auto
output_rows
=
at
::
mm
(
input_rows
,
w
);
auto
output_rows
=
at
::
mm
(
input_rows
,
w
);
rule_index_add_
<
T
>
(
output_features
,
output_rows
,
nRules
,
&
r
[
1
]);
rule_index_add_
<
T
>
(
output_features
,
output_rows
,
nRules
,
&
r
[
1
]);
...
@@ -337,8 +321,6 @@ void cpu_FullConvolution_backward(
...
@@ -337,8 +321,6 @@ void cpu_FullConvolution_backward(
if
(
nActive
and
d_bias
.
numel
())
if
(
nActive
and
d_bias
.
numel
())
at
::
sum_out
(
d_bias
,
d_output_features
,
{
0
},
false
);
at
::
sum_out
(
d_bias
,
d_output_features
,
{
0
},
false
);
auto
ip
=
weight
.
size
(
1
);
auto
op
=
weight
.
size
(
2
);
for
(
Int
i
=
0
;
i
<
(
Int
)
_rules
.
size
();
i
++
)
{
for
(
Int
i
=
0
;
i
<
(
Int
)
_rules
.
size
();
i
++
)
{
auto
r
=
_rules
[
i
];
auto
r
=
_rules
[
i
];
int
nRules
=
r
.
size
()
/
2
;
int
nRules
=
r
.
size
()
/
2
;
...
@@ -352,10 +334,8 @@ void cpu_FullConvolution_backward(
...
@@ -352,10 +334,8 @@ void cpu_FullConvolution_backward(
// at::mm_out(dw, input_rows.t(), d_output_rows);
// at::mm_out(dw, input_rows.t(), d_output_rows);
// auto d_input_rows = at::mm(d_output_rows, w.t());
// auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto
input_rows
=
at
::
empty
({
nRules
,
ip
},
input_features
.
type
());
auto
input_rows
=
rule_index_select
<
T
>
(
input_features
,
nRules
,
&
r
[
0
]);
rule_index_select
<
T
>
(
input_rows
,
input_features
,
nRules
,
&
r
[
0
]);
auto
d_output_rows
=
rule_index_select
<
T
>
(
d_output_features
,
nRules
,
&
r
[
1
]);
auto
d_output_rows
=
at
::
empty
({
nRules
,
op
},
d_output_features
.
type
());
rule_index_select
<
T
>
(
d_output_rows
,
d_output_features
,
nRules
,
&
r
[
1
]);
at
::
mm_out
(
dw
,
input_rows
.
t
(),
d_output_rows
);
at
::
mm_out
(
dw
,
input_rows
.
t
(),
d_output_rows
);
auto
d_input_rows
=
at
::
mm
(
d_output_rows
,
w
.
t
());
auto
d_input_rows
=
at
::
mm
(
d_output_rows
,
w
.
t
());
rule_index_add_
<
T
>
(
d_input_features
,
d_input_rows
,
nRules
,
&
r
[
0
]);
rule_index_add_
<
T
>
(
d_input_features
,
d_input_rows
,
nRules
,
&
r
[
0
]);
...
@@ -393,8 +373,7 @@ double cpu_RandomizedStrideConvolution_updateOutput(
...
@@ -393,8 +373,7 @@ double cpu_RandomizedStrideConvolution_updateOutput(
// auto w = weight.select(0, i);
// auto w = weight.select(0, i);
// auto output_rows = at::mm(input_rows, w);
// auto output_rows = at::mm(input_rows, w);
// output_features.index_add_(0, rt.select(1, 1), output_rows);
// output_features.index_add_(0, rt.select(1, 1), output_rows);
auto
input_rows
=
at
::
empty
({
nRules
,
ip
},
input_features
.
type
());
auto
input_rows
=
rule_index_select
<
T
>
(
input_features
,
nRules
,
&
r
[
0
]);
rule_index_select
<
T
>
(
input_rows
,
input_features
,
nRules
,
&
r
[
0
]);
auto
w
=
weight
.
select
(
0
,
i
);
auto
w
=
weight
.
select
(
0
,
i
);
auto
output_rows
=
at
::
mm
(
input_rows
,
w
);
auto
output_rows
=
at
::
mm
(
input_rows
,
w
);
rule_index_add_
<
T
>
(
output_features
,
output_rows
,
nRules
,
&
r
[
1
]);
rule_index_add_
<
T
>
(
output_features
,
output_rows
,
nRules
,
&
r
[
1
]);
...
@@ -421,8 +400,6 @@ void cpu_RandomizedStrideConvolution_backward(
...
@@ -421,8 +400,6 @@ void cpu_RandomizedStrideConvolution_backward(
if
(
nActive
and
d_bias
.
numel
())
if
(
nActive
and
d_bias
.
numel
())
at
::
sum_out
(
d_bias
,
d_output_features
,
{
0
},
false
);
at
::
sum_out
(
d_bias
,
d_output_features
,
{
0
},
false
);
auto
ip
=
weight
.
size
(
1
);
auto
op
=
weight
.
size
(
2
);
for
(
Int
i
=
0
;
i
<
(
Int
)
_rules
.
size
();
i
++
)
{
for
(
Int
i
=
0
;
i
<
(
Int
)
_rules
.
size
();
i
++
)
{
auto
r
=
_rules
[
i
];
auto
r
=
_rules
[
i
];
int
nRules
=
r
.
size
()
/
2
;
int
nRules
=
r
.
size
()
/
2
;
...
@@ -436,10 +413,8 @@ void cpu_RandomizedStrideConvolution_backward(
...
@@ -436,10 +413,8 @@ void cpu_RandomizedStrideConvolution_backward(
// at::mm_out(dw, input_rows.t(), d_output_rows);
// at::mm_out(dw, input_rows.t(), d_output_rows);
// auto d_input_rows = at::mm(d_output_rows, w.t());
// auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
// d_input_features.index_add_(0, rt.select(1, 0), d_input_rows);
auto
input_rows
=
at
::
empty
({
nRules
,
ip
},
input_features
.
type
());
auto
input_rows
=
rule_index_select
<
T
>
(
input_features
,
nRules
,
&
r
[
0
]);
rule_index_select
<
T
>
(
input_rows
,
input_features
,
nRules
,
&
r
[
0
]);
auto
d_output_rows
=
rule_index_select
<
T
>
(
d_output_features
,
nRules
,
&
r
[
1
]);
auto
d_output_rows
=
at
::
empty
({
nRules
,
op
},
d_output_features
.
type
());
rule_index_select
<
T
>
(
d_output_rows
,
d_output_features
,
nRules
,
&
r
[
1
]);
at
::
mm_out
(
dw
,
input_rows
.
t
(),
d_output_rows
);
at
::
mm_out
(
dw
,
input_rows
.
t
(),
d_output_rows
);
auto
d_input_rows
=
at
::
mm
(
d_output_rows
,
w
.
t
());
auto
d_input_rows
=
at
::
mm
(
d_output_rows
,
w
.
t
());
rule_index_add_
<
T
>
(
d_input_features
,
d_input_rows
,
nRules
,
&
r
[
0
]);
rule_index_add_
<
T
>
(
d_input_features
,
d_input_rows
,
nRules
,
&
r
[
0
]);
...
...
sparseconvnet/SCN/CPU/Deconvolution.cpp
View file @
556a8c06
...
@@ -34,8 +34,7 @@ double cpu_Deconvolution_updateOutput(
...
@@ -34,8 +34,7 @@ double cpu_Deconvolution_updateOutput(
// auto w = weight.select(0, i);
// auto w = weight.select(0, i);
// auto output_rows = at::mm(input_rows, w);
// auto output_rows = at::mm(input_rows, w);
// output_features.index_add_(0, rt.select(1, 0), output_rows);
// output_features.index_add_(0, rt.select(1, 0), output_rows);
auto
input_rows
=
at
::
empty
({
nRules
,
ip
},
input_features
.
type
());
auto
input_rows
=
rule_index_select
<
T
>
(
input_features
,
nRules
,
&
r
[
1
]);
rule_index_select
<
T
>
(
input_rows
,
input_features
,
nRules
,
&
r
[
1
]);
auto
w
=
weight
.
select
(
0
,
i
);
auto
w
=
weight
.
select
(
0
,
i
);
auto
output_rows
=
at
::
mm
(
input_rows
,
w
);
auto
output_rows
=
at
::
mm
(
input_rows
,
w
);
rule_index_add_
<
T
>
(
output_features
,
output_rows
,
nRules
,
&
r
[
0
]);
rule_index_add_
<
T
>
(
output_features
,
output_rows
,
nRules
,
&
r
[
0
]);
...
@@ -62,8 +61,6 @@ void cpu_Deconvolution_backward(
...
@@ -62,8 +61,6 @@ void cpu_Deconvolution_backward(
if
(
nActive
and
d_bias
.
numel
())
if
(
nActive
and
d_bias
.
numel
())
at
::
sum_out
(
d_bias
,
d_output_features
,
{
0
},
false
);
at
::
sum_out
(
d_bias
,
d_output_features
,
{
0
},
false
);
auto
ip
=
weight
.
size
(
1
);
auto
op
=
weight
.
size
(
2
);
for
(
Int
i
=
0
;
i
<
(
Int
)
_rules
.
size
();
i
++
)
{
for
(
Int
i
=
0
;
i
<
(
Int
)
_rules
.
size
();
i
++
)
{
auto
r
=
_rules
[
i
];
auto
r
=
_rules
[
i
];
int
nRules
=
r
.
size
()
/
2
;
int
nRules
=
r
.
size
()
/
2
;
...
@@ -77,10 +74,8 @@ void cpu_Deconvolution_backward(
...
@@ -77,10 +74,8 @@ void cpu_Deconvolution_backward(
// at::mm_out(dw, input_rows.t(), d_output_rows);
// at::mm_out(dw, input_rows.t(), d_output_rows);
// auto d_input_rows = at::mm(d_output_rows, w.t());
// auto d_input_rows = at::mm(d_output_rows, w.t());
// d_input_features.index_add_(0, rt.select(1, 1), d_input_rows);
// d_input_features.index_add_(0, rt.select(1, 1), d_input_rows);
auto
input_rows
=
at
::
empty
({
nRules
,
ip
},
d_output_features
.
type
());
auto
input_rows
=
rule_index_select
<
T
>
(
input_features
,
nRules
,
&
r
[
1
]);
rule_index_select
<
T
>
(
input_rows
,
input_features
,
nRules
,
&
r
[
1
]);
auto
d_output_rows
=
rule_index_select
<
T
>
(
d_output_features
,
nRules
,
&
r
[
0
]);
auto
d_output_rows
=
at
::
empty
({
nRules
,
op
},
d_output_features
.
type
());
rule_index_select
<
T
>
(
d_output_rows
,
d_output_features
,
nRules
,
&
r
[
0
]);
at
::
mm_out
(
dw
,
input_rows
.
t
(),
d_output_rows
);
at
::
mm_out
(
dw
,
input_rows
.
t
(),
d_output_rows
);
auto
d_input_rows
=
at
::
mm
(
d_output_rows
,
w
.
t
());
auto
d_input_rows
=
at
::
mm
(
d_output_rows
,
w
.
t
());
rule_index_add_
<
T
>
(
d_input_features
,
d_input_rows
,
nRules
,
&
r
[
1
]);
rule_index_add_
<
T
>
(
d_input_features
,
d_input_rows
,
nRules
,
&
r
[
1
]);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment