Commit 5a2fc4bb authored by rusty1s's avatar rusty1s
Browse files

first cuda run

parent 05ecd667
from os import path as osp
import torch
from torch.utils.ffi import create_extension
abs_path = osp.join(osp.dirname(osp.realpath(__file__)), 'torch_scatter')
abs_path = 'torch_scatter'
headers = ['torch_scatter/src/cpu.h']
sources = ['torch_scatter/src/cpu.c']
......@@ -11,6 +13,12 @@ defines = []
extra_objects = []
with_cuda = False
if torch.cuda.is_available():
headers += ['torch_scatter/src/cuda.h']
sources += ['torch_scatter/src/cuda.c']
defines += [('WITH_CUDA', None)]
with_cuda = True
ffi = create_extension(
name='torch_scatter._ext.ffi',
package=True,
......
......@@ -7,7 +7,7 @@ from .utils import tensor_strs, Tensor
@pytest.mark.parametrize('str', tensor_strs)
def test_scatter_mean(str):
def test_scatter_max(str):
input = [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]
index = [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]
input = Tensor(str, input)
......@@ -35,3 +35,19 @@ def test_scatter_mean(str):
output.backward(grad_output)
assert input.grad.data.tolist() == expected_grad_input
@pytest.mark.parametrize('str', tensor_strs)
def test_scatter_cuda_max(str):
input = [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]
index = [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]
input = Tensor(str, input)
index = torch.LongTensor(index)
output = input.new(2, 6).fill_(0)
expected_output = [[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]]
expected_arg_output = [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]]
output, index, input = output.cuda(), index.cuda(), input.cuda()
_, arg_output = scatter_max_(output, index, input, dim=1)
print(output)
......@@ -33,7 +33,8 @@ def _scatter(name, dim, *data):
'specified dimension')
typename = type(data[0]).__name__.replace('Tensor', '')
func = getattr(ffi, 'scatter_{}_{}'.format(name, typename))
cuda = 'cuda_' if data[0].is_cuda else ''
func = getattr(ffi, 'scatter_{}_{}{}'.format(name, cuda, typename))
func(dim, *data)
return (data[0], data[3]) if has_arg_output(name) else data[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment