Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
731f8cd2
Unverified
Commit
731f8cd2
authored
Nov 29, 2022
by
Matthias Fey
Committed by
GitHub
Nov 29, 2022
Browse files
Merge pull request #36 from rusty1s/fix_test
Remove `test/__init__.py`
parents
99f8b989
67b76d10
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
27 additions
and
17 deletions
+27
-17
test/__init__.py
test/__init__.py
+0
-0
test/test_basis.py
test/test_basis.py
+9
-4
test/test_conv.py
test/test_conv.py
+8
-8
test/test_weighting.py
test/test_weighting.py
+6
-3
torch_spline_conv/testing.py
torch_spline_conv/testing.py
+4
-2
No files found.
test/__init__.py
deleted
100644 → 0
View file @
99f8b989
test/test_basis.py
View file @
731f8cd2
...
@@ -3,8 +3,7 @@ from itertools import product
...
@@ -3,8 +3,7 @@ from itertools import product
import
pytest
import
pytest
import
torch
import
torch
from
torch_spline_conv
import
spline_basis
from
torch_spline_conv
import
spline_basis
from
torch_spline_conv.testing
import
devices
,
dtypes
,
tensor
from
.utils
import
dtypes
,
devices
,
tensor
tests
=
[{
tests
=
[{
'pseudo'
:
[[
0
],
[
0.0625
],
[
0.25
],
[
0.75
],
[
0.9375
],
[
1
]],
'pseudo'
:
[[
0
],
[
0.0625
],
[
0.25
],
[
0.75
],
[
0.9375
],
[
1
]],
...
@@ -29,12 +28,18 @@ tests = [{
...
@@ -29,12 +28,18 @@ tests = [{
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
def
test_spline_basis_forward
(
test
,
dtype
,
device
):
def
test_spline_basis_forward
(
test
,
dtype
,
device
):
if
dtype
==
torch
.
bfloat16
and
device
==
torch
.
device
(
'cuda:0'
):
return
pseudo
=
tensor
(
test
[
'pseudo'
],
dtype
,
device
)
pseudo
=
tensor
(
test
[
'pseudo'
],
dtype
,
device
)
kernel_size
=
tensor
(
test
[
'kernel_size'
],
torch
.
long
,
device
)
kernel_size
=
tensor
(
test
[
'kernel_size'
],
torch
.
long
,
device
)
is_open_spline
=
tensor
(
test
[
'is_open_spline'
],
torch
.
uint8
,
device
)
is_open_spline
=
tensor
(
test
[
'is_open_spline'
],
torch
.
uint8
,
device
)
basis
=
tensor
(
test
[
'basis'
],
dtype
,
device
)
weight_index
=
tensor
(
test
[
'weight_index'
],
dtype
,
device
)
degree
=
1
degree
=
1
basis
,
weight_index
=
spline_basis
(
pseudo
,
kernel_size
,
is_open_spline
,
basis
,
weight_index
=
spline_basis
(
pseudo
,
kernel_size
,
is_open_spline
,
degree
)
degree
)
assert
basis
.
tolist
()
==
test
[
'
basis
'
]
assert
torch
.
allclose
(
basis
,
basis
)
assert
weight_index
.
tolist
()
==
test
[
'
weight_index
'
]
assert
torch
.
allclose
(
weight_index
,
weight_index
)
test/test_conv.py
View file @
731f8cd2
...
@@ -4,8 +4,7 @@ import pytest
...
@@ -4,8 +4,7 @@ import pytest
import
torch
import
torch
from
torch.autograd
import
gradcheck
from
torch.autograd
import
gradcheck
from
torch_spline_conv
import
spline_conv
from
torch_spline_conv
import
spline_conv
from
torch_spline_conv.testing
import
devices
,
dtypes
,
tensor
from
.utils
import
dtypes
,
devices
,
tensor
degrees
=
[
1
,
2
,
3
]
degrees
=
[
1
,
2
,
3
]
...
@@ -43,6 +42,9 @@ tests = [{
...
@@ -43,6 +42,9 @@ tests = [{
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
def
test_spline_conv_forward
(
test
,
dtype
,
device
):
def
test_spline_conv_forward
(
test
,
dtype
,
device
):
if
dtype
==
torch
.
bfloat16
and
device
==
torch
.
device
(
'cuda:0'
):
return
x
=
tensor
(
test
[
'x'
],
dtype
,
device
)
x
=
tensor
(
test
[
'x'
],
dtype
,
device
)
edge_index
=
tensor
(
test
[
'edge_index'
],
torch
.
long
,
device
)
edge_index
=
tensor
(
test
[
'edge_index'
],
torch
.
long
,
device
)
pseudo
=
tensor
(
test
[
'pseudo'
],
dtype
,
device
)
pseudo
=
tensor
(
test
[
'pseudo'
],
dtype
,
device
)
...
@@ -51,15 +53,13 @@ def test_spline_conv_forward(test, dtype, device):
...
@@ -51,15 +53,13 @@ def test_spline_conv_forward(test, dtype, device):
is_open_spline
=
tensor
(
test
[
'is_open_spline'
],
torch
.
uint8
,
device
)
is_open_spline
=
tensor
(
test
[
'is_open_spline'
],
torch
.
uint8
,
device
)
root_weight
=
tensor
(
test
[
'root_weight'
],
dtype
,
device
)
root_weight
=
tensor
(
test
[
'root_weight'
],
dtype
,
device
)
bias
=
tensor
(
test
[
'bias'
],
dtype
,
device
)
bias
=
tensor
(
test
[
'bias'
],
dtype
,
device
)
expected
=
tensor
(
test
[
'expected'
],
dtype
,
device
)
out
=
spline_conv
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
out
=
spline_conv
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
1
,
True
,
root_weight
,
bias
)
is_open_spline
,
1
,
True
,
root_weight
,
bias
)
if
dtype
==
torch
.
bfloat16
:
target
=
torch
.
tensor
(
test
[
'expected'
])
error
=
1e-2
if
dtype
==
torch
.
bfloat16
else
1e-7
assert
torch
.
allclose
(
out
.
to
(
torch
.
float
),
target
,
assert
torch
.
allclose
(
out
,
expected
,
rtol
=
error
,
atol
=
error
)
rtol
=
1e-2
,
atol
=
1e-2
)
else
:
assert
out
.
tolist
()
==
test
[
'expected'
]
@
pytest
.
mark
.
parametrize
(
'degree,device'
,
product
(
degrees
,
devices
))
@
pytest
.
mark
.
parametrize
(
'degree,device'
,
product
(
degrees
,
devices
))
...
...
test/test_weighting.py
View file @
731f8cd2
...
@@ -4,8 +4,7 @@ import pytest
...
@@ -4,8 +4,7 @@ import pytest
import
torch
import
torch
from
torch.autograd
import
gradcheck
from
torch.autograd
import
gradcheck
from
torch_spline_conv
import
spline_basis
,
spline_weighting
from
torch_spline_conv
import
spline_basis
,
spline_weighting
from
torch_spline_conv.testing
import
devices
,
dtypes
,
tensor
from
.utils
import
dtypes
,
devices
,
tensor
tests
=
[{
tests
=
[{
'x'
:
[[
1
,
2
],
[
3
,
4
]],
'x'
:
[[
1
,
2
],
[
3
,
4
]],
...
@@ -21,13 +20,17 @@ tests = [{
...
@@ -21,13 +20,17 @@ tests = [{
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
def
test_spline_weighting_forward
(
test
,
dtype
,
device
):
def
test_spline_weighting_forward
(
test
,
dtype
,
device
):
if
dtype
==
torch
.
bfloat16
and
device
==
torch
.
device
(
'cuda:0'
):
return
x
=
tensor
(
test
[
'x'
],
dtype
,
device
)
x
=
tensor
(
test
[
'x'
],
dtype
,
device
)
weight
=
tensor
(
test
[
'weight'
],
dtype
,
device
)
weight
=
tensor
(
test
[
'weight'
],
dtype
,
device
)
basis
=
tensor
(
test
[
'basis'
],
dtype
,
device
)
basis
=
tensor
(
test
[
'basis'
],
dtype
,
device
)
weight_index
=
tensor
(
test
[
'weight_index'
],
torch
.
long
,
device
)
weight_index
=
tensor
(
test
[
'weight_index'
],
torch
.
long
,
device
)
expected
=
tensor
(
test
[
'expected'
],
dtype
,
device
)
out
=
spline_weighting
(
x
,
weight
,
basis
,
weight_index
)
out
=
spline_weighting
(
x
,
weight
,
basis
,
weight_index
)
assert
out
.
tolist
()
==
test
[
'
expected
'
]
assert
torch
.
allclose
(
out
,
expected
)
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
...
...
t
est/utils
.py
→
t
orch_spline_conv/testing
.py
View file @
731f8cd2
from
typing
import
Any
import
torch
import
torch
dtypes
=
[
torch
.
float
,
torch
.
double
,
torch
.
bfloat16
]
dtypes
=
[
torch
.
float
,
torch
.
double
,
torch
.
bfloat16
]
devices
=
[
torch
.
device
(
'cpu'
)]
devices
=
[
torch
.
device
(
'cpu'
)]
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
devices
+=
[
torch
.
device
(
f
'cuda:
{
torch
.
cuda
.
current_device
()
}
'
)]
devices
+=
[
torch
.
device
(
'cuda:
0
'
)]
def
tensor
(
x
,
dtype
,
device
):
def
tensor
(
x
:
Any
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
return
None
if
x
is
None
else
torch
.
tensor
(
x
,
dtype
=
dtype
,
device
=
device
)
return
None
if
x
is
None
else
torch
.
tensor
(
x
,
dtype
=
dtype
,
device
=
device
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment