Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
ab4d22e0
Commit
ab4d22e0
authored
Feb 29, 2020
by
rusty1s
Browse files
cuda fix
parent
26c0b866
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
48 additions
and
48 deletions
+48
-48
.travis.yml
.travis.yml
+11
-11
csrc/cuda/basis_cuda.cu
csrc/cuda/basis_cuda.cu
+13
-13
csrc/cuda/weighting_cuda.cu
csrc/cuda/weighting_cuda.cu
+24
-24
No files found.
.travis.yml
View file @
ab4d22e0
...
...
@@ -9,18 +9,18 @@ env:
global
:
-
CUDA_HOME=/usr/local/cuda
jobs
:
#
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cpu
-
TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cpu
-
TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu92
#
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu100
#
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu101
#
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cpu
#
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu92
#
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu100
#
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu101
#
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cpu
#
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu92
#
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu100
#
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu101
-
TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu100
-
TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu101
-
TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cpu
-
TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu92
-
TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu100
-
TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu101
-
TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cpu
-
TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu92
-
TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu100
-
TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu101
jobs
:
exclude
:
# Exclude *all* macOS CUDA jobs and Windows CUDA 9.2/10.0 jobs.
...
...
csrc/cuda/basis_cuda.cu
View file @
ab4d22e0
...
...
@@ -98,7 +98,7 @@ spline_basis_fw_cuda(torch::Tensor pseudo, torch::Tensor kernel_size,
CHECK_CUDA
(
pseudo
);
CHECK_CUDA
(
kernel_size
);
CHECK_CUDA
(
is_open_spline
);
//
cudaSetDevice(pseudo.get_device());
cudaSetDevice
(
pseudo
.
get_device
());
CHECK_INPUT
(
kernel_size
.
dim
()
==
1
);
CHECK_INPUT
(
pseudo
.
size
(
1
)
==
kernel_size
.
numel
());
...
...
@@ -116,16 +116,16 @@ spline_basis_fw_cuda(torch::Tensor pseudo, torch::Tensor kernel_size,
auto
is_open_spline_data
=
is_open_spline
.
data_ptr
<
uint8_t
>
();
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
//
auto stream = at::cuda::getCurrentCUDAStream();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES
(
pseudo
.
scalar_type
(),
"basis_fw"
,
[
&
]
{
auto
pseudo_data
=
pseudo
.
data_ptr
<
scalar_t
>
();
auto
basis_data
=
basis
.
data_ptr
<
scalar_t
>
();
AT_DISPATCH_DEGREE_TYPES
(
degree
,
[
&
]
{
//
spline_basis_fw_kernel<scalar_t, DEGREE>
//
<<<BLOCKS(basis.numel()), THREADS, 0, stream>>>(
//
pseudo_data, kernel_size_data, is_open_spline_data, basis_data,
//
weight_index_data, E, D, S, basis.numel());
spline_basis_fw_kernel
<
scalar_t
,
DEGREE
>
<<<
BLOCKS
(
basis
.
numel
()),
THREADS
,
0
,
stream
>>>
(
pseudo_data
,
kernel_size_data
,
is_open_spline_data
,
basis_data
,
weight_index_data
,
E
,
D
,
S
,
basis
.
numel
());
});
});
...
...
@@ -180,7 +180,7 @@ torch::Tensor spline_basis_bw_cuda(torch::Tensor grad_basis,
CHECK_CUDA
(
pseudo
);
CHECK_CUDA
(
kernel_size
);
CHECK_CUDA
(
is_open_spline
);
//
cudaSetDevice(grad_basis.get_device());
cudaSetDevice
(
grad_basis
.
get_device
());
CHECK_INPUT
(
grad_basis
.
size
(
0
)
==
pseudo
.
size
(
0
));
CHECK_INPUT
(
kernel_size
.
dim
()
==
1
);
...
...
@@ -197,18 +197,18 @@ torch::Tensor spline_basis_bw_cuda(torch::Tensor grad_basis,
auto
kernel_size_data
=
kernel_size
.
data_ptr
<
int64_t
>
();
auto
is_open_spline_data
=
is_open_spline
.
data_ptr
<
uint8_t
>
();
//
auto stream = at::cuda::getCurrentCUDAStream();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES
(
pseudo
.
scalar_type
(),
"basis_bw"
,
[
&
]
{
auto
grad_basis_data
=
grad_basis
.
data_ptr
<
scalar_t
>
();
auto
pseudo_data
=
pseudo
.
data_ptr
<
scalar_t
>
();
auto
grad_pseudo_data
=
grad_pseudo
.
data_ptr
<
scalar_t
>
();
AT_DISPATCH_DEGREE_TYPES
(
degree
,
[
&
]
{
//
spline_basis_bw_kernel<scalar_t, DEGREE>
//
<<<BLOCKS(grad_pseudo.numel()), THREADS, 0, stream>>>(
//
grad_basis_data, pseudo_data, kernel_size_data,
//
is_open_spline_data, grad_pseudo_data, E, D, S,
//
grad_pseudo.numel());
spline_basis_bw_kernel
<
scalar_t
,
DEGREE
>
<<<
BLOCKS
(
grad_pseudo
.
numel
()),
THREADS
,
0
,
stream
>>>
(
grad_basis_data
,
pseudo_data
,
kernel_size_data
,
is_open_spline_data
,
grad_pseudo_data
,
E
,
D
,
S
,
grad_pseudo
.
numel
());
});
});
...
...
csrc/cuda/weighting_cuda.cu
View file @
ab4d22e0
...
...
@@ -41,7 +41,7 @@ torch::Tensor spline_weighting_fw_cuda(torch::Tensor x, torch::Tensor weight,
CHECK_CUDA
(
weight
);
CHECK_CUDA
(
basis
);
CHECK_CUDA
(
weight_index
);
//
cudaSetDevice(x.get_device());
cudaSetDevice
(
x
.
get_device
());
CHECK_INPUT
(
x
.
size
(
1
)
==
weight
.
size
(
1
));
...
...
@@ -54,17 +54,17 @@ torch::Tensor spline_weighting_fw_cuda(torch::Tensor x, torch::Tensor weight,
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
//
auto stream = at::cuda::getCurrentCUDAStream();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"weighting_fw"
,
[
&
]
{
auto
x_data
=
x
.
data_ptr
<
scalar_t
>
();
auto
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
auto
basis_data
=
basis
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
//
spline_weighting_fw_kernel<scalar_t>
//
<<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
//
x_data, weight_data, basis_data, weight_index_data, out_data, E,
//
M_in, M_out, S, out.numel());
spline_weighting_fw_kernel
<
scalar_t
>
<<<
BLOCKS
(
out
.
numel
()),
THREADS
,
0
,
stream
>>>
(
x_data
,
weight_data
,
basis_data
,
weight_index_data
,
out_data
,
E
,
M_in
,
M_out
,
S
,
out
.
numel
());
});
return
out
;
...
...
@@ -106,7 +106,7 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
CHECK_CUDA
(
weight
);
CHECK_CUDA
(
basis
);
CHECK_CUDA
(
weight_index
);
//
cudaSetDevice(grad_out.get_device());
cudaSetDevice
(
grad_out
.
get_device
());
CHECK_INPUT
(
grad_out
.
size
(
1
)
==
weight
.
size
(
2
));
...
...
@@ -120,17 +120,17 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
//
auto stream = at::cuda::getCurrentCUDAStream();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES
(
grad_out
.
scalar_type
(),
"weighting_bw_x"
,
[
&
]
{
auto
grad_out_data
=
grad_out
.
data_ptr
<
scalar_t
>
();
auto
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
auto
basis_data
=
basis
.
data_ptr
<
scalar_t
>
();
auto
grad_x_data
=
grad_x
.
data_ptr
<
scalar_t
>
();
//
spline_weighting_bw_x_kernel<scalar_t>
//
<<<BLOCKS(grad_x.numel()), THREADS, 0, stream>>>(
//
grad_out_data, weight_data, basis_data, weight_index_data,
//
grad_x_data, E, M_in, M_out, S, grad_x.numel());
spline_weighting_bw_x_kernel
<
scalar_t
>
<<<
BLOCKS
(
grad_x
.
numel
()),
THREADS
,
0
,
stream
>>>
(
grad_out_data
,
weight_data
,
basis_data
,
weight_index_data
,
grad_x_data
,
E
,
M_in
,
M_out
,
S
,
grad_x
.
numel
());
});
return
grad_x
;
...
...
@@ -169,7 +169,7 @@ torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out,
CHECK_CUDA
(
x
);
CHECK_CUDA
(
basis
);
CHECK_CUDA
(
weight_index
);
//
cudaSetDevice(grad_out.get_device());
cudaSetDevice
(
grad_out
.
get_device
());
auto
E
=
grad_out
.
size
(
0
);
auto
M_in
=
x
.
size
(
1
);
...
...
@@ -180,17 +180,17 @@ torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out,
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
//
auto stream = at::cuda::getCurrentCUDAStream();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"weighting_bw_weight"
,
[
&
]
{
auto
grad_out_data
=
grad_out
.
data_ptr
<
scalar_t
>
();
auto
x_data
=
x
.
data_ptr
<
scalar_t
>
();
auto
basis_data
=
basis
.
data_ptr
<
scalar_t
>
();
auto
grad_weight_data
=
grad_weight
.
data_ptr
<
scalar_t
>
();
//
spline_weighting_bw_weight_kernel<scalar_t>
//
<<<BLOCKS(grad_out.numel()), THREADS, 0, stream>>>(
//
grad_out_data, x_data, basis_data, weight_index_data,
//
grad_weight_data, E, M_in, M_out, S, grad_out.numel());
spline_weighting_bw_weight_kernel
<
scalar_t
>
<<<
BLOCKS
(
grad_out
.
numel
()),
THREADS
,
0
,
stream
>>>
(
grad_out_data
,
x_data
,
basis_data
,
weight_index_data
,
grad_weight_data
,
E
,
M_in
,
M_out
,
S
,
grad_out
.
numel
());
});
return
grad_weight
;
...
...
@@ -230,7 +230,7 @@ torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out,
CHECK_CUDA
(
x
);
CHECK_CUDA
(
weight
);
CHECK_CUDA
(
weight_index
);
//
cudaSetDevice(grad_out.get_device());
cudaSetDevice
(
grad_out
.
get_device
());
CHECK_INPUT
(
x
.
size
(
1
)
==
weight
.
size
(
1
));
CHECK_INPUT
(
grad_out
.
size
(
1
)
==
weight
.
size
(
2
));
...
...
@@ -244,17 +244,17 @@ torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out,
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
//
auto stream = at::cuda::getCurrentCUDAStream();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"weighting_bw_basis"
,
[
&
]
{
auto
grad_out_data
=
grad_out
.
data_ptr
<
scalar_t
>
();
auto
x_data
=
x
.
data_ptr
<
scalar_t
>
();
auto
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
auto
grad_basis_data
=
grad_basis
.
data_ptr
<
scalar_t
>
();
//
spline_weighting_bw_basis_kernel<scalar_t>
//
<<<BLOCKS(grad_out.numel()), THREADS, 0, stream>>>(
//
grad_out_data, x_data, weight_data, weight_index_data,
//
grad_basis_data, E, M_in, M_out, S, grad_out.numel());
spline_weighting_bw_basis_kernel
<
scalar_t
>
<<<
BLOCKS
(
grad_out
.
numel
()),
THREADS
,
0
,
stream
>>>
(
grad_out_data
,
x_data
,
weight_data
,
weight_index_data
,
grad_basis_data
,
E
,
M_in
,
M_out
,
S
,
grad_out
.
numel
());
});
return
grad_basis
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment