Commit 5a2fc4bb authored by rusty1s's avatar rusty1s
Browse files

first cuda run

parent 05ecd667
from os import path as osp from os import path as osp
import torch
from torch.utils.ffi import create_extension from torch.utils.ffi import create_extension
abs_path = osp.join(osp.dirname(osp.realpath(__file__)), 'torch_scatter') abs_path = osp.join(osp.dirname(osp.realpath(__file__)), 'torch_scatter')
abs_path = 'torch_scatter'
headers = ['torch_scatter/src/cpu.h'] headers = ['torch_scatter/src/cpu.h']
sources = ['torch_scatter/src/cpu.c'] sources = ['torch_scatter/src/cpu.c']
...@@ -11,6 +13,12 @@ defines = [] ...@@ -11,6 +13,12 @@ defines = []
extra_objects = [] extra_objects = []
with_cuda = False with_cuda = False
if torch.cuda.is_available():
headers += ['torch_scatter/src/cuda.h']
sources += ['torch_scatter/src/cuda.c']
defines += [('WITH_CUDA', None)]
with_cuda = True
ffi = create_extension( ffi = create_extension(
name='torch_scatter._ext.ffi', name='torch_scatter._ext.ffi',
package=True, package=True,
......
...@@ -7,7 +7,7 @@ from .utils import tensor_strs, Tensor ...@@ -7,7 +7,7 @@ from .utils import tensor_strs, Tensor
@pytest.mark.parametrize('str', tensor_strs) @pytest.mark.parametrize('str', tensor_strs)
def test_scatter_mean(str): def test_scatter_max(str):
input = [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]] input = [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]
index = [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]] index = [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]
input = Tensor(str, input) input = Tensor(str, input)
...@@ -35,3 +35,19 @@ def test_scatter_mean(str): ...@@ -35,3 +35,19 @@ def test_scatter_mean(str):
output.backward(grad_output) output.backward(grad_output)
assert input.grad.data.tolist() == expected_grad_input assert input.grad.data.tolist() == expected_grad_input
@pytest.mark.parametrize('str', tensor_strs)
def test_scatter_cuda_max(str):
input = [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]
index = [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]
input = Tensor(str, input)
index = torch.LongTensor(index)
output = input.new(2, 6).fill_(0)
expected_output = [[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]]
expected_arg_output = [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]]
output, index, input = output.cuda(), index.cuda(), input.cuda()
_, arg_output = scatter_max_(output, index, input, dim=1)
print(output)
...@@ -33,7 +33,8 @@ def _scatter(name, dim, *data): ...@@ -33,7 +33,8 @@ def _scatter(name, dim, *data):
'specified dimension') 'specified dimension')
typename = type(data[0]).__name__.replace('Tensor', '') typename = type(data[0]).__name__.replace('Tensor', '')
func = getattr(ffi, 'scatter_{}_{}'.format(name, typename)) cuda = 'cuda_' if data[0].is_cuda else ''
func = getattr(ffi, 'scatter_{}_{}{}'.format(name, cuda, typename))
func(dim, *data) func(dim, *data)
return (data[0], data[3]) if has_arg_output(name) else data[0] return (data[0], data[3]) if has_arg_output(name) else data[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment