Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
9362b2d3
Commit
9362b2d3
authored
Apr 11, 2018
by
rusty1s
Browse files
added cpu weighting backward
parent
6146660b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
103 additions
and
9 deletions
+103
-9
aten/TH/generic/THWeighting.c
aten/TH/generic/THWeighting.c
+80
-4
test/test_weighting.py
test/test_weighting.py
+21
-3
torch_spline_conv/weighting.py
torch_spline_conv/weighting.py
+2
-2
No files found.
aten/TH/generic/THWeighting.c
View file @
9362b2d3
...
...
@@ -11,16 +11,18 @@ void THTensor_(weightingForward)(THTensor *self, THTensor *src, THTensor *weight
int64_t
*
weightIndexData
=
THLongTensor_data
(
weightIndex
);
ptrdiff_t
e
,
mOut
,
s
,
mIn
;
real
v
,
b
;
real
v
,
b
,
tmp
;
int64_t
wi
;
for
(
e
=
0
;
e
<
THTensor_
(
size
)(
src
,
0
);
e
++
)
{
for
(
mOut
=
0
;
mOut
<
THTensor_
(
size
)(
weight
,
2
);
mOut
++
)
{
for
(
mOut
=
0
;
mOut
<
THTensor_
(
size
)(
self
,
1
);
mOut
++
)
{
v
=
0
;
for
(
s
=
0
;
s
<
THTensor_
(
size
)(
basis
,
1
);
s
++
)
{
b
=
basisData
[
e
*
basis
->
stride
[
0
]
+
s
*
basis
->
stride
[
1
]];
wi
=
weightIndexData
[
e
*
weightIndex
->
stride
[
0
]
+
s
*
weightIndex
->
stride
[
1
]];
for
(
mIn
=
0
;
mIn
<
THTensor_
(
size
)(
weight
,
1
);
mIn
++
)
{
v
+=
b
*
weightData
[
wi
*
weight
->
stride
[
0
]
+
mIn
*
weight
->
stride
[
1
]
+
mOut
*
weight
->
stride
[
2
]]
*
srcData
[
e
*
src
->
stride
[
0
]
+
mIn
*
src
->
stride
[
1
]];
for
(
mIn
=
0
;
mIn
<
THTensor_
(
size
)(
src
,
1
);
mIn
++
)
{
tmp
=
weightData
[
wi
*
weight
->
stride
[
0
]
+
mIn
*
weight
->
stride
[
1
]
+
mOut
*
weight
->
stride
[
2
]];
tmp
*=
b
*
srcData
[
e
*
src
->
stride
[
0
]
+
mIn
*
src
->
stride
[
1
]];
v
+=
tmp
;
}
}
selfData
[
e
*
self
->
stride
[
0
]
+
mOut
*
self
->
stride
[
1
]]
=
v
;
...
...
@@ -30,14 +32,88 @@ void THTensor_(weightingForward)(THTensor *self, THTensor *src, THTensor *weight
void
THTensor_
(
weightingBackwardSrc
)(
THTensor
*
self
,
THTensor
*
gradOutput
,
THTensor
*
weight
,
THTensor
*
basis
,
THLongTensor
*
weightIndex
)
{
THTensor_
(
fill
)(
self
,
0
);
real
*
selfData
=
THTensor_
(
data
)(
self
);
real
*
gradOutputData
=
THTensor_
(
data
)(
gradOutput
);
real
*
weightData
=
THTensor_
(
data
)(
weight
);
real
*
basisData
=
THTensor_
(
data
)(
basis
);
int64_t
*
weightIndexData
=
THLongTensor_data
(
weightIndex
);
ptrdiff_t
e
,
mOut
,
s
,
mIn
;
real
g
,
b
,
v
;
int64_t
wi
;
for
(
e
=
0
;
e
<
THTensor_
(
size
)(
self
,
0
);
e
++
)
{
for
(
mOut
=
0
;
mOut
<
THTensor_
(
size
)(
gradOutput
,
1
);
mOut
++
)
{
g
=
gradOutputData
[
e
*
gradOutput
->
stride
[
0
]
+
mOut
*
gradOutput
->
stride
[
1
]];
for
(
s
=
0
;
s
<
THTensor_
(
size
)(
basis
,
1
);
s
++
)
{
b
=
basisData
[
e
*
basis
->
stride
[
0
]
+
s
*
basis
->
stride
[
1
]];
wi
=
weightIndexData
[
e
*
weightIndex
->
stride
[
0
]
+
s
*
weightIndex
->
stride
[
1
]];
for
(
mIn
=
0
;
mIn
<
THTensor_
(
size
)(
self
,
1
);
mIn
++
)
{
v
=
weightData
[
wi
*
weight
->
stride
[
0
]
+
mIn
*
weight
->
stride
[
1
]
+
mOut
*
weight
->
stride
[
2
]];
selfData
[
e
*
self
->
stride
[
0
]
+
mIn
*
self
->
stride
[
1
]]
+=
g
*
b
*
v
;
}
}
}
}
}
void
THTensor_
(
weightingBackwardWeight
)(
THTensor
*
self
,
THTensor
*
gradOutput
,
THTensor
*
src
,
THTensor
*
basis
,
THLongTensor
*
weightIndex
)
{
THTensor_
(
fill
)(
self
,
0
);
real
*
selfData
=
THTensor_
(
data
)(
self
);
real
*
gradOutputData
=
THTensor_
(
data
)(
gradOutput
);
real
*
srcData
=
THTensor_
(
data
)(
src
);
real
*
basisData
=
THTensor_
(
data
)(
basis
);
int64_t
*
weightIndexData
=
THLongTensor_data
(
weightIndex
);
ptrdiff_t
e
,
mOut
,
s
,
mIn
;
real
g
,
b
,
v
;
int64_t
wi
;
for
(
e
=
0
;
e
<
THTensor_
(
size
)(
src
,
0
);
e
++
)
{
for
(
mOut
=
0
;
mOut
<
THTensor_
(
size
)(
gradOutput
,
1
);
mOut
++
)
{
g
=
gradOutputData
[
e
*
gradOutput
->
stride
[
0
]
+
mOut
*
gradOutput
->
stride
[
1
]];
for
(
s
=
0
;
s
<
THTensor_
(
size
)(
basis
,
1
);
s
++
)
{
b
=
basisData
[
e
*
basis
->
stride
[
0
]
+
s
*
basis
->
stride
[
1
]];
wi
=
weightIndexData
[
e
*
weightIndex
->
stride
[
0
]
+
s
*
weightIndex
->
stride
[
1
]];
for
(
mIn
=
0
;
mIn
<
THTensor_
(
size
)(
src
,
1
);
mIn
++
)
{
v
=
b
*
g
*
srcData
[
e
*
src
->
stride
[
0
]
+
mIn
*
src
->
stride
[
1
]];
selfData
[
wi
*
self
->
stride
[
0
]
+
mIn
*
self
->
stride
[
1
]
+
mOut
*
self
->
stride
[
2
]]
+=
v
;
}
}
}
}
}
void
THTensor_
(
weightingBackwardBasis
)(
THTensor
*
self
,
THTensor
*
gradOutput
,
THTensor
*
src
,
THTensor
*
weight
,
THLongTensor
*
weightIndex
)
{
THTensor_
(
fill
)(
self
,
0
);
real
*
selfData
=
THTensor_
(
data
)(
self
);
real
*
gradOutputData
=
THTensor_
(
data
)(
gradOutput
);
real
*
srcData
=
THTensor_
(
data
)(
src
);
real
*
weightData
=
THTensor_
(
data
)(
weight
);
int64_t
*
weightIndexData
=
THLongTensor_data
(
weightIndex
);
ptrdiff_t
e
,
mOut
,
s
,
mIn
;
real
g
,
v
,
tmp
;
int64_t
wi
;
for
(
e
=
0
;
e
<
THTensor_
(
size
)(
src
,
0
);
e
++
)
{
for
(
mOut
=
0
;
mOut
<
THTensor_
(
size
)(
gradOutput
,
1
);
mOut
++
)
{
g
=
gradOutputData
[
e
*
gradOutput
->
stride
[
0
]
+
mOut
*
gradOutput
->
stride
[
1
]];
for
(
s
=
0
;
s
<
THLongTensor_size
(
weightIndex
,
1
);
s
++
)
{
v
=
0
;
wi
=
weightIndexData
[
e
*
weightIndex
->
stride
[
0
]
+
s
*
weightIndex
->
stride
[
1
]];
for
(
mIn
=
0
;
mIn
<
THTensor_
(
size
)(
src
,
1
);
mIn
++
)
{
tmp
=
weightData
[
wi
*
weight
->
stride
[
0
]
+
mIn
*
weight
->
stride
[
1
]
+
mOut
*
weight
->
stride
[
2
]];
tmp
*=
srcData
[
e
*
src
->
stride
[
0
]
+
mIn
*
src
->
stride
[
1
]];
v
+=
tmp
;
}
selfData
[
e
*
self
->
stride
[
0
]
+
s
*
self
->
stride
[
1
]]
+=
g
*
v
;
}
}
}
}
#endif // TH_GENERIC_FILE
test/test_weighting.py
View file @
9362b2d3
...
...
@@ -2,7 +2,9 @@ from itertools import product
import
pytest
import
torch
from
torch_spline_conv.weighting
import
spline_weighting
from
torch.autograd
import
Variable
,
gradcheck
from
torch_spline_conv.weighting
import
spline_weighting
,
SplineWeighting
from
torch_spline_conv.basis
import
spline_basis
from
.tensor
import
tensors
...
...
@@ -19,7 +21,7 @@ tests = [{
@
pytest
.
mark
.
parametrize
(
'tensor,i'
,
product
(
tensors
,
range
(
len
(
tests
))))
def
test_spline_
basis
_forward_cpu
(
tensor
,
i
):
def
test_spline_
weighting
_forward_cpu
(
tensor
,
i
):
data
=
tests
[
i
]
src
=
getattr
(
torch
,
tensor
)(
data
[
'src'
])
...
...
@@ -33,7 +35,7 @@ def test_spline_basis_forward_cpu(tensor, i):
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'no CUDA'
)
@
pytest
.
mark
.
parametrize
(
'tensor,i'
,
product
(
tensors
,
range
(
len
(
tests
))))
def
test_spline_
basis
_forward_gpu
(
tensor
,
i
):
def
test_spline_
weighting
_forward_gpu
(
tensor
,
i
):
data
=
tests
[
i
]
src
=
getattr
(
torch
.
cuda
,
tensor
)(
data
[
'src'
])
...
...
@@ -43,3 +45,19 @@ def test_spline_basis_forward_gpu(tensor, i):
output
=
spline_weighting
(
src
,
weight
,
basis
,
weight_index
)
assert
output
.
cpu
().
tolist
()
==
data
[
'output'
]
def
test_spline_basis_backward_cpu
():
src
=
torch
.
DoubleTensor
(
4
,
2
).
uniform_
(
0
,
1
)
weight
=
torch
.
DoubleTensor
(
25
,
2
,
4
).
uniform_
(
0
,
1
)
kernel_size
=
torch
.
LongTensor
([
5
,
5
])
is_open_spline
=
torch
.
ByteTensor
([
1
,
1
])
pseudo
=
torch
.
DoubleTensor
(
4
,
2
).
uniform_
(
0
,
1
)
basis
,
weight_index
=
spline_basis
(
1
,
pseudo
,
kernel_size
,
is_open_spline
)
src
=
Variable
(
src
,
requires_grad
=
True
)
weight
=
Variable
(
weight
,
requires_grad
=
True
)
basis
=
Variable
(
basis
,
requires_grad
=
True
)
op
=
SplineWeighting
(
weight_index
)
assert
gradcheck
(
op
,
(
src
,
weight
,
basis
),
eps
=
1e-6
,
atol
=
1e-4
)
is
True
torch_spline_conv/weighting.py
View file @
9362b2d3
...
...
@@ -15,7 +15,6 @@ def weighting_forward(src, weight, basis, weight_index):
def
weighting_backward_src
(
grad_output
,
weight
,
basis
,
weight_index
):
grad_src
=
grad_output
.
new
(
grad_output
.
size
(
0
),
weight
.
size
(
1
))
weight
=
weight
.
transpose
(
1
,
2
).
contiguous
()
# Coalesced memory access.
weighting_bw_src
(
grad_src
,
grad_output
,
weight
,
basis
,
weight_index
)
return
grad_src
...
...
@@ -49,8 +48,9 @@ class SplineWeighting(Function):
grad_src
=
weighting_backward_src
(
grad_output
,
weight
,
basis
,
self
.
weight_index
)
if
self
.
needs_input_grad
[
1
]:
K
=
weight
.
size
(
0
)
grad_weight
=
weighting_backward_weight
(
grad_output
,
src
,
basis
,
self
.
weight_index
)
self
.
weight_index
,
K
)
if
self
.
needs_input_grad
[
2
]:
grad_basis
=
weighting_backward_basis
(
grad_output
,
src
,
weight
,
self
.
weight_index
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment