Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
676aa072
Commit
676aa072
authored
Oct 14, 2019
by
rusty1s
Browse files
pytorch 1.3 support
parent
0dd1186d
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
77 additions
and
56 deletions
+77
-56
.travis.yml
.travis.yml
+1
-1
cpu/basis.cpp
cpu/basis.cpp
+12
-10
cpu/compat.h
cpu/compat.h
+5
-0
cpu/weighting.cpp
cpu/weighting.cpp
+22
-20
cuda/basis_kernel.cu
cuda/basis_kernel.cu
+6
-4
cuda/compat.cuh
cuda/compat.cuh
+5
-0
setup.py
setup.py
+16
-5
test/test_weighting.py
test/test_weighting.py
+1
-1
torch_spline_conv/__init__.py
torch_spline_conv/__init__.py
+1
-1
torch_spline_conv/basis.py
torch_spline_conv/basis.py
+2
-1
torch_spline_conv/conv.py
torch_spline_conv/conv.py
+6
-13
No files found.
.travis.yml
View file @
676aa072
...
@@ -17,7 +17,7 @@ before_install:
...
@@ -17,7 +17,7 @@ before_install:
-
export CXX="g++-4.9"
-
export CXX="g++-4.9"
install
:
install
:
-
pip install numpy
-
pip install numpy
-
pip install -
q torch
-f https://download.pytorch.org/whl/nightly/cpu/torch.html
-
pip install -
-pre torch torchvision
-f https://download.pytorch.org/whl/nightly/cpu/torch
_nightly
.html
-
pip install pycodestyle
-
pip install pycodestyle
-
pip install flake8
-
pip install flake8
-
pip install codecov
-
pip install codecov
...
...
cpu/basis.cpp
View file @
676aa072
#include <torch/extension.h>
#include <torch/extension.h>
#include "compat.h"
template
<
typename
scalar_t
>
inline
scalar_t
linear
(
scalar_t
v
,
int64_t
k_mod
)
{
template
<
typename
scalar_t
>
inline
scalar_t
linear
(
scalar_t
v
,
int64_t
k_mod
)
{
return
1
-
v
-
k_mod
+
2
*
v
*
k_mod
;
return
1
-
v
-
k_mod
+
2
*
v
*
k_mod
;
}
}
...
@@ -34,11 +36,11 @@ template <typename scalar_t> inline scalar_t cubic(scalar_t v, int64_t k_mod) {
...
@@ -34,11 +36,11 @@ template <typename scalar_t> inline scalar_t cubic(scalar_t v, int64_t k_mod) {
\
\
AT_DISPATCH_FLOATING_TYPES( \
AT_DISPATCH_FLOATING_TYPES( \
PSEUDO.scalar_type(), "basis_forward_##M", [&] { \
PSEUDO.scalar_type(), "basis_forward_##M", [&] { \
auto pseudo_data = PSEUDO.
data
<scalar_t>();
\
auto pseudo_data = PSEUDO.
DATA_PTR
<scalar_t>(); \
auto kernel_size_data = KERNEL_SIZE.
data
<int64_t>();
\
auto kernel_size_data = KERNEL_SIZE.
DATA_PTR
<int64_t>(); \
auto is_open_spline_data = IS_OPEN_SPLINE.
data
<uint8_t>();
\
auto is_open_spline_data = IS_OPEN_SPLINE.
DATA_PTR
<uint8_t>(); \
auto basis_data = basis.
data
<scalar_t>();
\
auto basis_data = basis.
DATA_PTR
<scalar_t>(); \
auto weight_index_data = weight_index.
data
<int64_t>();
\
auto weight_index_data = weight_index.
DATA_PTR
<int64_t>(); \
\
\
int64_t k, wi, wi_offset; \
int64_t k, wi, wi_offset; \
scalar_t b; \
scalar_t b; \
...
@@ -126,11 +128,11 @@ inline scalar_t grad_cubic(scalar_t v, int64_t k_mod) {
...
@@ -126,11 +128,11 @@ inline scalar_t grad_cubic(scalar_t v, int64_t k_mod) {
\
\
AT_DISPATCH_FLOATING_TYPES( \
AT_DISPATCH_FLOATING_TYPES( \
PSEUDO.scalar_type(), "basis_backward_##M", [&] { \
PSEUDO.scalar_type(), "basis_backward_##M", [&] { \
auto grad_basis_data = GRAD_BASIS.
data
<scalar_t>();
\
auto grad_basis_data = GRAD_BASIS.
DATA_PTR
<scalar_t>(); \
auto pseudo_data = PSEUDO.
data
<scalar_t>();
\
auto pseudo_data = PSEUDO.
DATA_PTR
<scalar_t>(); \
auto kernel_size_data = KERNEL_SIZE.
data
<int64_t>();
\
auto kernel_size_data = KERNEL_SIZE.
DATA_PTR
<int64_t>(); \
auto is_open_spline_data = IS_OPEN_SPLINE.
data
<uint8_t>();
\
auto is_open_spline_data = IS_OPEN_SPLINE.
DATA_PTR
<uint8_t>(); \
auto grad_pseudo_data = grad_pseudo.
data
<scalar_t>();
\
auto grad_pseudo_data = grad_pseudo.
DATA_PTR
<scalar_t>(); \
\
\
scalar_t g, tmp; \
scalar_t g, tmp; \
\
\
...
...
cpu/compat.h
0 → 100644
View file @
676aa072
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
cpu/weighting.cpp
View file @
676aa072
#include <torch/extension.h>
#include <torch/extension.h>
#include "compat.h"
at
::
Tensor
weighting_fw
(
at
::
Tensor
x
,
at
::
Tensor
weight
,
at
::
Tensor
basis
,
at
::
Tensor
weighting_fw
(
at
::
Tensor
x
,
at
::
Tensor
weight
,
at
::
Tensor
basis
,
at
::
Tensor
weight_index
)
{
at
::
Tensor
weight_index
)
{
auto
E
=
x
.
size
(
0
),
M_in
=
x
.
size
(
1
),
M_out
=
weight
.
size
(
2
);
auto
E
=
x
.
size
(
0
),
M_in
=
x
.
size
(
1
),
M_out
=
weight
.
size
(
2
);
...
@@ -7,11 +9,11 @@ at::Tensor weighting_fw(at::Tensor x, at::Tensor weight, at::Tensor basis,
...
@@ -7,11 +9,11 @@ at::Tensor weighting_fw(at::Tensor x, at::Tensor weight, at::Tensor basis,
auto
out
=
at
::
empty
({
E
,
M_out
},
x
.
options
());
auto
out
=
at
::
empty
({
E
,
M_out
},
x
.
options
());
AT_DISPATCH_FLOATING_TYPES
(
out
.
scalar_type
(),
"weighting_fw"
,
[
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
out
.
scalar_type
(),
"weighting_fw"
,
[
&
]
{
auto
x_data
=
x
.
data
<
scalar_t
>
();
auto
x_data
=
x
.
DATA_PTR
<
scalar_t
>
();
auto
weight_data
=
weight
.
data
<
scalar_t
>
();
auto
weight_data
=
weight
.
DATA_PTR
<
scalar_t
>
();
auto
basis_data
=
basis
.
data
<
scalar_t
>
();
auto
basis_data
=
basis
.
DATA_PTR
<
scalar_t
>
();
auto
weight_index_data
=
weight_index
.
data
<
int64_t
>
();
auto
weight_index_data
=
weight_index
.
DATA_PTR
<
int64_t
>
();
auto
out_data
=
out
.
data
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
scalar_t
v
;
scalar_t
v
;
...
@@ -44,11 +46,11 @@ at::Tensor weighting_bw_x(at::Tensor grad_out, at::Tensor weight,
...
@@ -44,11 +46,11 @@ at::Tensor weighting_bw_x(at::Tensor grad_out, at::Tensor weight,
auto
grad_x
=
at
::
zeros
({
E
,
M_in
},
grad_out
.
options
());
auto
grad_x
=
at
::
zeros
({
E
,
M_in
},
grad_out
.
options
());
AT_DISPATCH_FLOATING_TYPES
(
grad_out
.
scalar_type
(),
"weighting_bw_x"
,
[
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
grad_out
.
scalar_type
(),
"weighting_bw_x"
,
[
&
]
{
auto
grad_out_data
=
grad_out
.
data
<
scalar_t
>
();
auto
grad_out_data
=
grad_out
.
DATA_PTR
<
scalar_t
>
();
auto
weight_data
=
weight
.
data
<
scalar_t
>
();
auto
weight_data
=
weight
.
DATA_PTR
<
scalar_t
>
();
auto
basis_data
=
basis
.
data
<
scalar_t
>
();
auto
basis_data
=
basis
.
DATA_PTR
<
scalar_t
>
();
auto
weight_index_data
=
weight_index
.
data
<
int64_t
>
();
auto
weight_index_data
=
weight_index
.
DATA_PTR
<
int64_t
>
();
auto
grad_x_data
=
grad_x
.
data
<
scalar_t
>
();
auto
grad_x_data
=
grad_x
.
DATA_PTR
<
scalar_t
>
();
for
(
ptrdiff_t
e
=
0
;
e
<
E
;
e
++
)
{
for
(
ptrdiff_t
e
=
0
;
e
<
E
;
e
++
)
{
for
(
ptrdiff_t
m_out
=
0
;
m_out
<
M_out
;
m_out
++
)
{
for
(
ptrdiff_t
m_out
=
0
;
m_out
<
M_out
;
m_out
++
)
{
...
@@ -78,11 +80,11 @@ at::Tensor weighting_bw_w(at::Tensor grad_out, at::Tensor x, at::Tensor basis,
...
@@ -78,11 +80,11 @@ at::Tensor weighting_bw_w(at::Tensor grad_out, at::Tensor x, at::Tensor basis,
auto
grad_weight
=
at
::
zeros
({
K
,
M_in
,
M_out
},
grad_out
.
options
());
auto
grad_weight
=
at
::
zeros
({
K
,
M_in
,
M_out
},
grad_out
.
options
());
AT_DISPATCH_FLOATING_TYPES
(
grad_out
.
scalar_type
(),
"weighting_bw_w"
,
[
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
grad_out
.
scalar_type
(),
"weighting_bw_w"
,
[
&
]
{
auto
grad_out_data
=
grad_out
.
data
<
scalar_t
>
();
auto
grad_out_data
=
grad_out
.
DATA_PTR
<
scalar_t
>
();
auto
x_data
=
x
.
data
<
scalar_t
>
();
auto
x_data
=
x
.
DATA_PTR
<
scalar_t
>
();
auto
basis_data
=
basis
.
data
<
scalar_t
>
();
auto
basis_data
=
basis
.
DATA_PTR
<
scalar_t
>
();
auto
weight_index_data
=
weight_index
.
data
<
int64_t
>
();
auto
weight_index_data
=
weight_index
.
DATA_PTR
<
int64_t
>
();
auto
grad_weight_data
=
grad_weight
.
data
<
scalar_t
>
();
auto
grad_weight_data
=
grad_weight
.
DATA_PTR
<
scalar_t
>
();
for
(
ptrdiff_t
e
=
0
;
e
<
E
;
e
++
)
{
for
(
ptrdiff_t
e
=
0
;
e
<
E
;
e
++
)
{
for
(
ptrdiff_t
m_out
=
0
;
m_out
<
M_out
;
m_out
++
)
{
for
(
ptrdiff_t
m_out
=
0
;
m_out
<
M_out
;
m_out
++
)
{
...
@@ -110,11 +112,11 @@ at::Tensor weighting_bw_b(at::Tensor grad_out, at::Tensor x, at::Tensor weight,
...
@@ -110,11 +112,11 @@ at::Tensor weighting_bw_b(at::Tensor grad_out, at::Tensor x, at::Tensor weight,
auto
grad_basis
=
at
::
zeros
({
E
,
S
},
grad_out
.
options
());
auto
grad_basis
=
at
::
zeros
({
E
,
S
},
grad_out
.
options
());
AT_DISPATCH_FLOATING_TYPES
(
grad_out
.
scalar_type
(),
"weighting_bw_b"
,
[
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
grad_out
.
scalar_type
(),
"weighting_bw_b"
,
[
&
]
{
auto
grad_out_data
=
grad_out
.
data
<
scalar_t
>
();
auto
grad_out_data
=
grad_out
.
DATA_PTR
<
scalar_t
>
();
auto
x_data
=
x
.
data
<
scalar_t
>
();
auto
x_data
=
x
.
DATA_PTR
<
scalar_t
>
();
auto
weight_data
=
weight
.
data
<
scalar_t
>
();
auto
weight_data
=
weight
.
DATA_PTR
<
scalar_t
>
();
auto
weight_index_data
=
weight_index
.
data
<
int64_t
>
();
auto
weight_index_data
=
weight_index
.
DATA_PTR
<
int64_t
>
();
auto
grad_basis_data
=
grad_basis
.
data
<
scalar_t
>
();
auto
grad_basis_data
=
grad_basis
.
DATA_PTR
<
scalar_t
>
();
for
(
ptrdiff_t
e
=
0
;
e
<
E
;
e
++
)
{
for
(
ptrdiff_t
e
=
0
;
e
<
E
;
e
++
)
{
for
(
ptrdiff_t
m_out
=
0
;
m_out
<
M_out
;
m_out
++
)
{
for
(
ptrdiff_t
m_out
=
0
;
m_out
<
M_out
;
m_out
++
)
{
...
...
cuda/basis_kernel.cu
View file @
676aa072
...
@@ -2,6 +2,8 @@
...
@@ -2,6 +2,8 @@
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include "compat.cuh"
#define THREADS 1024
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define BLOCKS(N) (N + THREADS - 1) / THREADS
...
@@ -45,8 +47,8 @@ template <typename scalar_t> struct BasisForward {
...
@@ -45,8 +47,8 @@ template <typename scalar_t> struct BasisForward {
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(basis), \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(basis), \
at::cuda::detail::getTensorInfo<int64_t, int64_t>(weight_index), \
at::cuda::detail::getTensorInfo<int64_t, int64_t>(weight_index), \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(PSEUDO), \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(PSEUDO), \
KERNEL_SIZE.
data
<int64_t>(),
IS_OPEN_SPLINE.data<uint8_t>(),
\
KERNEL_SIZE.
DATA_PTR
<int64_t>(),
\
basis.numel());
\
IS_OPEN_SPLINE.DATA_PTR<uint8_t>(), basis.numel());
\
}); \
}); \
\
\
return std::make_tuple(basis, weight_index); \
return std::make_tuple(basis, weight_index); \
...
@@ -176,8 +178,8 @@ template <typename scalar_t> struct BasisBackward {
...
@@ -176,8 +178,8 @@ template <typename scalar_t> struct BasisBackward {
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_pseudo), \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_pseudo), \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(GRAD_BASIS), \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(GRAD_BASIS), \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(PSEUDO), \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(PSEUDO), \
KERNEL_SIZE.
data
<int64_t>(),
IS_OPEN_SPLINE.data<uint8_t>(),
\
KERNEL_SIZE.
DATA_PTR
<int64_t>(),
\
grad_pseudo.numel());
\
IS_OPEN_SPLINE.DATA_PTR<uint8_t>(), grad_pseudo.numel());
\
}); \
}); \
\
\
return grad_pseudo; \
return grad_pseudo; \
...
...
cuda/compat.cuh
0 → 100644
View file @
676aa072
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
setup.py
View file @
676aa072
...
@@ -2,21 +2,32 @@ from setuptools import setup, find_packages
...
@@ -2,21 +2,32 @@ from setuptools import setup, find_packages
import
torch
import
torch
from
torch.utils.cpp_extension
import
CppExtension
,
CUDAExtension
,
CUDA_HOME
from
torch.utils.cpp_extension
import
CppExtension
,
CUDAExtension
,
CUDA_HOME
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
extra_compile_args
=
[]
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>
2
):
extra_compile_args
+=
[
'-DVERSION_GE_1_3'
]
ext_modules
=
[
ext_modules
=
[
CppExtension
(
'torch_spline_conv.basis_cpu'
,
[
'cpu/basis.cpp'
]),
CppExtension
(
'torch_spline_conv.basis_cpu'
,
[
'cpu/basis.cpp'
],
CppExtension
(
'torch_spline_conv.weighting_cpu'
,
[
'cpu/weighting.cpp'
]),
extra_compile_args
=
extra_compile_args
),
CppExtension
(
'torch_spline_conv.weighting_cpu'
,
[
'cpu/weighting.cpp'
],
extra_compile_args
=
extra_compile_args
),
]
]
cmdclass
=
{
'build_ext'
:
torch
.
utils
.
cpp_extension
.
BuildExtension
}
cmdclass
=
{
'build_ext'
:
torch
.
utils
.
cpp_extension
.
BuildExtension
}
if
CUDA_HOME
is
not
None
:
if
CUDA_HOME
is
not
None
:
ext_modules
+=
[
ext_modules
+=
[
CUDAExtension
(
'torch_spline_conv.basis_cuda'
,
CUDAExtension
(
'torch_spline_conv.basis_cuda'
,
[
'cuda/basis.cpp'
,
'cuda/basis_kernel.cu'
]),
[
'cuda/basis.cpp'
,
'cuda/basis_kernel.cu'
],
extra_compile_args
=
extra_compile_args
),
CUDAExtension
(
'torch_spline_conv.weighting_cuda'
,
CUDAExtension
(
'torch_spline_conv.weighting_cuda'
,
[
'cuda/weighting.cpp'
,
'cuda/weighting_kernel.cu'
]),
[
'cuda/weighting.cpp'
,
'cuda/weighting_kernel.cu'
],
extra_compile_args
=
extra_compile_args
),
]
]
__version__
=
'1.1.
0
'
__version__
=
'1.1.
1
'
url
=
'https://github.com/rusty1s/pytorch_spline_conv'
url
=
'https://github.com/rusty1s/pytorch_spline_conv'
install_requires
=
[]
install_requires
=
[]
...
...
test/test_weighting.py
View file @
676aa072
...
@@ -32,7 +32,7 @@ def test_spline_weighting_forward(test, dtype, device):
...
@@ -32,7 +32,7 @@ def test_spline_weighting_forward(test, dtype, device):
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_spline_
basis
_backward
(
device
):
def
test_spline_
weighting
_backward
(
device
):
pseudo
=
torch
.
rand
((
4
,
2
),
dtype
=
torch
.
double
,
device
=
device
)
pseudo
=
torch
.
rand
((
4
,
2
),
dtype
=
torch
.
double
,
device
=
device
)
kernel_size
=
tensor
([
5
,
5
],
torch
.
long
,
device
)
kernel_size
=
tensor
([
5
,
5
],
torch
.
long
,
device
)
is_open_spline
=
tensor
([
1
,
1
],
torch
.
uint8
,
device
)
is_open_spline
=
tensor
([
1
,
1
],
torch
.
uint8
,
device
)
...
...
torch_spline_conv/__init__.py
View file @
676aa072
...
@@ -2,6 +2,6 @@ from .basis import SplineBasis
...
@@ -2,6 +2,6 @@ from .basis import SplineBasis
from
.weighting
import
SplineWeighting
from
.weighting
import
SplineWeighting
from
.conv
import
SplineConv
from
.conv
import
SplineConv
__version__
=
'1.1.
0
'
__version__
=
'1.1.
1
'
__all__
=
[
'SplineBasis'
,
'SplineWeighting'
,
'SplineConv'
,
'__version__'
]
__all__
=
[
'SplineBasis'
,
'SplineWeighting'
,
'SplineConv'
,
'__version__'
]
torch_spline_conv/basis.py
View file @
676aa072
...
@@ -18,7 +18,8 @@ class SplineBasis(torch.autograd.Function):
...
@@ -18,7 +18,8 @@ class SplineBasis(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
pseudo
,
kernel_size
,
is_open_spline
,
degree
):
def
forward
(
ctx
,
pseudo
,
kernel_size
,
is_open_spline
,
degree
):
ctx
.
save_for_backward
(
pseudo
)
ctx
.
save_for_backward
(
pseudo
)
ctx
.
kernel_size
,
ctx
.
is_open_spline
=
kernel_size
,
is_open_spline
ctx
.
kernel_size
=
kernel_size
ctx
.
is_open_spline
=
is_open_spline
ctx
.
degree
=
degree
ctx
.
degree
=
degree
op
=
get_func
(
'{}_fw'
.
format
(
implemented_degrees
[
degree
]),
pseudo
)
op
=
get_func
(
'{}_fw'
.
format
(
implemented_degrees
[
degree
]),
pseudo
)
...
...
torch_spline_conv/conv.py
View file @
676aa072
...
@@ -38,18 +38,9 @@ class SplineConv(object):
...
@@ -38,18 +38,9 @@ class SplineConv(object):
:rtype: :class:`Tensor`
:rtype: :class:`Tensor`
"""
"""
@
staticmethod
@
staticmethod
def
apply
(
x
,
def
apply
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
edge_index
,
degree
=
1
,
norm
=
True
,
root_weight
=
None
,
bias
=
None
):
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
degree
=
1
,
norm
=
True
,
root_weight
=
None
,
bias
=
None
):
x
=
x
.
unsqueeze
(
-
1
)
if
x
.
dim
()
==
1
else
x
x
=
x
.
unsqueeze
(
-
1
)
if
x
.
dim
()
==
1
else
x
pseudo
=
pseudo
.
unsqueeze
(
-
1
)
if
pseudo
.
dim
()
==
1
else
pseudo
pseudo
=
pseudo
.
unsqueeze
(
-
1
)
if
pseudo
.
dim
()
==
1
else
pseudo
...
@@ -58,8 +49,10 @@ class SplineConv(object):
...
@@ -58,8 +49,10 @@ class SplineConv(object):
n
,
m_out
=
x
.
size
(
0
),
weight
.
size
(
2
)
n
,
m_out
=
x
.
size
(
0
),
weight
.
size
(
2
)
# Weight each node.
# Weight each node.
data
=
SplineBasis
.
apply
(
pseudo
,
kernel_size
,
is_open_spline
,
degree
)
basis
,
weight_index
=
SplineBasis
.
apply
(
pseudo
,
kernel_size
,
out
=
SplineWeighting
.
apply
(
x
[
col
],
weight
,
*
data
)
is_open_spline
,
degree
)
weight_index
=
weight_index
.
detach
()
out
=
SplineWeighting
.
apply
(
x
[
col
],
weight
,
basis
,
weight_index
)
# Convert e x m_out to n x m_out features.
# Convert e x m_out to n x m_out features.
row_expand
=
row
.
unsqueeze
(
-
1
).
expand_as
(
out
)
row_expand
=
row
.
unsqueeze
(
-
1
).
expand_as
(
out
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment