"tests/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "365e8461ac9045da3210b79522302d7706d943ad"
Commit fb380737 authored by rusty1s's avatar rusty1s
Browse files

parameterize tests

parent 2a571e28
...@@ -2,5 +2,7 @@ __pycache__/ ...@@ -2,5 +2,7 @@ __pycache__/
_ext/ _ext/
build/ build/
dist/ dist/
.cache/
.eggs/
*.egg-info/ *.egg-info/
*.so *.so
[aliases]
test=pytest
[tool:pytest]
addopts = --capture=no
...@@ -3,6 +3,11 @@ from setuptools import setup, find_packages ...@@ -3,6 +3,11 @@ from setuptools import setup, find_packages
import build # noqa import build # noqa
install_requires = ['cffi']
setup_requires = ['pytest-runner', 'cffi']
tests_require = ['pytest']
docs_require = ['Sphinx', 'sphinx_rtd_theme']
setup( setup(
name='torch_scatter', name='torch_scatter',
version='0.1', version='0.1',
...@@ -10,8 +15,10 @@ setup( ...@@ -10,8 +15,10 @@ setup(
url='https://github.com/rusty1s/pytorch_scatter', url='https://github.com/rusty1s/pytorch_scatter',
author='Matthias Fey', author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de', author_email='matthias.fey@tu-dortmund.de',
install_requires=['cffi>=1.0.0'], install_requires=install_requires,
setup_requires=['cffi>=1.0.0'], setup_requires=setup_requires,
tests_require=tests_require,
docs_require=docs_require,
packages=find_packages(exclude=['build']), packages=find_packages(exclude=['build']),
ext_package='', ext_package='',
cffi_modules=[osp.join(osp.dirname(__file__), 'build.py:ffi')], cffi_modules=[osp.join(osp.dirname(__file__), 'build.py:ffi')],
......
from nose.tools import assert_equal import pytest
import torch import torch
from torch.autograd import Variable from torch.autograd import Variable
from torch_scatter import scatter_add_, scatter_add from torch_scatter import scatter_add_, scatter_add
from .utils import tensor_strs, Tensor
def test_scatter_add(): @pytest.mark.parametrize('str', tensor_strs)
def test_scatter_add(str):
input = [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]] input = [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]
index = [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]] index = [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]
input = torch.FloatTensor(input) input = Tensor(str, input)
index = torch.LongTensor(index) index = torch.LongTensor(index)
output = input.new(2, 6).fill_(0) output = input.new(2, 6).fill_(0)
expected_output = [[0, 0, 4, 3, 3, 0], [2, 4, 4, 0, 0, 0]] expected_output = [[0, 0, 4, 3, 3, 0], [2, 4, 4, 0, 0, 0]]
scatter_add_(output, index, input, dim=1) scatter_add_(output, index, input, dim=1)
assert_equal(output.tolist(), expected_output) assert output.tolist() == expected_output
output = scatter_add(index, input, dim=1) output = scatter_add(index, input, dim=1)
assert_equal(output.tolist(), expected_output) assert output.tolist(), expected_output
output = Variable(output).fill_(0) output = Variable(output).fill_(0)
index = Variable(index) index = Variable(index)
...@@ -25,7 +27,7 @@ def test_scatter_add(): ...@@ -25,7 +27,7 @@ def test_scatter_add():
scatter_add_(output, index, input, dim=1) scatter_add_(output, index, input, dim=1)
grad_output = [[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]] grad_output = [[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]]
grad_output = torch.FloatTensor(grad_output) grad_output = Tensor(str, grad_output)
output.backward(grad_output) output.backward(grad_output)
assert_equal(index.data.tolist(), input.grad.data.tolist()) assert index.data.tolist() == input.grad.data.tolist()
import torch
from torch._tensor_docs import tensor_classes
tensor_strs = [t[:-4] for t in tensor_classes]
def Tensor(str, x):
tensor = getattr(torch, str)
return tensor(x)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment