Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
cd7b1988
Unverified
Commit
cd7b1988
authored
Feb 29, 2020
by
Matthias Fey
Committed by
GitHub
Feb 29, 2020
Browse files
Merge pull request #14 from rusty1s/tracing
prepare tracing
parents
d3169766
32224979
Changes
42
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
46 deletions
+6
-46
torch_spline_conv/utils/degree.py
torch_spline_conv/utils/degree.py
+0
-7
torch_spline_conv/weighting.py
torch_spline_conv/weighting.py
+6
-39
No files found.
torch_spline_conv/utils/degree.py
deleted
100644 → 0
View file @
d3169766
import
torch
def
degree
(
index
,
num_nodes
=
None
,
dtype
=
None
,
device
=
None
):
num_nodes
=
index
.
max
().
item
()
+
1
if
num_nodes
is
None
else
num_nodes
out
=
torch
.
zeros
((
num_nodes
),
dtype
=
dtype
,
device
=
device
)
return
out
.
scatter_add_
(
0
,
index
,
out
.
new_ones
((
index
.
size
(
0
))))
torch_spline_conv/weighting.py
View file @
cd7b1988
import
torch
import
torch_spline_conv.weighting_cpu
if
torch
.
cuda
.
is_available
():
import
torch_spline_conv.weighting_cuda
def
get_func
(
name
,
tensor
):
if
tensor
.
is_cuda
:
return
getattr
(
torch_spline_conv
.
weighting_cuda
,
name
)
else
:
return
getattr
(
torch_spline_conv
.
weighting_cpu
,
name
)
class
SplineWeighting
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
weight
,
basis
,
weight_index
):
ctx
.
weight_index
=
weight_index
ctx
.
save_for_backward
(
x
,
weight
,
basis
)
op
=
get_func
(
'weighting_fw'
,
x
)
out
=
op
(
x
,
weight
,
basis
,
weight_index
)
return
out
@
staticmethod
def
backward
(
ctx
,
grad_out
):
x
,
weight
,
basis
=
ctx
.
saved_tensors
grad_x
=
grad_weight
=
grad_basis
=
None
if
ctx
.
needs_input_grad
[
0
]:
op
=
get_func
(
'weighting_bw_x'
,
x
)
grad_x
=
op
(
grad_out
,
weight
,
basis
,
ctx
.
weight_index
)
if
ctx
.
needs_input_grad
[
1
]:
op
=
get_func
(
'weighting_bw_w'
,
x
)
grad_weight
=
op
(
grad_out
,
x
,
basis
,
ctx
.
weight_index
,
weight
.
size
(
0
))
if
ctx
.
needs_input_grad
[
2
]:
op
=
get_func
(
'weighting_bw_b'
,
x
)
grad_basis
=
op
(
grad_out
,
x
,
weight
,
ctx
.
weight_index
)
return
grad_x
,
grad_weight
,
grad_basis
,
None
@
torch
.
jit
.
script
def
spline_weighting
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
basis
:
torch
.
Tensor
,
weight_index
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
ops
.
torch_spline_conv
.
spline_weighting
(
x
,
weight
,
basis
,
weight_index
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment