Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
07804abc
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "d774acad5cef7a538da33d39207f9e2bc51474eb"
Commit
07804abc
authored
Mar 02, 2018
by
rusty1s
Browse files
python impl
parent
5c2b664b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
99 additions
and
0 deletions
+99
-0
torch_spline_conv/functions/degree.py
torch_spline_conv/functions/degree.py
+7
-0
torch_spline_conv/functions/spline_conv.py
torch_spline_conv/functions/spline_conv.py
+48
-0
torch_spline_conv/functions/utils.py
torch_spline_conv/functions/utils.py
+44
-0
No files found.
torch_spline_conv/functions/degree.py
0 → 100644
View file @
07804abc
import
torch
def
node_degree
(
index
,
out
=
None
):
one
=
torch
.
ones
(
index
.
size
(
1
),
out
)
zero
=
torch
.
zeros
(
index
.
size
(
1
),
out
)
return
zero
.
scatter_add_
(
0
,
index
[
0
],
one
)
torch_spline_conv/functions/spline_conv.py
0 → 100644
View file @
07804abc
import
torch
# from torch.autograd import Variable as Var
from
.degree
import
node_degree
from
.utils
import
spline_bases
,
spline_weighting
def
spline_conv
(
x
,
index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
root_weight
=
None
,
degree
=
1
,
bias
=
None
):
x
=
x
.
unsqueeze
(
-
1
)
if
x
.
dim
()
==
1
else
x
# Get features for every target node => |E| x M_in
output
=
x
[
index
[
1
]]
# Get B-spline basis products and weight indices for each edge.
basis
,
weight_index
=
spline_bases
(
pseudo
,
kernel_size
,
is_open_spline
,
degree
)
# Weight gathered features based on B-spline basis and trainable weights.
output
=
spline_weighting
(
output
,
weight
,
basis
,
weight_index
)
# Perform the real convolution => Convert |E| x M_out to N x M_out output.
row
=
index
[
0
].
unsqueeze
(
-
1
).
expand
(
-
1
,
output
.
size
(
1
))
# zero = x if torch.is_tensor(x) else x.data
zero
=
x
.
new
(
row
.
size
()).
fill_
(
0
)
# row, zero = row, zero if torch.is_tensor(x) else Var(row), Var(zero)
output
=
zero
.
scatter_add_
(
0
,
row
,
output
)
# Normalize output by node degree.
output
/=
node_degree
(
index
,
out
=
x
.
new
()).
unsqueeze
(
-
1
).
clamp_
(
min
=
1
)
# Weight root node separately (if wished).
if
root_weight
is
not
None
:
output
+=
torch
.
mm
(
x
,
root_weight
)
# Add bias (if wished).
if
bias
is
not
None
:
output
+=
bias
return
output
torch_spline_conv/functions/utils.py
0 → 100644
View file @
07804abc
import
torch
from
torch.autograd
import
Function
from
.._ext
import
ffi
def
get_func
(
name
,
tensor
):
typename
=
type
(
tensor
).
__name__
.
replace
(
'Tensor'
,
''
)
cuda
=
'cuda_'
if
tensor
.
is_cuda
else
''
func
=
getattr
(
ffi
,
'spline_{}_{}{}'
.
format
(
name
,
cuda
,
typename
))
return
func
def
spline_bases
(
pseudo
,
kernel_size
,
is_open_spline
,
degree
):
# raise NotImplementedError for degree > 3
pass
def
spline_weighting_forward
(
x
,
weight
,
basis
,
weight_index
):
pass
def
spline_weighting_backward
(
x
,
weight
,
basis
,
weight_index
):
pass
class
SplineWeighting
(
Function
):
def
__init__
(
self
,
basis
,
weight_index
):
super
(
SplineWeighting
,
self
).
__init__
()
self
.
basis
=
basis
self
.
weight_index
=
weight_index
def
forward
(
self
,
x
,
weight
):
pass
def
backward
(
self
,
grad_output
):
pass
def
spline_weighting
(
x
,
weight
,
basis
,
weight_index
):
if
torch
.
is_tensor
(
x
):
return
spline_weighting_forward
(
x
,
weight
,
basis
,
weight_index
)
else
:
return
SplineWeighting
(
basis
,
weight_index
)(
x
,
weight
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment