Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
04d6ec40
Commit
04d6ec40
authored
Jan 12, 2024
by
limm
Browse files
push v1.2.1 version
parent
1d2126aa
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
149 additions
and
1 deletion
+149
-1
test/test_conv.py
test/test_conv.py
+78
-0
test/test_weighting.py
test/test_weighting.py
+50
-0
test/utils.py
test/utils.py
+11
-0
torch_spline_conv/__init__.py
torch_spline_conv/__init__.py
+10
-1
No files found.
test/test_conv.py
0 → 100644
View file @
04d6ec40
from
itertools
import
product
import
pytest
import
torch
from
torch.autograd
import
gradcheck
from
torch_spline_conv
import
spline_conv
from
.utils
import
dtypes
,
devices
,
tensor
degrees
=
[
1
,
2
,
3
]
tests
=
[{
'x'
:
[[
9
,
10
],
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
]],
'edge_index'
:
[[
0
,
0
,
0
,
0
],
[
1
,
2
,
3
,
4
]],
'pseudo'
:
[[
0.25
,
0.125
],
[
0.25
,
0.375
],
[
0.75
,
0.625
],
[
0.75
,
0.875
]],
'weight'
:
[
[[
0.5
],
[
1
]],
[[
1.5
],
[
2
]],
[[
2.5
],
[
3
]],
[[
3.5
],
[
4
]],
[[
4.5
],
[
5
]],
[[
5.5
],
[
6
]],
[[
6.5
],
[
7
]],
[[
7.5
],
[
8
]],
[[
8.5
],
[
9
]],
[[
9.5
],
[
10
]],
[[
10.5
],
[
11
]],
[[
11.5
],
[
12
]],
],
'kernel_size'
:
[
3
,
4
],
'is_open_spline'
:
[
1
,
0
],
'root_weight'
:
[[
12.5
],
[
13
]],
'bias'
:
[
1
],
'expected'
:
[
[
1
+
12.5
*
9
+
13
*
10
+
(
8.5
+
40.5
+
107.5
+
101.5
)
/
4
],
[
1
+
12.5
*
1
+
13
*
2
],
[
1
+
12.5
*
3
+
13
*
4
],
[
1
+
12.5
*
5
+
13
*
6
],
[
1
+
12.5
*
7
+
13
*
8
],
]
}]
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
def
test_spline_conv_forward
(
test
,
dtype
,
device
):
x
=
tensor
(
test
[
'x'
],
dtype
,
device
)
edge_index
=
tensor
(
test
[
'edge_index'
],
torch
.
long
,
device
)
pseudo
=
tensor
(
test
[
'pseudo'
],
dtype
,
device
)
weight
=
tensor
(
test
[
'weight'
],
dtype
,
device
)
kernel_size
=
tensor
(
test
[
'kernel_size'
],
torch
.
long
,
device
)
is_open_spline
=
tensor
(
test
[
'is_open_spline'
],
torch
.
uint8
,
device
)
root_weight
=
tensor
(
test
[
'root_weight'
],
dtype
,
device
)
bias
=
tensor
(
test
[
'bias'
],
dtype
,
device
)
out
=
spline_conv
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
1
,
True
,
root_weight
,
bias
)
assert
out
.
tolist
()
==
test
[
'expected'
]
@
pytest
.
mark
.
parametrize
(
'degree,device'
,
product
(
degrees
,
devices
))
def
test_spline_basis_backward
(
degree
,
device
):
x
=
torch
.
rand
((
3
,
2
),
dtype
=
torch
.
double
,
device
=
device
)
x
.
requires_grad_
()
edge_index
=
tensor
([[
0
,
1
,
1
,
2
],
[
1
,
0
,
2
,
1
]],
torch
.
long
,
device
)
pseudo
=
torch
.
rand
((
4
,
3
),
dtype
=
torch
.
double
,
device
=
device
)
pseudo
.
requires_grad_
()
weight
=
torch
.
rand
((
125
,
2
,
4
),
dtype
=
torch
.
double
,
device
=
device
)
weight
.
requires_grad_
()
kernel_size
=
tensor
([
5
,
5
,
5
],
torch
.
long
,
device
)
is_open_spline
=
tensor
([
1
,
0
,
1
],
torch
.
uint8
,
device
)
root_weight
=
torch
.
rand
((
2
,
4
),
dtype
=
torch
.
double
,
device
=
device
)
root_weight
.
requires_grad_
()
bias
=
torch
.
rand
((
4
),
dtype
=
torch
.
double
,
device
=
device
)
bias
.
requires_grad_
()
data
=
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
degree
,
True
,
root_weight
,
bias
)
assert
gradcheck
(
spline_conv
,
data
,
eps
=
1e-6
,
atol
=
1e-4
)
is
True
test/test_weighting.py
0 → 100644
View file @
04d6ec40
from
itertools
import
product
import
pytest
import
torch
from
torch.autograd
import
gradcheck
from
torch_spline_conv
import
spline_basis
,
spline_weighting
from
.utils
import
dtypes
,
devices
,
tensor
tests
=
[{
'x'
:
[[
1
,
2
],
[
3
,
4
]],
'weight'
:
[[[
1
],
[
2
]],
[[
3
],
[
4
]],
[[
5
],
[
6
]],
[[
7
],
[
8
]]],
'basis'
:
[[
0.5
,
0
,
0.5
,
0
],
[
0
,
0
,
0.5
,
0.5
]],
'weight_index'
:
[[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
]],
'expected'
:
[
[
0.5
*
((
1
*
(
1
+
5
))
+
(
2
*
(
2
+
6
)))],
[
0.5
*
((
3
*
(
5
+
7
))
+
(
4
*
(
6
+
8
)))],
]
}]
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
def
test_spline_weighting_forward
(
test
,
dtype
,
device
):
x
=
tensor
(
test
[
'x'
],
dtype
,
device
)
weight
=
tensor
(
test
[
'weight'
],
dtype
,
device
)
basis
=
tensor
(
test
[
'basis'
],
dtype
,
device
)
weight_index
=
tensor
(
test
[
'weight_index'
],
torch
.
long
,
device
)
out
=
spline_weighting
(
x
,
weight
,
basis
,
weight_index
)
assert
out
.
tolist
()
==
test
[
'expected'
]
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_spline_weighting_backward
(
device
):
pseudo
=
torch
.
rand
((
4
,
2
),
dtype
=
torch
.
double
,
device
=
device
)
kernel_size
=
tensor
([
5
,
5
],
torch
.
long
,
device
)
is_open_spline
=
tensor
([
1
,
1
],
torch
.
uint8
,
device
)
degree
=
1
basis
,
weight_index
=
spline_basis
(
pseudo
,
kernel_size
,
is_open_spline
,
degree
)
basis
.
requires_grad_
()
x
=
torch
.
rand
((
4
,
2
),
dtype
=
torch
.
double
,
device
=
device
)
x
.
requires_grad_
()
weight
=
torch
.
rand
((
25
,
2
,
4
),
dtype
=
torch
.
double
,
device
=
device
)
weight
.
requires_grad_
()
data
=
(
x
,
weight
,
basis
,
weight_index
)
assert
gradcheck
(
spline_weighting
,
data
,
eps
=
1e-6
,
atol
=
1e-4
)
is
True
test/utils.py
0 → 100644
View file @
04d6ec40
import
torch
dtypes
=
[
torch
.
float
,
torch
.
double
]
devices
=
[
torch
.
device
(
'cpu'
)]
if
torch
.
cuda
.
is_available
():
devices
+=
[
torch
.
device
(
f
'cuda:
{
torch
.
cuda
.
current_device
()
}
'
)]
def
tensor
(
x
,
dtype
,
device
):
return
None
if
x
is
None
else
torch
.
tensor
(
x
,
dtype
=
dtype
,
device
=
device
)
torch_spline_conv/__init__.py
View file @
04d6ec40
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
__version__
=
'1.2.1'
__version__
=
'1.2.1'
suffix
=
'
hip
'
if
torch
.
cuda
.
is_available
()
else
'cpu'
suffix
=
'
cuda
'
if
torch
.
cuda
.
is_available
()
else
'cpu'
for
library
in
[
'_version'
,
'_basis'
,
'_weighting'
]:
for
library
in
[
'_version'
,
'_basis'
,
'_weighting'
]:
torch
.
ops
.
load_library
(
importlib
.
machinery
.
PathFinder
().
find_spec
(
torch
.
ops
.
load_library
(
importlib
.
machinery
.
PathFinder
().
find_spec
(
...
@@ -20,6 +20,15 @@ if torch.cuda.is_available(): # pragma: no cover
...
@@ -20,6 +20,15 @@ if torch.cuda.is_available(): # pragma: no cover
major
,
minor
=
int
(
str
(
cuda_version
)[
0
]),
int
(
str
(
cuda_version
)[
2
])
major
,
minor
=
int
(
str
(
cuda_version
)[
0
]),
int
(
str
(
cuda_version
)[
2
])
else
:
else
:
major
,
minor
=
int
(
str
(
cuda_version
)[
0
:
2
]),
int
(
str
(
cuda_version
)[
3
])
major
,
minor
=
int
(
str
(
cuda_version
)[
0
:
2
]),
int
(
str
(
cuda_version
)[
3
])
t_major
,
t_minor
=
[
int
(
x
)
for
x
in
torch
.
version
.
cuda
.
split
(
'.'
)]
if
t_major
!=
major
:
raise
RuntimeError
(
f
'Detected that PyTorch and torch_spline_conv were compiled with '
f
'different CUDA versions. PyTorch has CUDA version '
f
'
{
t_major
}
.
{
t_minor
}
and torch_spline_conv has CUDA version '
f
'
{
major
}
.
{
minor
}
. Please reinstall the torch_spline_conv that '
f
'matches your PyTorch install.'
)
from
.basis
import
spline_basis
# noqa
from
.basis
import
spline_basis
# noqa
from
.weighting
import
spline_weighting
# noqa
from
.weighting
import
spline_weighting
# noqa
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment