Commit c1cd9753 authored by rusty1s's avatar rusty1s
Browse files

multi gpu update

parent e6a8f8c4
......@@ -30,6 +30,7 @@ static void init_cusparse() {
std::tuple<at::Tensor, at::Tensor>
spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
at::Tensor valueB, int m, int k, int n) {
cudaSetDevice(indexA.get_device());
init_cusparse();
indexA = indexA.contiguous();
......
......@@ -16,6 +16,7 @@ __global__ void unique_cuda_kernel(scalar_t *__restrict__ src, uint8_t *mask,
}
std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) {
cudaSetDevice(src.get_device());
at::Tensor perm;
std::tie(src, perm) = src.sort();
......
......@@ -2,7 +2,7 @@ import platform
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
__version__ = '0.2.3'
__version__ = '0.2.4'
url = 'https://github.com/rusty1s/pytorch_sparse'
install_requires = ['scipy']
......
......@@ -4,7 +4,7 @@ from .eye import eye
from .spmm import spmm
from .spspmm import spspmm
__version__ = '0.2.3'
__version__ = '0.2.4'
__all__ = [
'__version__',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment