Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
2862a818
Commit
2862a818
authored
Mar 06, 2019
by
rusty1s
Browse files
multi gpu update
parent
46fe125c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
8 additions
and
2 deletions
+8
-2
cuda/basis_kernel.cu
cuda/basis_kernel.cu
+2
-0
cuda/weighting_kernel.cu
cuda/weighting_kernel.cu
+4
-0
setup.py
setup.py
+1
-1
torch_spline_conv/__init__.py
torch_spline_conv/__init__.py
+1
-1
No files found.
cuda/basis_kernel.cu
View file @
2862a818
...
...
@@ -33,6 +33,7 @@ template <typename scalar_t> struct BasisForward {
#define BASIS_FORWARD(M, PSEUDO, KERNEL_SIZE, IS_OPEN_SPLINE, KERNEL_NAME) \
[&]() -> std::tuple<at::Tensor, at::Tensor> { \
cudaSetDevice(PSEUDO.get_device()); \
auto E = PSEUDO.size(0); \
auto S = (int64_t)(powf(M + 1, KERNEL_SIZE.size(0)) + 0.5); \
auto basis = at::empty({E, S}, PSEUDO.options()); \
...
...
@@ -163,6 +164,7 @@ template <typename scalar_t> struct BasisBackward {
#define BASIS_BACKWARD(M, GRAD_BASIS, PSEUDO, KERNEL_SIZE, IS_OPEN_SPLINE, \
KERNEL_NAME) \
[&]() -> at::Tensor { \
cudaSetDevice(GRAD_BASIS.get_device()); \
auto E = PSEUDO.size(0); \
auto D = PSEUDO.size(1); \
auto grad_pseudo = at::empty({E, D}, PSEUDO.options()); \
...
...
cuda/weighting_kernel.cu
View file @
2862a818
...
...
@@ -39,6 +39,7 @@ weighting_fw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
at
::
Tensor
weighting_fw_cuda
(
at
::
Tensor
x
,
at
::
Tensor
weight
,
at
::
Tensor
basis
,
at
::
Tensor
weight_index
)
{
cudaSetDevice
(
x
.
get_device
());
auto
E
=
x
.
size
(
0
),
M_out
=
weight
.
size
(
2
);
auto
out
=
at
::
empty
({
E
,
M_out
},
x
.
options
());
AT_DISPATCH_FLOATING_TYPES
(
out
.
type
(),
"weighting_fw"
,
[
&
]
{
...
...
@@ -86,6 +87,7 @@ __global__ void weighting_bw_x_kernel(
at
::
Tensor
weighting_bw_x_cuda
(
at
::
Tensor
grad_out
,
at
::
Tensor
weight
,
at
::
Tensor
basis
,
at
::
Tensor
weight_index
)
{
cudaSetDevice
(
grad_out
.
get_device
());
auto
E
=
grad_out
.
size
(
0
),
M_in
=
weight
.
size
(
1
);
auto
grad_x
=
at
::
empty
({
E
,
M_in
},
grad_out
.
options
());
weight
=
weight
.
transpose
(
1
,
2
).
contiguous
();
...
...
@@ -131,6 +133,7 @@ __global__ void weighting_bw_w_kernel(
at
::
Tensor
weighting_bw_w_cuda
(
at
::
Tensor
grad_out
,
at
::
Tensor
x
,
at
::
Tensor
basis
,
at
::
Tensor
weight_index
,
int64_t
K
)
{
cudaSetDevice
(
grad_out
.
get_device
());
auto
M_in
=
x
.
size
(
1
),
M_out
=
grad_out
.
size
(
1
);
auto
grad_weight
=
at
::
zeros
({
K
,
M_in
,
M_out
},
grad_out
.
options
());
AT_DISPATCH_FLOATING_TYPES
(
grad_out
.
type
(),
"weighting_bw_w"
,
[
&
]
{
...
...
@@ -175,6 +178,7 @@ __global__ void weighting_bw_b_kernel(
at
::
Tensor
weighting_bw_b_cuda
(
at
::
Tensor
grad_out
,
at
::
Tensor
x
,
at
::
Tensor
weight
,
at
::
Tensor
weight_index
)
{
cudaSetDevice
(
grad_out
.
get_device
());
auto
E
=
x
.
size
(
0
),
S
=
weight_index
.
size
(
1
);
auto
grad_basis
=
at
::
zeros
({
E
,
S
},
grad_out
.
options
());
AT_DISPATCH_FLOATING_TYPES
(
grad_out
.
type
(),
"weighting_bw_b"
,
[
&
]
{
...
...
setup.py
View file @
2862a818
...
...
@@ -16,7 +16,7 @@ if CUDA_HOME is not None:
[
'cuda/weighting.cpp'
,
'cuda/weighting_kernel.cu'
]),
]
__version__
=
'1.0.
5
'
__version__
=
'1.0.
6
'
url
=
'https://github.com/rusty1s/pytorch_spline_conv'
install_requires
=
[]
...
...
torch_spline_conv/__init__.py
View file @
2862a818
...
...
@@ -2,6 +2,6 @@ from .basis import SplineBasis
from
.weighting
import
SplineWeighting
from
.conv
import
SplineConv
__version__
=
'1.0.
5
'
__version__
=
'1.0.
6
'
__all__
=
[
'SplineBasis'
,
'SplineWeighting'
,
'SplineConv'
,
'__version__'
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment