Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
8d6acb03
Commit
8d6acb03
authored
Aug 13, 2018
by
rusty1s
Browse files
clean up and bugfixes
parent
b46459f4
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
3 additions
and
70 deletions
+3
-70
aten/THCC/generic/THCCBasis.c
aten/THCC/generic/THCCBasis.c
+0
-41
aten/THCC/generic/THCCWeighting.c
aten/THCC/generic/THCCWeighting.c
+0
-25
cuda/basis_kernel.cu
cuda/basis_kernel.cu
+2
-2
test/test_conv.py
test/test_conv.py
+1
-1
test/test_weighting.py
test/test_weighting.py
+0
-1
No files found.
aten/THCC/generic/THCCBasis.c
deleted
100644 → 0
View file @
b46459f4
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/THCCBasis.c"
#else
void
THCCTensor_
(
linearBasisForward
)(
THCTensor
*
basis
,
THCudaLongTensor
*
weightIndex
,
THCTensor
*
pseudo
,
THCudaLongTensor
*
kernelSize
,
THCudaByteTensor
*
isOpenSpline
)
{
THCTensor_
(
linearBasisForward
)(
state
,
basis
,
weightIndex
,
pseudo
,
kernelSize
,
isOpenSpline
);
}
void
THCCTensor_
(
quadraticBasisForward
)(
THCTensor
*
basis
,
THCudaLongTensor
*
weightIndex
,
THCTensor
*
pseudo
,
THCudaLongTensor
*
kernelSize
,
THCudaByteTensor
*
isOpenSpline
)
{
THCTensor_
(
quadraticBasisForward
)(
state
,
basis
,
weightIndex
,
pseudo
,
kernelSize
,
isOpenSpline
);
}
void
THCCTensor_
(
cubicBasisForward
)(
THCTensor
*
basis
,
THCudaLongTensor
*
weightIndex
,
THCTensor
*
pseudo
,
THCudaLongTensor
*
kernelSize
,
THCudaByteTensor
*
isOpenSpline
)
{
THCTensor_
(
cubicBasisForward
)(
state
,
basis
,
weightIndex
,
pseudo
,
kernelSize
,
isOpenSpline
);
}
void
THCCTensor_
(
linearBasisBackward
)(
THCTensor
*
self
,
THCTensor
*
gradBasis
,
THCTensor
*
pseudo
,
THCudaLongTensor
*
kernelSize
,
THCudaByteTensor
*
isOpenSpline
)
{
THCTensor_
(
linearBasisBackward
)(
state
,
self
,
gradBasis
,
pseudo
,
kernelSize
,
isOpenSpline
);
}
void
THCCTensor_
(
quadraticBasisBackward
)(
THCTensor
*
self
,
THCTensor
*
gradBasis
,
THCTensor
*
pseudo
,
THCudaLongTensor
*
kernelSize
,
THCudaByteTensor
*
isOpenSpline
)
{
THCTensor_
(
quadraticBasisBackward
)(
state
,
self
,
gradBasis
,
pseudo
,
kernelSize
,
isOpenSpline
);
}
void
THCCTensor_
(
cubicBasisBackward
)(
THCTensor
*
self
,
THCTensor
*
gradBasis
,
THCTensor
*
pseudo
,
THCudaLongTensor
*
kernelSize
,
THCudaByteTensor
*
isOpenSpline
)
{
THCTensor_
(
cubicBasisBackward
)(
state
,
self
,
gradBasis
,
pseudo
,
kernelSize
,
isOpenSpline
);
}
#endif // THC_GENERIC_FILE
aten/THCC/generic/THCCWeighting.c
deleted
100644 → 0
View file @
b46459f4
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/THCCWeighting.c"
#else
void
THCCTensor_
(
weightingForward
)(
THCTensor
*
self
,
THCTensor
*
src
,
THCTensor
*
weight
,
THCTensor
*
basis
,
THCudaLongTensor
*
weightIndex
)
{
THCTensor_
(
weightingForward
)(
state
,
self
,
src
,
weight
,
basis
,
weightIndex
);
}
void
THCCTensor_
(
weightingBackwardSrc
)(
THCTensor
*
self
,
THCTensor
*
gradOutput
,
THCTensor
*
weight
,
THCTensor
*
basis
,
THCudaLongTensor
*
weightIndex
)
{
THCTensor_
(
weightingBackwardSrc
)(
state
,
self
,
gradOutput
,
weight
,
basis
,
weightIndex
);
}
void
THCCTensor_
(
weightingBackwardWeight
)(
THCTensor
*
self
,
THCTensor
*
gradOutput
,
THCTensor
*
src
,
THCTensor
*
basis
,
THCudaLongTensor
*
weightIndex
)
{
THCTensor_
(
weightingBackwardWeight
)(
state
,
self
,
gradOutput
,
src
,
basis
,
weightIndex
);
}
void
THCCTensor_
(
weightingBackwardBasis
)(
THCTensor
*
self
,
THCTensor
*
gradOutput
,
THCTensor
*
src
,
THCTensor
*
weight
,
THCudaLongTensor
*
weightIndex
)
{
THCTensor_
(
weightingBackwardBasis
)(
state
,
self
,
gradOutput
,
src
,
weight
,
weightIndex
);
}
#endif // THC_GENERIC_FILE
cuda/basis_kernel.cu
View file @
8d6acb03
...
...
@@ -193,7 +193,7 @@ template <typename scalar_t> struct BasisBackward {
auto v = PSEUDO.data[e * PSEUDO.strides[0] + d * PSEUDO.strides[1]]; \
v *= KERNEL_SIZE[d] - M * IS_OPEN_SPLINE[d]; \
v -= floor(v); \
v = CODE;
\
v =
GRAD_
CODE; \
tmp = v; \
\
for (ptrdiff_t d_it = 1; d_it < GRAD_PSEUDO.sizes[1]; d_it++) { \
...
...
@@ -202,7 +202,7 @@ template <typename scalar_t> struct BasisBackward {
v = PSEUDO.data[e * pseudo.strides[0] + d_new * PSEUDO.strides[1]]; \
v *= KERNEL_SIZE[d_new] - M * IS_OPEN_SPLINE[d_new]; \
v -= floor(v); \
v =
GRAD_
CODE; \
v = CODE;
\
tmp *= v; \
} \
g += tmp * \
...
...
test/test_conv.py
View file @
8d6acb03
...
...
@@ -7,7 +7,7 @@ from torch_spline_conv import SplineConv
from
torch_spline_conv.basis
import
implemented_degrees
as
degrees
from
.utils
import
dtypes
,
devices
,
tensor
devices
=
[
torch
.
device
(
'c
p
u'
)]
devices
=
[
torch
.
device
(
'cu
da
'
)]
tests
=
[{
'x'
:
[[
9
,
10
],
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
]],
...
...
test/test_weighting.py
View file @
8d6acb03
...
...
@@ -7,7 +7,6 @@ from torch_spline_conv.weighting import SplineWeighting
from
torch_spline_conv.basis
import
SplineBasis
from
.utils
import
dtypes
,
devices
,
tensor
devices
=
[
torch
.
device
(
'cuda'
)]
tests
=
[{
'x'
:
[[
1
,
2
],
[
3
,
4
]],
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment