Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
1038a59e
Commit
1038a59e
authored
May 24, 2018
by
rusty1s
Browse files
degree normalize after root_weight
parent
1b96f53b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
8 deletions
+11
-8
test/test_conv.py
test/test_conv.py
+7
-6
torch_spline_conv/conv.py
torch_spline_conv/conv.py
+4
-2
No files found.
test/test_conv.py
View file @
1038a59e
...
@@ -31,11 +31,11 @@ tests = [{
...
@@ -31,11 +31,11 @@ tests = [{
'root_weight'
:
[[
12.5
],
[
13
]],
'root_weight'
:
[[
12.5
],
[
13
]],
'bias'
:
[
1
],
'bias'
:
[
1
],
'expected'
:
[
'expected'
:
[
[
1
+
12.5
*
9
+
13
*
10
+
(
8.5
+
40.5
+
107.5
+
101.5
)
/
4
]
,
1
+
(
12.5
*
9
+
13
*
10
+
8.5
+
40.5
+
107.5
+
101.5
)
/
5
,
[
1
+
12.5
*
1
+
13
*
2
]
,
1
+
12.5
*
1
+
13
*
2
,
[
1
+
12.5
*
3
+
13
*
4
]
,
1
+
12.5
*
3
+
13
*
4
,
[
1
+
12.5
*
5
+
13
*
6
]
,
1
+
12.5
*
5
+
13
*
6
,
[
1
+
12.5
*
7
+
13
*
8
]
,
1
+
12.5
*
7
+
13
*
8
,
]
]
}]
}]
...
@@ -53,7 +53,8 @@ def test_spline_conv_forward(test, dtype, device):
...
@@ -53,7 +53,8 @@ def test_spline_conv_forward(test, dtype, device):
out
=
SplineConv
.
apply
(
src
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
out
=
SplineConv
.
apply
(
src
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
1
,
root_weight
,
bias
)
is_open_spline
,
1
,
root_weight
,
bias
)
assert
out
.
tolist
()
==
test
[
'expected'
]
assert
list
(
out
.
size
())
==
[
5
,
1
]
assert
pytest
.
approx
(
out
.
view
(
-
1
).
tolist
())
==
test
[
'expected'
]
@
pytest
.
mark
.
parametrize
(
'degree,device'
,
product
(
degrees
.
keys
(),
devices
))
@
pytest
.
mark
.
parametrize
(
'degree,device'
,
product
(
degrees
.
keys
(),
devices
))
...
...
torch_spline_conv/conv.py
View file @
1038a59e
...
@@ -62,13 +62,15 @@ class SplineConv(object):
...
@@ -62,13 +62,15 @@ class SplineConv(object):
row_expand
=
row
.
unsqueeze
(
-
1
).
expand_as
(
out
)
row_expand
=
row
.
unsqueeze
(
-
1
).
expand_as
(
out
)
out
=
src
.
new_zeros
((
n
,
m_out
)).
scatter_add_
(
0
,
row_expand
,
out
)
out
=
src
.
new_zeros
((
n
,
m_out
)).
scatter_add_
(
0
,
row_expand
,
out
)
# Normalize out by node degree.
deg
=
node_degree
(
row
,
n
,
out
.
dtype
,
out
.
device
)
deg
=
node_degree
(
row
,
n
,
out
.
dtype
,
out
.
device
)
out
/=
deg
.
unsqueeze
(
-
1
).
clamp
(
min
=
1
)
# Weight root node separately (if wished).
# Weight root node separately (if wished).
if
root_weight
is
not
None
:
if
root_weight
is
not
None
:
out
+=
torch
.
mm
(
src
,
root_weight
)
out
+=
torch
.
mm
(
src
,
root_weight
)
deg
+=
1
# Normalize out by node degree.
out
/=
deg
.
unsqueeze
(
-
1
).
clamp
(
min
=
1
)
# Add bias (if wished).
# Add bias (if wished).
if
bias
is
not
None
:
if
bias
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment