Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
20a7cd3c
Commit
20a7cd3c
authored
Mar 06, 2019
by
rusty1s
Browse files
multi gpu update
parent
b1072a59
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
19 additions
and
2 deletions
+19
-2
cuda/scatter_kernel.cu
cuda/scatter_kernel.cu
+5
-0
setup.py
setup.py
+1
-1
test/test_multi_gpu.py
test/test_multi_gpu.py
+12
-0
torch_scatter/__init__.py
torch_scatter/__init__.py
+1
-1
No files found.
cuda/scatter_kernel.cu
View file @
20a7cd3c
...
...
@@ -43,6 +43,7 @@ scatter_mul_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void
scatter_mul_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
int64_t
dim
)
{
cudaSetDevice
(
src
.
get_device
());
AT_DISPATCH_ALL_TYPES
(
src
.
type
(),
"scatter_mul_kernel"
,
[
&
]
{
KERNEL_RUN
(
scatter_mul_kernel
,
index
.
dim
(),
index
.
numel
(),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
src
),
...
...
@@ -69,6 +70,7 @@ scatter_div_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void
scatter_div_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
int64_t
dim
)
{
cudaSetDevice
(
src
.
get_device
());
AT_DISPATCH_ALL_TYPES
(
src
.
type
(),
"scatter_div_kernel"
,
[
&
]
{
KERNEL_RUN
(
scatter_div_kernel
,
index
.
dim
(),
index
.
numel
(),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
src
),
...
...
@@ -114,6 +116,7 @@ scatter_max_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void
scatter_max_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
at
::
Tensor
arg
,
int64_t
dim
)
{
cudaSetDevice
(
src
.
get_device
());
AT_DISPATCH_ALL_TYPES
(
src
.
type
(),
"scatter_max_kernel"
,
[
&
]
{
auto
src_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
src
);
auto
index_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
index
);
...
...
@@ -144,6 +147,7 @@ scatter_min_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void
scatter_min_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
at
::
Tensor
arg
,
int64_t
dim
)
{
cudaSetDevice
(
src
.
get_device
());
AT_DISPATCH_ALL_TYPES
(
src
.
type
(),
"scatter_min_kernel"
,
[
&
]
{
auto
src_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
src
);
auto
index_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
index
);
...
...
@@ -179,6 +183,7 @@ index_backward_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> grad,
void
index_backward_cuda
(
at
::
Tensor
grad
,
at
::
Tensor
index
,
at
::
Tensor
arg
,
at
::
Tensor
out
,
int64_t
dim
)
{
cudaSetDevice
(
grad
.
get_device
());
AT_DISPATCH_ALL_TYPES
(
grad
.
type
(),
"index_backward_kernel"
,
[
&
]
{
KERNEL_RUN
(
index_backward_kernel
,
index
.
dim
(),
index
.
numel
(),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
grad
),
...
...
setup.py
View file @
20a7cd3c
...
...
@@ -20,7 +20,7 @@ if CUDA_HOME is not None:
[
'cuda/scatter.cpp'
,
'cuda/scatter_kernel.cu'
])
]
__version__
=
'1.1.
1
'
__version__
=
'1.1.
2
'
url
=
'https://github.com/rusty1s/pytorch_scatter'
install_requires
=
[]
...
...
test/test_multi_gpu.py
0 → 100644
View file @
20a7cd3c
import
pytest
import
torch
from
torch_scatter
import
scatter_max
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
'No multiple GPUS'
)
def
test_multi_gpu
():
device
=
torch
.
device
(
'cuda:1'
)
src
=
torch
.
tensor
([
2.0
,
3.0
,
4.0
,
5.0
],
device
=
device
)
index
=
torch
.
tensor
([
0
,
0
,
1
,
1
],
device
=
device
)
assert
scatter_max
(
src
,
index
)[
0
].
tolist
()
==
[
3
,
5
]
torch_scatter/__init__.py
View file @
20a7cd3c
...
...
@@ -7,7 +7,7 @@ from .std import scatter_std
from
.max
import
scatter_max
from
.min
import
scatter_min
__version__
=
'1.1.
1
'
__version__
=
'1.1.
2
'
__all__
=
[
'scatter_add'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment