Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
d7a83c01
"examples/controlnet/train_controlnet_sd3.py" did not exist on "e7534542a2e736ab54328a7fb3a0a15fe4f31da2"
Commit
d7a83c01
authored
Mar 12, 2018
by
rusty1s
Browse files
bugfixes
parent
67904212
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
49 additions
and
34 deletions
+49
-34
test/test_spline_conv.py
test/test_spline_conv.py
+0
-1
torch_spline_conv/functions/ffi.py
torch_spline_conv/functions/ffi.py
+8
-8
torch_spline_conv/functions/spline_weighting.py
torch_spline_conv/functions/spline_weighting.py
+23
-9
torch_spline_conv/src/cpu.h
torch_spline_conv/src/cpu.h
+3
-2
torch_spline_conv/src/generic/cpu.c
torch_spline_conv/src/generic/cpu.c
+15
-14
No files found.
test/test_spline_conv.py
View file @
d7a83c01
...
@@ -48,7 +48,6 @@ def test_spline_conv_cpu(tensor):
...
@@ -48,7 +48,6 @@ def test_spline_conv_cpu(tensor):
def
test_spline_weighting_backward_cpu
():
def
test_spline_weighting_backward_cpu
():
return
kernel_size
=
torch
.
LongTensor
([
5
,
5
])
kernel_size
=
torch
.
LongTensor
([
5
,
5
])
is_open_spline
=
torch
.
ByteTensor
([
1
,
1
])
is_open_spline
=
torch
.
ByteTensor
([
1
,
1
])
op
=
SplineWeighting
(
kernel_size
,
is_open_spline
,
1
)
op
=
SplineWeighting
(
kernel_size
,
is_open_spline
,
1
)
...
...
torch_spline_conv/functions/ffi.py
View file @
d7a83c01
...
@@ -40,17 +40,17 @@ def spline_weighting_backward_input(grad_output, weight, basis, weight_index):
...
@@ -40,17 +40,17 @@ def spline_weighting_backward_input(grad_output, weight, basis, weight_index):
return
grad_input
return
grad_input
# pragma: no cover
def
spline_weighting_backward_weight
(
grad_output
,
x
,
basis
,
weight_index
,
K
):
grad_weight
=
x
.
new
(
K
,
x
.
size
(
1
),
grad_output
.
size
(
1
)).
fill_
(
0
)
func
=
get_func
(
'weighting_backward_weight'
,
x
)
func
(
grad_weight
,
grad_output
,
x
,
basis
,
weight_index
)
return
grad_weight
# pragma: no cover
# pragma: no cover
def
spline_weighting_backward_basis
(
grad_output
,
x
,
weight
,
weight_index
):
def
spline_weighting_backward_basis
(
grad_output
,
x
,
weight
,
weight_index
):
grad_basis
=
x
.
new
(
weight_index
.
size
())
grad_basis
=
x
.
new
(
weight_index
.
size
())
func
=
get_func
(
'weighting_backward_basis'
,
x
)
func
=
get_func
(
'weighting_backward_basis'
,
x
)
func
(
grad_basis
,
grad_output
,
x
,
weight
,
weight_index
)
func
(
grad_basis
,
grad_output
,
x
,
weight
,
weight_index
)
return
grad_basis
return
grad_basis
# pragma: no cover
def
spline_weighting_backward_weight
(
grad_output
,
x
,
basis
,
weight_index
,
K
):
grad_weight
=
x
.
new
(
K
,
x
.
size
(
1
),
grad_output
.
size
(
1
)).
fill_
(
0
)
func
=
get_func
(
'weighting_backward_weight'
,
x
)
func
(
grad_weight
,
grad_output
,
x
,
basis
,
weight_index
)
return
grad_weight
torch_spline_conv/functions/spline_weighting.py
View file @
d7a83c01
...
@@ -3,8 +3,8 @@ from torch.autograd import Function
...
@@ -3,8 +3,8 @@ from torch.autograd import Function
from
.ffi
import
(
spline_basis_forward
,
spline_weighting_forward
,
from
.ffi
import
(
spline_basis_forward
,
spline_weighting_forward
,
spline_weighting_backward_input
,
spline_weighting_backward_input
,
spline_weighting_backward_
weight
,
spline_weighting_backward_
basis
,
spline_weighting_backward_
basis
)
spline_weighting_backward_
weight
)
class
SplineWeighting
(
Function
):
class
SplineWeighting
(
Function
):
...
@@ -20,17 +20,31 @@ class SplineWeighting(Function):
...
@@ -20,17 +20,31 @@ class SplineWeighting(Function):
self
.
degree
,
pseudo
,
self
.
kernel_size
,
self
.
is_open_spline
,
K
)
self
.
degree
,
pseudo
,
self
.
kernel_size
,
self
.
is_open_spline
,
K
)
output
=
spline_weighting_forward
(
x
,
weight
,
basis
,
weight_index
)
output
=
spline_weighting_forward
(
x
,
weight
,
basis
,
weight_index
)
#
self.save_for_backward(x, weight)
self
.
save_for_backward
(
x
,
weight
)
#
self.basis, self.weight_index = basis, weight_index
self
.
basis
,
self
.
weight_index
=
basis
,
weight_index
return
output
return
output
def
backward
(
self
,
grad_output
):
# pragma: no cover
def
backward
(
self
,
grad_output
):
# pragma: no cover
pass
x
,
weight
=
self
.
saved_tensors
# x, weight = self.saved_tensors
basis
,
weight_index
=
self
.
basis
,
self
.
weight_index
# grad_input, grad_weight = spline_weighting_backward(
grad_input
,
grad_pseudo
,
grad_weight
=
None
,
None
,
None
# grad_output, x, weight, self.basis, self.weight_index)
# return grad_input, None, grad_weight
if
self
.
needs_input_grad
[
0
]:
grad_input
=
spline_weighting_backward_input
(
grad_output
,
weight
,
basis
,
weight_index
)
if
self
.
needs_input_grad
[
1
]:
grad_basis
=
spline_weighting_backward_basis
(
grad_output
,
x
,
weight
,
weight_index
)
print
(
'pseudo needs grad'
)
if
self
.
needs_input_grad
[
2
]:
K
=
weight
.
size
(
0
)
grad_weight
=
spline_weighting_backward_weight
(
grad_output
,
x
,
basis
,
weight_index
,
K
)
return
grad_input
,
grad_pseudo
,
grad_weight
def
spline_weighting
(
x
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
degree
):
def
spline_weighting
(
x
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
degree
):
...
...
torch_spline_conv/src/cpu.h
View file @
d7a83c01
...
@@ -13,8 +13,9 @@ void spline_weighting_forward_Double(THDoubleTensor *output, THDoubleTensor *inp
...
@@ -13,8 +13,9 @@ void spline_weighting_forward_Double(THDoubleTensor *output, THDoubleTensor *inp
void
spline_weighting_backward_input_Float
(
THFloatTensor
*
grad_input
,
THFloatTensor
*
grad_output
,
THFloatTensor
*
weight
,
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_input_Float
(
THFloatTensor
*
grad_input
,
THFloatTensor
*
grad_output
,
THFloatTensor
*
weight
,
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_input_Double
(
THDoubleTensor
*
grad_input
,
THDoubleTensor
*
grad_output
,
THDoubleTensor
*
weight
,
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_input_Double
(
THDoubleTensor
*
grad_input
,
THDoubleTensor
*
grad_output
,
THDoubleTensor
*
weight
,
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_basis_Float
(
THFloatTensor
*
grad_basis
,
THFloatTensor
*
grad_output
,
THFloatTensor
*
input
,
THFloatTensor
*
weight
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_basis_Double
(
THDoubleTensor
*
grad_basis
,
THDoubleTensor
*
grad_output
,
THDoubleTensor
*
input
,
THDoubleTensor
*
weight
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_weight_Float
(
THFloatTensor
*
grad_weight
,
THFloatTensor
*
grad_output
,
THFloatTensor
*
input
,
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_weight_Float
(
THFloatTensor
*
grad_weight
,
THFloatTensor
*
grad_output
,
THFloatTensor
*
input
,
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_weight_Double
(
THDoubleTensor
*
grad_weight
,
THDoubleTensor
*
grad_output
,
THDoubleTensor
*
input
,
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_weight_Double
(
THDoubleTensor
*
grad_weight
,
THDoubleTensor
*
grad_output
,
THDoubleTensor
*
input
,
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_basis_Float
(
THFloatTensor
*
grad_basis
,
THFloatTensor
*
grad_output
,
THFloatTensor
*
input
,
THFloatTensor
*
weight
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_basis_Double
(
THDoubleTensor
*
grad_basis
,
THDoubleTensor
*
grad_output
,
THDoubleTensor
*
input
,
THDoubleTensor
*
weight
,
THLongTensor
*
weight_index
);
torch_spline_conv/src/generic/cpu.c
View file @
d7a83c01
...
@@ -49,7 +49,7 @@ void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *wei
...
@@ -49,7 +49,7 @@ void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *wei
void
spline_
(
weighting_backward_input
)(
THTensor
*
grad_input
,
THTensor
*
grad_output
,
THTensor
*
weight
,
THTensor
*
basis
,
THLongTensor
*
weight_index
)
{
void
spline_
(
weighting_backward_input
)(
THTensor
*
grad_input
,
THTensor
*
grad_output
,
THTensor
*
weight
,
THTensor
*
basis
,
THLongTensor
*
weight_index
)
{
real
*
weight_data
=
weight
->
storage
->
data
+
weight
->
storageOffset
;
real
b
;
real
*
weight_data
=
weight
->
storage
->
data
+
weight
->
storageOffset
;
real
b
;
SPLINE_WEIGHTING_BACKWARD
(
grad_input
,
grad_output
,
basis
,
weight_index
,
THTensor_
(
size
)(
grad_inpu
t
,
1
),
THTensor_
(
size
)(
grad_outpu
t
,
1
),
THLongTensor_size
(
weight_index
,
1
),
SPLINE_WEIGHTING_BACKWARD
(
grad_input
,
grad_output
,
basis
,
weight_index
,
THTensor_
(
size
)(
weigh
t
,
1
),
THTensor_
(
size
)(
weigh
t
,
2
),
THLongTensor_size
(
weight_index
,
1
),
for
(
m_in
=
0
;
m_in
<
M_in
;
m_in
++
)
{
for
(
m_in
=
0
;
m_in
<
M_in
;
m_in
++
)
{
value
=
0
;
value
=
0
;
for
(
s
=
0
;
s
<
S
;
s
++
)
{
for
(
s
=
0
;
s
<
S
;
s
++
)
{
...
@@ -64,35 +64,36 @@ void spline_(weighting_backward_input)(THTensor *grad_input, THTensor *grad_outp
...
@@ -64,35 +64,36 @@ void spline_(weighting_backward_input)(THTensor *grad_input, THTensor *grad_outp
)
)
}
}
void
spline_
(
weighting_backward_
weight
)(
THTensor
*
grad_
weight
,
THTensor
*
grad_output
,
THTensor
*
input
,
THTensor
*
basis
,
THLongTensor
*
weight_index
)
{
void
spline_
(
weighting_backward_
basis
)(
THTensor
*
grad_
basis
,
THTensor
*
grad_output
,
THTensor
*
input
,
THTensor
*
weight
,
THLongTensor
*
weight_index
)
{
real
*
grad_
weight_data
=
grad_
weight
->
storage
->
data
+
grad_
weight
->
storageOffset
;
real
b
;
real
*
weight_data
=
weight
->
storage
->
data
+
weight
->
storageOffset
;
SPLINE_WEIGHTING_BACKWARD
(
grad_output
,
input
,
basis
,
weight_index
,
THTensor_
(
size
)(
grad_outpu
t
,
1
),
THTensor_
(
size
)(
inpu
t
,
1
),
THLongTensor_size
(
weight_index
,
1
),
SPLINE_WEIGHTING_BACKWARD
(
grad_basis
,
grad_output
,
input
,
weight_index
,
THTensor_
(
size
)(
weigh
t
,
1
),
THTensor_
(
size
)(
weigh
t
,
2
),
THLongTensor_size
(
weight_index
,
1
),
for
(
m_out
=
0
;
m_out
<
M_out
;
m_out
++
)
{
for
(
m_out
=
0
;
m_out
<
M_out
;
m_out
++
)
{
value
=
*
(
grad_output_data
+
m_out
*
grad_output_stride
);
for
(
s
=
0
;
s
<
S
;
s
++
)
{
for
(
s
=
0
;
s
<
S
;
s
++
)
{
b
=
*
(
basis_data
+
s
*
basis_stride
);
w_idx
=
*
(
weight_index_data
+
s
*
weight_index_stride
);
value
=
0
;
w_idx
=
*
(
weight_index_data
+
s
*
weight_index_stride
);
for
(
m_in
=
0
;
m_in
<
M_in
;
m_in
++
)
{
for
(
m_in
=
0
;
m_in
<
M_in
;
m_in
++
)
{
grad_
weight_data
[
w_idx
*
M_in
*
M_out
+
m_in
*
M_out
+
m_out
]
+=
b
*
value
*
*
(
input_data
+
m_in
*
input_stride
);
value
+=
*
(
input_data
+
m_in
*
input_stride
)
*
*
(
weight_data
+
w_idx
*
M_in
*
M_out
+
m_in
*
M_out
+
m_out
);
}
}
grad_basis_data
[
s
]
+=
value
*
*
(
grad_output_data
+
m_out
*
grad_output_stride
);
}
}
}
}
)
)
}
}
void
spline_
(
weighting_backward_
basis
)(
THTensor
*
grad_
basis
,
THTensor
*
grad_output
,
THTensor
*
input
,
THTensor
*
weight
,
THLongTensor
*
weight_index
)
{
void
spline_
(
weighting_backward_
weight
)(
THTensor
*
grad_
weight
,
THTensor
*
grad_output
,
THTensor
*
input
,
THTensor
*
basis
,
THLongTensor
*
weight_index
)
{
real
*
weight_data
=
weight
->
storage
->
data
+
weight
->
storageOffset
;
real
*
grad_
weight_data
=
grad_
weight
->
storage
->
data
+
grad_
weight
->
storageOffset
;
real
b
;
SPLINE_WEIGHTING_BACKWARD
(
grad_basis
,
grad_output
,
input
,
weight_index
,
THTensor_
(
size
)(
grad_out
put
,
1
),
THTensor_
(
size
)(
in
put
,
1
),
THLongTensor_size
(
weight_index
,
1
),
SPLINE_WEIGHTING_BACKWARD
(
grad_output
,
input
,
basis
,
weight_index
,
THTensor_
(
size
)(
in
put
,
1
),
THTensor_
(
size
)(
grad_out
put
,
1
),
THLongTensor_size
(
weight_index
,
1
),
for
(
m_out
=
0
;
m_out
<
M_out
;
m_out
++
)
{
for
(
m_out
=
0
;
m_out
<
M_out
;
m_out
++
)
{
value
=
*
(
grad_output_data
+
m_out
*
grad_output_stride
);
for
(
s
=
0
;
s
<
S
;
s
++
)
{
for
(
s
=
0
;
s
<
S
;
s
++
)
{
w_idx
=
*
(
weight_index_data
+
s
*
weight_index_stride
);
value
=
0
;
b
=
*
(
basis_data
+
s
*
basis_stride
);
w_idx
=
*
(
weight_index_data
+
s
*
weight_index_stride
);
for
(
m_in
=
0
;
m_in
<
M_in
;
m_in
++
)
{
for
(
m_in
=
0
;
m_in
<
M_in
;
m_in
++
)
{
value
+=
*
(
input_data
+
m_in
*
input_stride
)
*
*
(
weight_data
+
w_idx
*
M_in
*
M_out
+
m_in
*
M_out
+
m_out
);
grad_
weight_data
[
w_idx
*
M_in
*
M_out
+
m_in
*
M_out
+
m_out
]
+=
b
*
value
*
*
(
input_data
+
m_in
*
input_stride
);
}
}
grad_basis_data
[
s
]
+=
value
*
*
(
grad_output_data
+
m_out
*
grad_output_stride
);
}
}
}
}
)
)
}
}
#endif
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment