Commit 20a7cd3c authored by rusty1s's avatar rusty1s
Browse files

multi gpu update

parent b1072a59
......@@ -43,6 +43,7 @@ scatter_mul_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void scatter_mul_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_mul_kernel", [&] {
KERNEL_RUN(scatter_mul_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
......@@ -69,6 +70,7 @@ scatter_div_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void scatter_div_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_div_kernel", [&] {
KERNEL_RUN(scatter_div_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
......@@ -114,6 +116,7 @@ scatter_max_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void scatter_max_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_max_kernel", [&] {
auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
......@@ -144,6 +147,7 @@ scatter_min_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void scatter_min_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.type(), "scatter_min_kernel", [&] {
auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
......@@ -179,6 +183,7 @@ index_backward_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> grad,
void index_backward_cuda(at::Tensor grad, at::Tensor index, at::Tensor arg,
at::Tensor out, int64_t dim) {
cudaSetDevice(grad.get_device());
AT_DISPATCH_ALL_TYPES(grad.type(), "index_backward_kernel", [&] {
KERNEL_RUN(index_backward_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad),
......
......@@ -20,7 +20,7 @@ if CUDA_HOME is not None:
['cuda/scatter.cpp', 'cuda/scatter_kernel.cu'])
]
__version__ = '1.1.1'
__version__ = '1.1.2'
url = 'https://github.com/rusty1s/pytorch_scatter'
install_requires = []
......
import pytest
import torch
from torch_scatter import scatter_max
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='No multiple GPUS')
def test_multi_gpu():
device = torch.device('cuda:1')
src = torch.tensor([2.0, 3.0, 4.0, 5.0], device=device)
index = torch.tensor([0, 0, 1, 1], device=device)
assert scatter_max(src, index)[0].tolist() == [3, 5]
......@@ -7,7 +7,7 @@ from .std import scatter_std
from .max import scatter_max
from .min import scatter_min
__version__ = '1.1.1'
__version__ = '1.1.2'
__all__ = [
'scatter_add',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment