Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
ac26fc19
Commit
ac26fc19
authored
Feb 27, 2020
by
rusty1s
Browse files
prepare tracing
parent
d3169766
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
63 additions
and
70 deletions
+63
-70
torch_spline_conv/conv.py
torch_spline_conv/conv.py
+34
-31
torch_spline_conv/utils/__init__.py
torch_spline_conv/utils/__init__.py
+0
-0
torch_spline_conv/utils/degree.py
torch_spline_conv/utils/degree.py
+0
-7
torch_spline_conv/weighting.py
torch_spline_conv/weighting.py
+29
-32
No files found.
torch_spline_conv/conv.py
View file @
ac26fc19
import
torch
from
typing
import
Optional
from
.basis
import
SplineBasis
from
.weighting
import
SplineWeighting
import
torch
from
.utils.degree
import
degree
as
node_degree
from
.basis
import
spline_basis
from
.weighting
import
spline_weighting
class
SplineConv
(
object
):
@
torch
.
jit
.
script
def
spline_conv
(
x
:
torch
.
Tensor
,
edge_index
:
torch
.
Tensor
,
pseudo
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
kernel_size
:
torch
.
Tensor
,
is_open_spline
:
torch
.
Tensor
,
degree
:
int
=
1
,
norm
:
bool
=
True
,
root_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
r
"""Applies the spline-based convolution operator :math:`(f \star g)(i) =
\frac{1}{|\mathcal{N}(i)|} \sum_{l=1}^{M_{in}} \sum_{j \in \mathcal{N}(i)}
f_l(j) \cdot g_l(u(i, j))` over several node features of an input graph.
...
...
@@ -38,37 +44,34 @@ class SplineConv(object):
:rtype: :class:`Tensor`
"""
@
staticmethod
def
apply
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
degree
=
1
,
norm
=
True
,
root_weight
=
None
,
bias
=
None
):
x
=
x
.
unsqueeze
(
-
1
)
if
x
.
dim
()
==
1
else
x
pseudo
=
pseudo
.
unsqueeze
(
-
1
)
if
pseudo
.
dim
()
==
1
else
pseudo
x
=
x
.
unsqueeze
(
-
1
)
if
x
.
dim
()
==
1
else
x
pseudo
=
pseudo
.
unsqueeze
(
-
1
)
if
pseudo
.
dim
()
==
1
else
pseudo
row
,
col
=
edge_index
N
,
E
,
M_out
=
x
.
size
(
0
),
row
.
size
(
0
),
weight
.
size
(
2
)
row
,
col
=
edge_index
n
,
m_out
=
x
.
size
(
0
),
weight
.
size
(
2
)
# Weight each node.
basis
,
weight_index
=
spline_basis
(
pseudo
,
kernel_size
,
is_open_spline
,
degree
)
# Weight each node.
basis
,
weight_index
=
SplineBasis
.
apply
(
pseudo
,
kernel_size
,
is_open_spline
,
degree
)
weight_index
=
weight_index
.
detach
()
out
=
SplineWeighting
.
apply
(
x
[
col
],
weight
,
basis
,
weight_index
)
out
=
spline_weighting
(
x
[
col
],
weight
,
basis
,
weight_index
)
# Convert
e
x
m
_out to
n
x
m
_out features.
row_expand
=
row
.
unsqueeze
(
-
1
).
expand_as
(
out
)
out
=
x
.
new_zeros
((
n
,
m
_out
)).
scatter_add_
(
0
,
row_expand
,
out
)
# Convert
E
x
M
_out to
N
x
M
_out features.
row_expand
ed
=
row
.
unsqueeze
(
-
1
).
expand_as
(
out
)
out
=
x
.
new_zeros
((
N
,
M
_out
)).
scatter_add_
(
0
,
row_expand
ed
,
out
)
# Normalize out by node degree (if wished).
if
norm
:
deg
=
node_degree
(
row
,
n
,
out
.
dtype
,
out
.
device
)
out
=
out
/
deg
.
unsqueeze
(
-
1
).
clamp
(
min
=
1
)
# Normalize out by node degree (if wished).
if
norm
:
deg
=
out
.
new_zeros
(
N
).
scatter_add_
(
0
,
row
,
out
.
new_ones
(
E
)
)
out
=
out
/
deg
.
unsqueeze
(
-
1
).
clamp
_
(
min
=
1
)
# Weight root node separately (if wished).
if
root_weight
is
not
None
:
out
=
out
+
torch
.
m
m
(
x
,
root_weight
)
# Weight root node separately (if wished).
if
root_weight
is
not
None
:
out
=
out
+
torch
.
m
atmul
(
x
,
root_weight
)
# Add bias (if wished).
if
bias
is
not
None
:
out
=
out
+
bias
# Add bias (if wished).
if
bias
is
not
None
:
out
=
out
+
bias
return
out
return
out
torch_spline_conv/utils/__init__.py
deleted
100644 → 0
View file @
d3169766
torch_spline_conv/utils/degree.py
deleted
100644 → 0
View file @
d3169766
import
torch
def
degree
(
index
,
num_nodes
=
None
,
dtype
=
None
,
device
=
None
):
num_nodes
=
index
.
max
().
item
()
+
1
if
num_nodes
is
None
else
num_nodes
out
=
torch
.
zeros
((
num_nodes
),
dtype
=
dtype
,
device
=
device
)
return
out
.
scatter_add_
(
0
,
index
,
out
.
new_ones
((
index
.
size
(
0
))))
torch_spline_conv/weighting.py
View file @
ac26fc19
import
torch
import
torch_spline_conv.weighting_cpu
if
torch
.
cuda
.
is_available
():
import
torch_spline_conv.weighting_cuda
@
torch
.
jit
.
script
def
spline_weighting
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
basis
:
torch
.
Tensor
,
weight_index
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
ops
.
spline_conv
.
spline_weighting
(
x
,
weight
,
basis
,
weight_index
)
def
get_func
(
name
,
tensor
):
if
tensor
.
is_cuda
:
return
getattr
(
torch_spline_conv
.
weighting_cuda
,
name
)
else
:
return
getattr
(
torch_spline_conv
.
weighting_cpu
,
name
)
# class SplineWeighting(torch.autograd.Function):
# @staticmethod
# def forward(ctx, x, weight, basis, weight_index):
# ctx.weight_index = weight_index
# ctx.save_for_backward(x, weight, basis)
# op = get_func('weighting_fw', x)
# out = op(x, weight, basis, weight_index)
# return out
class
SplineWeighting
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
weight
,
basis
,
weight_index
):
ctx
.
weight_index
=
weight_index
ctx
.
save_for_backward
(
x
,
weight
,
basis
)
op
=
get_func
(
'weighting_fw'
,
x
)
out
=
op
(
x
,
weight
,
basis
,
weight_index
)
return
out
# @staticmethod
# def backward(ctx, grad_out):
# x, weight, basis = ctx.saved_tensors
# grad_x = grad_weight = grad_basis = None
@
staticmethod
def
backward
(
ctx
,
grad_out
):
x
,
weight
,
basis
=
ctx
.
saved_tensors
grad_x
=
grad_weight
=
grad_basis
=
None
# if ctx.needs_input_grad[0]:
# op = get_func('weighting_bw_x', x)
# grad_x = op(grad_out, weight, basis, ctx.weight_index)
if
ctx
.
needs_input_grad
[
0
]:
op
=
get_func
(
'weighting_bw_x'
,
x
)
grad_x
=
op
(
grad_out
,
weight
,
basis
,
ctx
.
weight_index
)
# if ctx.needs_input_grad[1]:
# op = get_func('weighting_bw_w', x)
# grad_weight = op(grad_out, x, basis, ctx.weight_index,
# weight.size(0))
if
ctx
.
needs_input_grad
[
1
]:
op
=
get_func
(
'weighting_bw_w'
,
x
)
grad_weight
=
op
(
grad_out
,
x
,
basis
,
ctx
.
weight_index
,
weight
.
size
(
0
))
# if ctx.needs_input_grad[2]:
# op = get_func('weighting_bw_b', x)
# grad_basis = op(grad_out, x, weight, ctx.weight_index)
if
ctx
.
needs_input_grad
[
2
]:
op
=
get_func
(
'weighting_bw_b'
,
x
)
grad_basis
=
op
(
grad_out
,
x
,
weight
,
ctx
.
weight_index
)
return
grad_x
,
grad_weight
,
grad_basis
,
None
# return grad_x, grad_weight, grad_basis, None
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment