Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
fb380737
Commit
fb380737
authored
Dec 16, 2017
by
rusty1s
Browse files
parameterize tests
parent
2a571e28
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
35 additions
and
10 deletions
+35
-10
.gitignore
.gitignore
+2
-0
setup.cfg
setup.cfg
+5
-0
setup.py
setup.py
+9
-2
test/__init__.py
test/__init__.py
+0
-0
test/test_add.py
test/test_add.py
+10
-8
test/utils.py
test/utils.py
+9
-0
No files found.
.gitignore
View file @
fb380737
...
@@ -2,5 +2,7 @@ __pycache__/
...
@@ -2,5 +2,7 @@ __pycache__/
_ext/
_ext/
build/
build/
dist/
dist/
.cache/
.eggs/
*.egg-info/
*.egg-info/
*.so
*.so
setup.cfg
0 → 100644
View file @
fb380737
[aliases]
test=pytest
[tool:pytest]
addopts = --capture=no
setup.py
View file @
fb380737
...
@@ -3,6 +3,11 @@ from setuptools import setup, find_packages
...
@@ -3,6 +3,11 @@ from setuptools import setup, find_packages
import
build
# noqa
import
build
# noqa
install_requires
=
[
'cffi'
]
setup_requires
=
[
'pytest-runner'
,
'cffi'
]
tests_require
=
[
'pytest'
]
docs_require
=
[
'Sphinx'
,
'sphinx_rtd_theme'
]
setup
(
setup
(
name
=
'torch_scatter'
,
name
=
'torch_scatter'
,
version
=
'0.1'
,
version
=
'0.1'
,
...
@@ -10,8 +15,10 @@ setup(
...
@@ -10,8 +15,10 @@ setup(
url
=
'https://github.com/rusty1s/pytorch_scatter'
,
url
=
'https://github.com/rusty1s/pytorch_scatter'
,
author
=
'Matthias Fey'
,
author
=
'Matthias Fey'
,
author_email
=
'matthias.fey@tu-dortmund.de'
,
author_email
=
'matthias.fey@tu-dortmund.de'
,
install_requires
=
[
'cffi>=1.0.0'
],
install_requires
=
install_requires
,
setup_requires
=
[
'cffi>=1.0.0'
],
setup_requires
=
setup_requires
,
tests_require
=
tests_require
,
docs_require
=
docs_require
,
packages
=
find_packages
(
exclude
=
[
'build'
]),
packages
=
find_packages
(
exclude
=
[
'build'
]),
ext_package
=
''
,
ext_package
=
''
,
cffi_modules
=
[
osp
.
join
(
osp
.
dirname
(
__file__
),
'build.py:ffi'
)],
cffi_modules
=
[
osp
.
join
(
osp
.
dirname
(
__file__
),
'build.py:ffi'
)],
...
...
test/__init__.py
0 → 100644
View file @
fb380737
test/test_add.py
View file @
fb380737
from
nose.tools
import
assert_equal
import
pytest
import
torch
import
torch
from
torch.autograd
import
Variable
from
torch.autograd
import
Variable
from
torch_scatter
import
scatter_add_
,
scatter_add
from
torch_scatter
import
scatter_add_
,
scatter_add
from
.utils
import
tensor_strs
,
Tensor
def
test_scatter_add
():
@
pytest
.
mark
.
parametrize
(
'str'
,
tensor_strs
)
def
test_scatter_add
(
str
):
input
=
[[
2
,
0
,
1
,
4
,
3
],
[
0
,
2
,
1
,
3
,
4
]]
input
=
[[
2
,
0
,
1
,
4
,
3
],
[
0
,
2
,
1
,
3
,
4
]]
index
=
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]]
index
=
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]]
input
=
torch
.
Float
Tensor
(
input
)
input
=
Tensor
(
str
,
input
)
index
=
torch
.
LongTensor
(
index
)
index
=
torch
.
LongTensor
(
index
)
output
=
input
.
new
(
2
,
6
).
fill_
(
0
)
output
=
input
.
new
(
2
,
6
).
fill_
(
0
)
expected_output
=
[[
0
,
0
,
4
,
3
,
3
,
0
],
[
2
,
4
,
4
,
0
,
0
,
0
]]
expected_output
=
[[
0
,
0
,
4
,
3
,
3
,
0
],
[
2
,
4
,
4
,
0
,
0
,
0
]]
scatter_add_
(
output
,
index
,
input
,
dim
=
1
)
scatter_add_
(
output
,
index
,
input
,
dim
=
1
)
assert
_equal
(
output
.
tolist
()
,
expected_output
)
assert
output
.
tolist
()
==
expected_output
output
=
scatter_add
(
index
,
input
,
dim
=
1
)
output
=
scatter_add
(
index
,
input
,
dim
=
1
)
assert
_equal
(
output
.
tolist
(),
expected_output
)
assert
output
.
tolist
(),
expected_output
output
=
Variable
(
output
).
fill_
(
0
)
output
=
Variable
(
output
).
fill_
(
0
)
index
=
Variable
(
index
)
index
=
Variable
(
index
)
...
@@ -25,7 +27,7 @@ def test_scatter_add():
...
@@ -25,7 +27,7 @@ def test_scatter_add():
scatter_add_
(
output
,
index
,
input
,
dim
=
1
)
scatter_add_
(
output
,
index
,
input
,
dim
=
1
)
grad_output
=
[[
0
,
1
,
2
,
3
,
4
,
5
],
[
0
,
1
,
2
,
3
,
4
,
5
]]
grad_output
=
[[
0
,
1
,
2
,
3
,
4
,
5
],
[
0
,
1
,
2
,
3
,
4
,
5
]]
grad_output
=
torch
.
Float
Tensor
(
grad_output
)
grad_output
=
Tensor
(
str
,
grad_output
)
output
.
backward
(
grad_output
)
output
.
backward
(
grad_output
)
assert
_equal
(
index
.
data
.
tolist
()
,
input
.
grad
.
data
.
tolist
()
)
assert
index
.
data
.
tolist
()
==
input
.
grad
.
data
.
tolist
()
test/utils.py
0 → 100644
View file @
fb380737
import
torch
from
torch._tensor_docs
import
tensor_classes
tensor_strs
=
[
t
[:
-
4
]
for
t
in
tensor_classes
]
def
Tensor
(
str
,
x
):
tensor
=
getattr
(
torch
,
str
)
return
tensor
(
x
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment