Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
fbd0ffaf
Commit
fbd0ffaf
authored
Mar 10, 2018
by
rusty1s
Browse files
added backward
parent
3fbeeabc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
2 deletions
+22
-2
torch_spline_conv/functions/utils.py
torch_spline_conv/functions/utils.py
+2
-2
torch_spline_conv/src/generic/cpu.c
torch_spline_conv/src/generic/cpu.c
+20
-0
No files found.
torch_spline_conv/functions/utils.py
View file @
fbd0ffaf
...
@@ -36,8 +36,8 @@ def spline_weighting_forward(x, weight, basis, weight_index):
...
@@ -36,8 +36,8 @@ def spline_weighting_forward(x, weight, basis, weight_index):
def
spline_weighting_backward
(
grad_output
,
x
,
weight
,
basis
,
weight_index
):
def
spline_weighting_backward
(
grad_output
,
x
,
weight
,
basis
,
weight_index
):
grad_input
=
x
.
new
(
x
.
size
(
0
),
weight
.
size
(
1
))
grad_input
=
x
.
new
(
x
.
size
(
0
),
weight
.
size
(
1
))
.
fill_
(
0
)
grad_weight
=
x
.
new
(
weight
)
grad_weight
=
x
.
new
(
weight
)
.
fill_
(
0
)
func
=
get_func
(
'weighting_backward'
,
x
)
func
=
get_func
(
'weighting_backward'
,
x
)
func
(
grad_input
,
grad_weight
,
grad_output
,
x
,
weight
,
basis
,
weight_index
)
func
(
grad_input
,
grad_weight
,
grad_output
,
x
,
weight
,
basis
,
weight_index
)
return
grad_input
,
grad_weight
return
grad_input
,
grad_weight
...
...
torch_spline_conv/src/generic/cpu.c
View file @
fbd0ffaf
...
@@ -73,7 +73,27 @@ void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *wei
...
@@ -73,7 +73,27 @@ void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *wei
}
}
void
spline_
(
weighting_backward
)(
THTensor
*
grad_input
,
THTensor
*
grad_weight
,
THTensor
*
grad_output
,
THTensor
*
input
,
THTensor
*
weight
,
THTensor
*
basis
,
THLongTensor
*
weight_index
)
{
void
spline_
(
weighting_backward
)(
THTensor
*
grad_input
,
THTensor
*
grad_weight
,
THTensor
*
grad_output
,
THTensor
*
input
,
THTensor
*
weight
,
THTensor
*
basis
,
THLongTensor
*
weight_index
)
{
real
*
weight_data
=
weight
->
storage
->
data
+
weight
->
storageOffset
;
real
*
grad_weight_data
=
grad_weight
->
storage
->
data
+
grad_weight
->
storageOffset
;
int64_t
M_out
=
THTensor_
(
size
)(
grad_input
,
1
);
int64_t
M_in
=
THTensor_
(
size
)(
input
,
1
);
int64_t
S
=
THLongTensor_size
(
weight_index
,
1
);
int64_t
m_out
,
m_in
,
s
,
i
,
w_idx
;
real
g
,
b
;
TH_TENSOR_DIM_APPLY5
(
real
,
grad_input
,
real
,
grad_output
,
real
,
input
,
real
,
basis
,
int64_t
,
weight_index
,
1
,
TH_TENSOR_DIM_APPLY5
(
real
,
grad_input
,
real
,
grad_output
,
real
,
input
,
real
,
basis
,
int64_t
,
weight_index
,
1
,
for
(
m_out
=
0
;
m_out
<
M_out
;
m_out
++
)
{
g
=
*
(
grad_output_data
+
m_out
*
grad_output_stride
);
for
(
s
=
0
;
s
<
S
;
s
++
)
{
b
=
*
(
basis_data
+
s
*
basis_stride
);
i
=
*
(
weight_index_data
+
s
*
weight_index_stride
);
for
(
m_in
=
0
;
m_in
<
M_in
;
m_in
++
)
{
w_idx
=
i
*
M_in
*
M_out
+
m_in
*
M_out
+
m_out
;
grad_input_data
[
m_in
]
+=
b
*
g
*
*
(
weight_data
+
w_idx
);
grad_weight_data
[
w_idx
]
+=
b
*
g
*
*
(
input_data
+
m_in
*
input_stride
);
}
}
}
)
)
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment