Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
de11bfdf
Commit
de11bfdf
authored
Mar 11, 2018
by
rusty1s
Browse files
swap m_in/m_out loop
parent
1de99c93
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
8 deletions
+9
-8
torch_spline_conv/functions/spline_conv.py
torch_spline_conv/functions/spline_conv.py
+0
-1
torch_spline_conv/functions/utils.py
torch_spline_conv/functions/utils.py
+1
-1
torch_spline_conv/src/generic/cpu.c
torch_spline_conv/src/generic/cpu.c
+8
-6
No files found.
torch_spline_conv/functions/spline_conv.py
View file @
de11bfdf
...
@@ -17,7 +17,6 @@ def spline_conv(x,
...
@@ -17,7 +17,6 @@ def spline_conv(x,
print
(
'TODO: Degree of 0'
)
print
(
'TODO: Degree of 0'
)
print
(
'TODO: Kernel size of 1'
)
print
(
'TODO: Kernel size of 1'
)
print
(
'swap M_in and M_out in backward implementation'
)
n
,
e
=
x
.
size
(
0
),
edge_index
.
size
(
1
)
n
,
e
=
x
.
size
(
0
),
edge_index
.
size
(
1
)
K
,
m_in
,
m_out
=
weight
.
size
()
K
,
m_in
,
m_out
=
weight
.
size
()
...
...
torch_spline_conv/functions/utils.py
View file @
de11bfdf
...
@@ -36,7 +36,7 @@ def spline_weighting_forward(x, weight, basis, weight_index):
...
@@ -36,7 +36,7 @@ def spline_weighting_forward(x, weight, basis, weight_index):
def
spline_weighting_backward
(
grad_output
,
x
,
weight
,
basis
,
weight_index
):
def
spline_weighting_backward
(
grad_output
,
x
,
weight
,
basis
,
weight_index
):
grad_input
=
x
.
new
(
x
.
size
(
0
),
weight
.
size
(
1
))
.
fill_
(
0
)
grad_input
=
x
.
new
(
x
.
size
(
0
),
weight
.
size
(
1
))
grad_weight
=
x
.
new
(
weight
.
size
()).
fill_
(
0
)
grad_weight
=
x
.
new
(
weight
.
size
()).
fill_
(
0
)
func
=
get_func
(
'weighting_backward'
,
x
)
func
=
get_func
(
'weighting_backward'
,
x
)
func
(
grad_input
,
grad_weight
,
grad_output
,
x
,
weight
,
basis
,
weight_index
)
func
(
grad_input
,
grad_weight
,
grad_output
,
x
,
weight
,
basis
,
weight_index
)
...
...
torch_spline_conv/src/generic/cpu.c
View file @
de11bfdf
...
@@ -78,20 +78,22 @@ void spline_(weighting_backward)(THTensor *grad_input, THTensor *grad_weight, TH
...
@@ -78,20 +78,22 @@ void spline_(weighting_backward)(THTensor *grad_input, THTensor *grad_weight, TH
int64_t
M_out
=
THTensor_
(
size
)(
grad_output
,
1
);
int64_t
M_out
=
THTensor_
(
size
)(
grad_output
,
1
);
int64_t
M_in
=
THTensor_
(
size
)(
input
,
1
);
int64_t
M_in
=
THTensor_
(
size
)(
input
,
1
);
int64_t
S
=
THLongTensor_size
(
weight_index
,
1
);
int64_t
S
=
THLongTensor_size
(
weight_index
,
1
);
int64_t
m_out
,
m_in
,
s
,
i
,
w_idx
;
real
g
,
b
;
int64_t
m_out
,
m_in
,
s
,
i
,
w_idx
;
real
g
_in
,
value
,
b
,
g_out
;
TH_TENSOR_DIM_APPLY5
(
real
,
grad_input
,
real
,
grad_output
,
real
,
input
,
real
,
basis
,
int64_t
,
weight_index
,
1
,
TH_TENSOR_DIM_APPLY5
(
real
,
grad_input
,
real
,
grad_output
,
real
,
input
,
real
,
basis
,
int64_t
,
weight_index
,
1
,
for
(
m_
out
=
0
;
m_
out
<
M_
out
;
m_
out
++
)
{
for
(
m_
in
=
0
;
m_
in
<
M_
in
;
m_
in
++
)
{
g
=
*
(
grad_output_data
+
m_out
*
grad_out
put_stride
);
g
_in
=
0
;
value
=
*
(
input_data
+
m_in
*
in
put_stride
);
for
(
s
=
0
;
s
<
S
;
s
++
)
{
for
(
s
=
0
;
s
<
S
;
s
++
)
{
b
=
*
(
basis_data
+
s
*
basis_stride
);
b
=
*
(
basis_data
+
s
*
basis_stride
);
i
=
*
(
weight_index_data
+
s
*
weight_index_stride
);
i
=
*
(
weight_index_data
+
s
*
weight_index_stride
);
for
(
m_
in
=
0
;
m_
in
<
M_
in
;
m_
in
++
)
{
for
(
m_
out
=
0
;
m_
out
<
M_
out
;
m_
out
++
)
{
w_idx
=
i
*
M_in
*
M_out
+
m_in
*
M_out
+
m_out
;
w_idx
=
i
*
M_in
*
M_out
+
m_in
*
M_out
+
m_out
;
grad_input_data
[
m_in
]
+=
b
*
g
*
*
(
weight_data
+
w_idx
);
g_out
=
*
(
grad_output_data
+
m_out
*
grad_output_stride
);
grad_weight_data
[
w_idx
]
+=
b
*
g
*
*
(
input_data
+
m_in
*
input_stride
);
grad_weight_data
[
w_idx
]
+=
b
*
g_out
*
value
;
g_in
+=
b
*
g_out
*
*
(
weight_data
+
w_idx
);
}
}
}
}
grad_input_data
[
m_in
]
=
g_in
;
}
}
)
)
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment