Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
ac26fc19
Commit
ac26fc19
authored
Feb 27, 2020
by
rusty1s
Browse files
prepare tracing
parent
d3169766
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
199 additions
and
113 deletions
+199
-113
.coveragerc
.coveragerc
+3
-2
.gitignore
.gitignore
+0
-1
LICENSE
LICENSE
+1
-1
MANIFEST.in
MANIFEST.in
+4
-2
README.md
README.md
+42
-19
csrc/cpu/basis.cpp
csrc/cpu/basis.cpp
+0
-0
csrc/cpu/compat.h
csrc/cpu/compat.h
+0
-0
csrc/cpu/weighting.cpp
csrc/cpu/weighting.cpp
+0
-0
csrc/cuda/basis.cpp
csrc/cuda/basis.cpp
+0
-0
csrc/cuda/basis_kernel.cu
csrc/cuda/basis_kernel.cu
+0
-0
csrc/cuda/compat.cuh
csrc/cuda/compat.cuh
+0
-0
csrc/cuda/weighting.cpp
csrc/cuda/weighting.cpp
+0
-0
csrc/cuda/weighting_kernel.cu
csrc/cuda/weighting_kernel.cu
+0
-0
setup.py
setup.py
+58
-36
test/test_basis.py
test/test_basis.py
+3
-3
test/test_conv.py
test/test_conv.py
+7
-6
test/test_weighting.py
test/test_weighting.py
+5
-6
test/utils.py
test/utils.py
+1
-1
torch_spline_conv/__init__.py
torch_spline_conv/__init__.py
+48
-5
torch_spline_conv/basis.py
torch_spline_conv/basis.py
+27
-31
No files found.
.coveragerc
View file @
ac26fc19
...
...
@@ -3,5 +3,6 @@ source=torch_spline_conv
[report]
exclude_lines =
pragma: no cover
cuda
backward
torch.jit.script
raise
except
.gitignore
View file @
ac26fc19
__pycache__/
_ext/
build/
dist/
.cache/
...
...
LICENSE
View file @
ac26fc19
Copyright (c) 20
19
Matthias Fey <matthias.fey@tu-dortmund.de>
Copyright (c) 20
20
Matthias Fey <matthias.fey@tu-dortmund.de>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
...
...
MANIFEST.in
View file @
ac26fc19
include README.md
include LICENSE
recursive-include cpu *
recursive-include cuda *
recursive-exclude test *
recursive-include csrc *
README.md
View file @
ac26fc19
...
...
@@ -21,11 +21,30 @@ The operator works on all floating point data types and is implemented both for
## Installation
Ensure that at least PyTorch 1.1.0 is installed and verify that
`cuda/bin`
and
`cuda/include`
are in your
`$PATH`
and
`$CPATH`
respectively,
*e.g.*
:
### Binaries
We provide pip wheels for all major OS/PyTorch/CUDA combinations, see
[
here
](
https://pytorch-geometric.com/whl
)
.
To install the binaries for PyTorch 1.4.0, simply run
```
pip install torch-spline-conv==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.4.0.html
```
where
`${CUDA}`
should be replaced by either
`cpu`
,
`cu92`
,
`cu100`
or
`cu101`
depending on your PyTorch installation.
| |
`cpu`
|
`cu92`
|
`cu100`
|
`cu101`
|
|-------------|-------|--------|---------|---------|
|
**Linux**
| ✅ | ✅ | ✅ | ✅ |
|
**Windows**
| ✅ | ❌ | ❌ | ✅ |
|
**macOS**
| ✅ | | | |
### From source
Ensure that at least PyTorch 1.4.0 is installed and verify that
`cuda/bin`
and
`cuda/include`
are in your
`$PATH`
and
`$CPATH`
respectively,
*e.g.*
:
```
$ python -c "import torch; print(torch.__version__)"
>>> 1.
1
.0
>>> 1.
4
.0
$ echo $PATH
>>> /usr/local/cuda/bin:...
...
...
@@ -40,24 +59,28 @@ Then run:
pip install torch-spline-conv
```
If you are running into any installation problems, please create an
[
issue
](
https://github.com/rusty1s/pytorch_spline_conv/issues
)
.
Be sure to import
`torch`
first before using this package to resolve symbols the dynamic linker must see.
When running in a docker container without NVIDIA driver, PyTorch needs to evaluate the compute capabilities and may fail.
In this case, ensure that the compute capabilities are set via
`TORCH_CUDA_ARCH_LIST`
,
*e.g.*
:
```
export TORCH_CUDA_ARCH_LIST = "6.0 6.1 7.2+PTX 7.5+PTX"
```
## Usage
```
python
from
torch_spline_conv
import
S
pline
C
onv
out
=
S
pline
C
onv
.
apply
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
degree
=
1
,
norm
=
True
,
root_weight
=
None
,
bias
=
None
)
from
torch_spline_conv
import
s
pline
_c
onv
out
=
s
pline
_c
onv
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
degree
=
1
,
norm
=
True
,
root_weight
=
None
,
bias
=
None
)
```
Applies the spline-based convolution operator
...
...
@@ -93,7 +116,7 @@ The kernel function is defined over the weighted B-spline tensor product basis,
```
python
import
torch
from
torch_spline_conv
import
S
pline
C
onv
from
torch_spline_conv
import
s
pline
_c
onv
x
=
torch
.
rand
((
4
,
2
),
dtype
=
torch
.
float
)
# 4 nodes with 2 features each
edge_index
=
torch
.
tensor
([[
0
,
1
,
1
,
2
,
2
,
3
],
[
1
,
0
,
2
,
1
,
3
,
2
]])
# 6 edges
...
...
@@ -106,8 +129,8 @@ norm = True # Normalize output by node degree.
root_weight
=
torch
.
rand
((
2
,
4
),
dtype
=
torch
.
float
)
# separately weight root nodes
bias
=
None
# do not apply an additional bias
out
=
S
pline
C
onv
.
apply
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
degree
,
norm
,
root_weight
,
bias
)
out
=
s
pline
_c
onv
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
degree
,
norm
,
root_weight
,
bias
)
print
(
out
.
size
())
torch
.
Size
([
4
,
4
])
# 4 nodes with 4 features each
...
...
cpu/basis.cpp
→
csrc/
cpu/basis.cpp
View file @
ac26fc19
File moved
cpu/compat.h
→
csrc/
cpu/compat.h
View file @
ac26fc19
File moved
cpu/weighting.cpp
→
csrc/
cpu/weighting.cpp
View file @
ac26fc19
File moved
cuda/basis.cpp
→
csrc/
cuda/basis.cpp
View file @
ac26fc19
File moved
cuda/basis_kernel.cu
→
csrc/
cuda/basis_kernel.cu
View file @
ac26fc19
File moved
cuda/compat.cuh
→
csrc/
cuda/compat.cuh
View file @
ac26fc19
File moved
cuda/weighting.cpp
→
csrc/
cuda/weighting.cpp
View file @
ac26fc19
File moved
cuda/weighting_kernel.cu
→
csrc/
cuda/weighting_kernel.cu
View file @
ac26fc19
File moved
setup.py
View file @
ac26fc19
import
os
import
os.path
as
osp
import
glob
from
setuptools
import
setup
,
find_packages
from
sys
import
argv
import
torch
from
torch.utils.cpp_extension
import
BuildExtension
from
torch.utils.cpp_extension
import
CppExtension
,
CUDAExtension
,
CUDA_HOME
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
WITH_CUDA
=
torch
.
cuda
.
is_available
()
and
CUDA_HOME
is
not
None
if
os
.
getenv
(
'FORCE_CUDA'
,
'0'
)
==
'1'
:
WITH_CUDA
=
True
if
os
.
getenv
(
'FORCE_CPU'
,
'0'
)
==
'1'
:
WITH_CUDA
=
False
BUILD_DOCS
=
os
.
getenv
(
'BUILD_DOCS'
,
'0'
)
==
'1'
def
get_extensions
():
Extension
=
CppExtension
define_macros
=
[]
extra_compile_args
=
{
'cxx'
:
[]}
if
WITH_CUDA
:
Extension
=
CUDAExtension
define_macros
+=
[(
'WITH_CUDA'
,
None
)]
nvcc_flags
=
os
.
getenv
(
'NVCC_FLAGS'
,
''
)
nvcc_flags
=
[]
if
nvcc_flags
==
''
else
nvcc_flags
.
split
(
' '
)
extra_compile_args
[
'nvcc'
]
=
nvcc_flags
extensions_dir
=
osp
.
join
(
osp
.
dirname
(
osp
.
abspath
(
__file__
)),
'csrc'
)
main_files
=
glob
.
glob
(
osp
.
join
(
extensions_dir
,
'*.cpp'
))
extensions
=
[]
for
main
in
main_files
:
name
=
main
.
split
(
os
.
sep
)[
-
1
][:
-
4
]
extra_compile_args
=
[]
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>
2
):
extra_compile_args
+=
[
'-DVERSION_GE_1_3'
]
sources
=
[
main
]
ext_modules
=
[
CppExtension
(
'torch_spline_conv.basis_cpu'
,
[
'cpu/basis.cpp'
],
extra_compile_args
=
extra_compile_args
),
CppExtension
(
'torch_spline_conv.weighting_cpu'
,
[
'cpu/weighting.cpp'
],
extra_compile_args
=
extra_compile_args
),
]
cmdclass
=
{
'build_ext'
:
torch
.
utils
.
cpp_extension
.
BuildExtension
}
path
=
osp
.
join
(
extensions_dir
,
'cpu'
,
f
'
{
name
}
_cpu.cpp'
)
if
osp
.
exists
(
path
):
sources
+=
[
path
]
GPU
=
True
for
arg
in
argv
:
if
arg
==
'--cpu'
:
GPU
=
False
argv
.
remove
(
arg
)
path
=
osp
.
join
(
extensions_dir
,
'cuda'
,
f
'
{
name
}
_cuda.cu'
)
if
WITH_CUDA
and
osp
.
exists
(
path
):
sources
+=
[
path
]
if
CUDA_HOME
is
not
None
and
GPU
:
ext_modules
+=
[
CUDAExtension
(
'torch_spline_conv.basis_cuda'
,
[
'cuda/basis.cpp'
,
'cuda/basis_kernel.cu'
],
extra_compile_args
=
extra_compile_args
),
CUDAExtension
(
'torch_spline_conv.weighting_cuda'
,
[
'cuda/weighting.cpp'
,
'cuda/weighting_kernel.cu'
],
extra_compile_args
=
extra_compile_args
),
]
extension
=
Extension
(
'torch_scatter._'
+
name
,
sources
,
include_dirs
=
[
extensions_dir
],
define_macros
=
define_macros
,
extra_compile_args
=
extra_compile_args
,
)
extensions
+=
[
extension
]
return
extensions
__version__
=
'1.1.1'
url
=
'https://github.com/rusty1s/pytorch_spline_conv'
install_requires
=
[]
setup_requires
=
[
'pytest-runner'
]
...
...
@@ -43,23 +62,26 @@ tests_require = ['pytest', 'pytest-cov']
setup
(
name
=
'torch_spline_conv'
,
version
=
__version__
,
description
=
(
'Implementation of the Spline-Based Convolution Operator of '
'SplineCNN in PyTorch'
),
version
=
'1.2.0'
,
author
=
'Matthias Fey'
,
author_email
=
'matthias.fey@tu-dortmund.de'
,
url
=
url
,
download_url
=
'{}/archive/{}.tar.gz'
.
format
(
url
,
__version__
),
url
=
'https://github.com/rusty1s/pytorch_spline_conv'
,
description
=
(
'Implementation of the Spline-Based Convolution Operator of '
'SplineCNN in PyTorch'
),
keywords
=
[
'pytorch'
,
'geometric-deep-learning'
,
'graph-neural-networks'
,
'spline-cnn'
,
],
license
=
'MIT'
,
python_requires
=
'>=3.6'
,
install_requires
=
install_requires
,
setup_requires
=
setup_requires
,
tests_require
=
tests_require
,
ext_modules
=
ext_modules
,
cmdclass
=
cmdclass
,
ext_modules
=
get_extensions
()
if
not
BUILD_DOCS
else
[],
cmdclass
=
{
'build_ext'
:
BuildExtension
.
with_options
(
no_python_abi_suffix
=
True
)
},
packages
=
find_packages
(),
)
test/test_basis.py
View file @
ac26fc19
...
...
@@ -2,7 +2,7 @@ from itertools import product
import
pytest
import
torch
from
torch_spline_conv
.basis
import
S
pline
B
asis
from
torch_spline_conv
import
s
pline
_b
asis
from
.utils
import
dtypes
,
devices
,
tensor
...
...
@@ -34,7 +34,7 @@ def test_spline_basis_forward(test, dtype, device):
is_open_spline
=
tensor
(
test
[
'is_open_spline'
],
torch
.
uint8
,
device
)
degree
=
1
op
=
S
pline
B
asis
.
apply
basis
,
weight_index
=
op
(
pseudo
,
kernel_size
,
is_open_spline
,
degree
)
basis
,
weight_index
=
s
pline
_b
asis
(
pseudo
,
kernel_size
,
is_open_spline
,
degree
)
assert
basis
.
tolist
()
==
test
[
'basis'
]
assert
weight_index
.
tolist
()
==
test
[
'weight_index'
]
test/test_conv.py
View file @
ac26fc19
...
...
@@ -3,11 +3,12 @@ from itertools import product
import
pytest
import
torch
from
torch.autograd
import
gradcheck
from
torch_spline_conv
import
SplineConv
from
torch_spline_conv.basis
import
implemented_degrees
as
degrees
from
torch_spline_conv
import
spline_conv
from
.utils
import
dtypes
,
devices
,
tensor
degrees
=
[
1
,
2
,
3
]
tests
=
[{
'x'
:
[[
9
,
10
],
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
]],
'edge_index'
:
[[
0
,
0
,
0
,
0
],
[
1
,
2
,
3
,
4
]],
...
...
@@ -51,12 +52,12 @@ def test_spline_conv_forward(test, dtype, device):
root_weight
=
tensor
(
test
[
'root_weight'
],
dtype
,
device
)
bias
=
tensor
(
test
[
'bias'
],
dtype
,
device
)
out
=
S
pline
C
onv
.
apply
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
1
,
True
,
root_weight
,
bias
)
out
=
s
pline
_c
onv
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
1
,
True
,
root_weight
,
bias
)
assert
out
.
tolist
()
==
test
[
'expected'
]
@
pytest
.
mark
.
parametrize
(
'degree,device'
,
product
(
degrees
.
keys
()
,
devices
))
@
pytest
.
mark
.
parametrize
(
'degree,device'
,
product
(
degrees
,
devices
))
def
test_spline_basis_backward
(
degree
,
device
):
x
=
torch
.
rand
((
3
,
2
),
dtype
=
torch
.
double
,
device
=
device
)
x
.
requires_grad_
()
...
...
@@ -74,4 +75,4 @@ def test_spline_basis_backward(degree, device):
data
=
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
degree
,
True
,
root_weight
,
bias
)
assert
gradcheck
(
S
pline
C
onv
.
apply
,
data
,
eps
=
1e-6
,
atol
=
1e-4
)
is
True
assert
gradcheck
(
s
pline
_c
onv
,
data
,
eps
=
1e-6
,
atol
=
1e-4
)
is
True
test/test_weighting.py
View file @
ac26fc19
...
...
@@ -3,8 +3,7 @@ from itertools import product
import
pytest
import
torch
from
torch.autograd
import
gradcheck
from
torch_spline_conv.weighting
import
SplineWeighting
from
torch_spline_conv.basis
import
SplineBasis
from
torch_spline_conv
import
spline_weighting
,
spline_basis
from
.utils
import
dtypes
,
devices
,
tensor
...
...
@@ -27,7 +26,7 @@ def test_spline_weighting_forward(test, dtype, device):
basis
=
tensor
(
test
[
'basis'
],
dtype
,
device
)
weight_index
=
tensor
(
test
[
'weight_index'
],
torch
.
long
,
device
)
out
=
S
pline
W
eighting
.
apply
(
x
,
weight
,
basis
,
weight_index
)
out
=
s
pline
_w
eighting
(
x
,
weight
,
basis
,
weight_index
)
assert
out
.
tolist
()
==
test
[
'expected'
]
...
...
@@ -38,8 +37,8 @@ def test_spline_weighting_backward(device):
is_open_spline
=
tensor
([
1
,
1
],
torch
.
uint8
,
device
)
degree
=
1
op
=
S
pline
B
asis
.
apply
basis
,
weight_index
=
op
(
pseudo
,
kernel_size
,
is_open_spline
,
degree
)
basis
,
weight_index
=
s
pline
_b
asis
(
pseudo
,
kernel_size
,
is_open_spline
,
degree
)
basis
.
requires_grad_
()
x
=
torch
.
rand
((
4
,
2
),
dtype
=
torch
.
double
,
device
=
device
)
...
...
@@ -48,4 +47,4 @@ def test_spline_weighting_backward(device):
weight
.
requires_grad_
()
data
=
(
x
,
weight
,
basis
,
weight_index
)
assert
gradcheck
(
S
pline
W
eighting
.
apply
,
data
,
eps
=
1e-6
,
atol
=
1e-4
)
is
True
assert
gradcheck
(
s
pline
_w
eighting
,
data
,
eps
=
1e-6
,
atol
=
1e-4
)
is
True
test/utils.py
View file @
ac26fc19
...
...
@@ -4,7 +4,7 @@ dtypes = [torch.float, torch.double]
devices
=
[
torch
.
device
(
'cpu'
)]
if
torch
.
cuda
.
is_available
():
devices
+=
[
torch
.
device
(
'cuda:{
}'
.
format
(
torch
.
cuda
.
current_device
()
)
)]
devices
+=
[
torch
.
device
(
f
'cuda:
{
torch
.
cuda
.
current_device
()
}
'
)]
def
tensor
(
x
,
dtype
,
device
):
...
...
torch_spline_conv/__init__.py
View file @
ac26fc19
from
.basis
import
SplineBasis
from
.weighting
import
SplineWeighting
from
.conv
import
SplineConv
import
importlib
import
os.path
as
osp
__version__
=
'1.1.1'
import
torch
__all__
=
[
'SplineBasis'
,
'SplineWeighting'
,
'SplineConv'
,
'__version__'
]
__version__
=
'1.2.0'
expected_torch_version
=
(
1
,
4
)
try
:
for
library
in
[
'_version'
,
'_basis'
,
'_weighting'
]:
torch
.
ops
.
load_library
(
importlib
.
machinery
.
PathFinder
().
find_spec
(
library
,
[
osp
.
dirname
(
__file__
)]).
origin
)
except
OSError
as
e
:
major
,
minor
=
[
int
(
x
)
for
x
in
torch
.
__version__
.
split
(
'.'
)[:
2
]]
t_major
,
t_minor
=
expected_torch_version
if
major
!=
t_major
or
(
major
==
t_major
and
minor
!=
t_minor
):
raise
RuntimeError
(
f
'Expected PyTorch version
{
t_major
}
.
{
t_minor
}
but found '
f
'version
{
major
}
.
{
minor
}
.'
)
raise
OSError
(
e
)
if
torch
.
version
.
cuda
is
not
None
:
# pragma: no cover
cuda_version
=
torch
.
ops
.
torch_scatter
.
cuda_version
()
if
cuda_version
==
-
1
:
major
=
minor
=
0
elif
cuda_version
<
10000
:
major
,
minor
=
int
(
str
(
cuda_version
)[
0
]),
int
(
str
(
cuda_version
)[
2
])
else
:
major
,
minor
=
int
(
str
(
cuda_version
)[
0
:
2
]),
int
(
str
(
cuda_version
)[
3
])
t_major
,
t_minor
=
[
int
(
x
)
for
x
in
torch
.
version
.
cuda
.
split
(
'.'
)]
if
t_major
!=
major
or
t_minor
!=
minor
:
raise
RuntimeError
(
f
'Detected that PyTorch and torch_spline_conv were compiled with '
f
'different CUDA versions. PyTorch has CUDA version '
f
'
{
t_major
}
.
{
t_minor
}
and torch_spline_conv has CUDA version '
f
'
{
major
}
.
{
minor
}
. Please reinstall the torch_spline_conv that '
f
'matches your PyTorch install.'
)
from
.basis
import
spline_basis
# noqa
from
.weighting
import
spline_weighting
# noqa
from
.conv
import
spline_conv
# noqa
__all__
=
[
'spline_basis'
,
'spline_weighting'
,
'spline_conv'
,
'__version__'
,
]
torch_spline_conv/basis.py
View file @
ac26fc19
import
torch
import
torch_spline_conv.basis_cpu
if
torch
.
cuda
.
is_available
():
import
torch_spline_conv.basis_cuda
from
typing
import
Tuple
imp
lemented_degrees
=
{
1
:
'linear'
,
2
:
'quadratic'
,
3
:
'cubic'
}
imp
ort
torch
def
get_func
(
name
,
tensor
):
if
tensor
.
is_cuda
:
return
getattr
(
torch_spline_conv
.
basis_cuda
,
name
)
else
:
return
getattr
(
torch_spline_conv
.
basis_cpu
,
name
)
@
torch
.
jit
.
script
def
spline_basis
(
pseudo
:
torch
.
Tensor
,
kernel_size
:
torch
.
Tensor
,
is_open_spline
:
torch
.
Tensor
,
degree
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
ops
.
torch_spline_conv
.
spline_basis
(
pseudo
,
kernel_size
,
is_open_spline
,
degree
)
class
SplineBasis
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
pseudo
,
kernel_size
,
is_open_spline
,
degree
):
ctx
.
save_for_backward
(
pseudo
)
ctx
.
kernel_size
=
kernel_size
ctx
.
is_open_spline
=
is_open_spline
ctx
.
degree
=
degree
#
class SplineBasis(torch.autograd.Function):
#
@staticmethod
#
def forward(ctx, pseudo, kernel_size, is_open_spline, degree):
#
ctx.save_for_backward(pseudo)
#
ctx.kernel_size = kernel_size
#
ctx.is_open_spline = is_open_spline
#
ctx.degree = degree
op
=
get_func
(
'{}_fw'
.
format
(
implemented_degrees
[
degree
]),
pseudo
)
basis
,
weight_index
=
op
(
pseudo
,
kernel_size
,
is_open_spline
)
#
op = get_func('{}_fw'.format(implemented_degrees[degree]), pseudo)
#
basis, weight_index = op(pseudo, kernel_size, is_open_spline)
return
basis
,
weight_index
#
return basis, weight_index
@
staticmethod
def
backward
(
ctx
,
grad_basis
,
grad_weight_index
):
pseudo
,
=
ctx
.
saved_tensors
kernel_size
,
is_open_spline
=
ctx
.
kernel_size
,
ctx
.
is_open_spline
degree
=
ctx
.
degree
grad_pseudo
=
None
#
@staticmethod
#
def backward(ctx, grad_basis, grad_weight_index):
#
pseudo, = ctx.saved_tensors
#
kernel_size, is_open_spline = ctx.kernel_size, ctx.is_open_spline
#
degree = ctx.degree
#
grad_pseudo = None
if
ctx
.
needs_input_grad
[
0
]:
op
=
get_func
(
'{}_bw'
.
format
(
implemented_degrees
[
degree
]),
pseudo
)
grad_pseudo
=
op
(
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
)
# if ctx.needs_input_grad[0]:
# grad_pseudo = op(grad_basis, pseudo, kernel_size, is_open_spline)
return
grad_pseudo
,
None
,
None
,
None
#
return grad_pseudo, None, None, None
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment