Commit fb380737 authored by rusty1s's avatar rusty1s
Browse files

parameterize tests

parent 2a571e28
......@@ -2,5 +2,7 @@ __pycache__/
_ext/
build/
dist/
.cache/
.eggs/
*.egg-info/
*.so
[aliases]
test=pytest
[tool:pytest]
addopts = --capture=no
......@@ -3,6 +3,11 @@ from setuptools import setup, find_packages
import build # noqa
install_requires = ['cffi']
setup_requires = ['pytest-runner', 'cffi']
tests_require = ['pytest']
docs_require = ['Sphinx', 'sphinx_rtd_theme']
setup(
name='torch_scatter',
version='0.1',
......@@ -10,8 +15,10 @@ setup(
url='https://github.com/rusty1s/pytorch_scatter',
author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de',
install_requires=['cffi>=1.0.0'],
setup_requires=['cffi>=1.0.0'],
install_requires=install_requires,
setup_requires=setup_requires,
tests_require=tests_require,
docs_require=docs_require,
packages=find_packages(exclude=['build']),
ext_package='',
cffi_modules=[osp.join(osp.dirname(__file__), 'build.py:ffi')],
......
from nose.tools import assert_equal
import pytest
import torch
from torch.autograd import Variable
from torch_scatter import scatter_add_, scatter_add
from .utils import tensor_strs, Tensor
def test_scatter_add():
@pytest.mark.parametrize('str', tensor_strs)
def test_scatter_add(str):
input = [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]
index = [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]
input = torch.FloatTensor(input)
input = Tensor(str, input)
index = torch.LongTensor(index)
output = input.new(2, 6).fill_(0)
expected_output = [[0, 0, 4, 3, 3, 0], [2, 4, 4, 0, 0, 0]]
scatter_add_(output, index, input, dim=1)
assert_equal(output.tolist(), expected_output)
assert output.tolist() == expected_output
output = scatter_add(index, input, dim=1)
assert_equal(output.tolist(), expected_output)
assert output.tolist(), expected_output
output = Variable(output).fill_(0)
index = Variable(index)
......@@ -25,7 +27,7 @@ def test_scatter_add():
scatter_add_(output, index, input, dim=1)
grad_output = [[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]]
grad_output = torch.FloatTensor(grad_output)
grad_output = Tensor(str, grad_output)
output.backward(grad_output)
assert_equal(index.data.tolist(), input.grad.data.tolist())
assert index.data.tolist() == input.grad.data.tolist()
import torch
from torch._tensor_docs import tensor_classes
tensor_strs = [t[:-4] for t in tensor_classes]
def Tensor(str, x):
tensor = getattr(torch, str)
return tensor(x)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment