"vscode:/vscode.git/clone" did not exist on "6a5ba1b719f53110e2d7ab652aa7aa799d97f4a8"
Commit 5a2fc4bb authored by rusty1s's avatar rusty1s
Browse files

first cuda run

parent 05ecd667
from os import path as osp
import torch
from torch.utils.ffi import create_extension
abs_path = osp.join(osp.dirname(osp.realpath(__file__)), 'torch_scatter')
abs_path = 'torch_scatter'
headers = ['torch_scatter/src/cpu.h']
sources = ['torch_scatter/src/cpu.c']
......@@ -11,6 +13,12 @@ defines = []
extra_objects = []
with_cuda = False
if torch.cuda.is_available():
headers += ['torch_scatter/src/cuda.h']
sources += ['torch_scatter/src/cuda.c']
defines += [('WITH_CUDA', None)]
with_cuda = True
ffi = create_extension(
name='torch_scatter._ext.ffi',
package=True,
......
......@@ -7,7 +7,7 @@ from .utils import tensor_strs, Tensor
@pytest.mark.parametrize('str', tensor_strs)
def test_scatter_mean(str):
def test_scatter_max(str):
input = [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]
index = [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]
input = Tensor(str, input)
......@@ -35,3 +35,19 @@ def test_scatter_mean(str):
output.backward(grad_output)
assert input.grad.data.tolist() == expected_grad_input
@pytest.mark.parametrize('str', tensor_strs)
def test_scatter_cuda_max(str):
input = [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]
index = [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]
input = Tensor(str, input)
index = torch.LongTensor(index)
output = input.new(2, 6).fill_(0)
expected_output = [[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]]
expected_arg_output = [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]]
output, index, input = output.cuda(), index.cuda(), input.cuda()
_, arg_output = scatter_max_(output, index, input, dim=1)
print(output)
......@@ -33,7 +33,8 @@ def _scatter(name, dim, *data):
'specified dimension')
typename = type(data[0]).__name__.replace('Tensor', '')
func = getattr(ffi, 'scatter_{}_{}'.format(name, typename))
cuda = 'cuda_' if data[0].is_cuda else ''
func = getattr(ffi, 'scatter_{}_{}{}'.format(name, cuda, typename))
func(dim, *data)
return (data[0], data[3]) if has_arg_output(name) else data[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment