Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
8e6635b3
Unverified
Commit
8e6635b3
authored
Nov 29, 2022
by
Matthias Fey
Committed by
GitHub
Nov 29, 2022
Browse files
fix test (#340)
parent
003abd58
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
18 additions
and
20 deletions
+18
-20
test/__init__.py
test/__init__.py
+0
-0
test/test_broadcasting.py
test/test_broadcasting.py
+1
-2
test/test_gather.py
test/test_gather.py
+2
-3
test/test_multi_gpu.py
test/test_multi_gpu.py
+1
-2
test/test_scatter.py
test/test_scatter.py
+1
-2
test/test_segment.py
test/test_segment.py
+2
-3
test/test_zero_tensors.py
test/test_zero_tensors.py
+3
-4
torch_scatter/testing.py
torch_scatter/testing.py
+8
-4
No files found.
test/__init__.py
deleted
100644 → 0
View file @
003abd58
test/test_broadcasting.py
View file @
8e6635b3
...
@@ -3,8 +3,7 @@ from itertools import product
...
@@ -3,8 +3,7 @@ from itertools import product
import
pytest
import
pytest
import
torch
import
torch
from
torch_scatter
import
scatter
from
torch_scatter
import
scatter
from
torch_scatter.testing
import
devices
,
reductions
from
.utils
import
reductions
,
devices
@
pytest
.
mark
.
parametrize
(
'reduce,device'
,
product
(
reductions
,
devices
))
@
pytest
.
mark
.
parametrize
(
'reduce,device'
,
product
(
reductions
,
devices
))
...
...
test/test_gather.py
View file @
8e6635b3
...
@@ -3,9 +3,8 @@ from itertools import product
...
@@ -3,9 +3,8 @@ from itertools import product
import
pytest
import
pytest
import
torch
import
torch
from
torch.autograd
import
gradcheck
from
torch.autograd
import
gradcheck
from
torch_scatter
import
gather_csr
,
gather_coo
from
torch_scatter
import
gather_coo
,
gather_csr
from
torch_scatter.testing
import
devices
,
dtypes
,
tensor
from
.utils
import
tensor
,
dtypes
,
devices
tests
=
[
tests
=
[
{
{
...
...
test/test_multi_gpu.py
View file @
8e6635b3
...
@@ -3,8 +3,7 @@ from itertools import product
...
@@ -3,8 +3,7 @@ from itertools import product
import
pytest
import
pytest
import
torch
import
torch
import
torch_scatter
import
torch_scatter
from
torch_scatter.testing
import
dtypes
,
reductions
,
tensor
from
.utils
import
reductions
,
tensor
,
dtypes
tests
=
[
tests
=
[
{
{
...
...
test/test_scatter.py
View file @
8e6635b3
...
@@ -4,8 +4,7 @@ import pytest
...
@@ -4,8 +4,7 @@ import pytest
import
torch
import
torch
import
torch_scatter
import
torch_scatter
from
torch.autograd
import
gradcheck
from
torch.autograd
import
gradcheck
from
torch_scatter.testing
import
devices
,
dtypes
,
reductions
,
tensor
from
.utils
import
devices
,
dtypes
,
reductions
,
tensor
reductions
=
reductions
+
[
'mul'
]
reductions
=
reductions
+
[
'mul'
]
...
...
test/test_segment.py
View file @
8e6635b3
...
@@ -2,10 +2,9 @@ from itertools import product
...
@@ -2,10 +2,9 @@ from itertools import product
import
pytest
import
pytest
import
torch
import
torch
from
torch.autograd
import
gradcheck
import
torch_scatter
import
torch_scatter
from
torch.autograd
import
gradcheck
from
.utils
import
reductions
,
tensor
,
dtypes
,
devices
from
torch_scatter.testing
import
devices
,
dtypes
,
reductions
,
tensor
tests
=
[
tests
=
[
{
{
...
...
test/test_zero_tensors.py
View file @
8e6635b3
...
@@ -2,10 +2,9 @@ from itertools import product
...
@@ -2,10 +2,9 @@ from itertools import product
import
pytest
import
pytest
import
torch
import
torch
from
torch_scatter
import
scatter
,
segment_coo
,
gather_coo
from
torch_scatter
import
(
gather_coo
,
gather_csr
,
scatter
,
segment_coo
,
from
torch_scatter
import
segment_csr
,
gather_csr
segment_csr
)
from
torch_scatter.testing
import
devices
,
grad_dtypes
,
reductions
,
tensor
from
.utils
import
reductions
,
tensor
,
grad_dtypes
,
devices
@
pytest
.
mark
.
parametrize
(
'reduce,dtype,device'
,
@
pytest
.
mark
.
parametrize
(
'reduce,dtype,device'
,
...
...
t
est/utils
.py
→
t
orch_scatter/testing
.py
View file @
8e6635b3
from
typing
import
Any
import
torch
import
torch
reductions
=
[
'sum'
,
'add'
,
'mean'
,
'min'
,
'max'
]
reductions
=
[
'sum'
,
'add'
,
'mean'
,
'min'
,
'max'
]
dtypes
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
,
torch
.
double
,
dtypes
=
[
torch
.
int
,
torch
.
long
]
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
,
torch
.
double
,
torch
.
int
,
torch
.
long
]
grad_dtypes
=
[
torch
.
float
,
torch
.
double
]
grad_dtypes
=
[
torch
.
float
,
torch
.
double
]
devices
=
[
torch
.
device
(
'cpu'
)]
devices
=
[
torch
.
device
(
'cpu'
)]
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
devices
+=
[
torch
.
device
(
f
'cuda:
{
torch
.
cuda
.
current_device
()
}
'
)]
devices
+=
[
torch
.
device
(
'cuda:
0
'
)]
def
tensor
(
x
,
dtype
,
device
):
def
tensor
(
x
:
Any
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
return
None
if
x
is
None
else
torch
.
tensor
(
x
,
device
=
device
).
to
(
dtype
)
return
None
if
x
is
None
else
torch
.
tensor
(
x
,
device
=
device
).
to
(
dtype
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment