Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
10e5ee86
Commit
10e5ee86
authored
Aug 07, 2018
by
rusty1s
Browse files
cuda basis
parent
7a86525b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
341 additions
and
51 deletions
+341
-51
cpu/basis.cpp
cpu/basis.cpp
+50
-47
cuda/basis.cpp
cuda/basis.cpp
+86
-0
cuda/basis_kernel.cu
cuda/basis_kernel.cu
+196
-0
setup.py
setup.py
+5
-0
test/test_basis.py
test/test_basis.py
+0
-2
torch_spline_conv/basis.py
torch_spline_conv/basis.py
+4
-2
No files found.
cpu/basis.cpp
View file @
10e5ee86
#include <torch/torch.h>
template
<
typename
scalar
>
inline
scalar
linear
(
scalar
v
,
int64_t
k_mod
)
{
template
<
typename
scalar
_t
>
inline
scalar
_t
linear
(
scalar
_t
v
,
int64_t
k_mod
)
{
return
1
-
v
-
k_mod
+
2
*
v
*
k_mod
;
}
template
<
typename
scalar
>
inline
scalar
quadratic
(
scalar
v
,
int64_t
k_mod
)
{
template
<
typename
scalar_t
>
inline
scalar_t
quadratic
(
scalar_t
v
,
int64_t
k_mod
)
{
if
(
k_mod
==
0
)
return
0.5
*
v
*
v
-
v
+
0.5
;
else
if
(
k_mod
==
1
)
...
...
@@ -13,10 +14,10 @@ template <typename scalar> inline scalar quadratic(scalar v, int64_t k_mod) {
return
0.5
*
v
*
v
;
}
template
<
typename
scalar
>
inline
scalar
cubic
(
scalar
v
,
int64_t
k_mod
)
{
if
(
k_mod
==
0
)
{
template
<
typename
scalar
_t
>
inline
scalar
_t
cubic
(
scalar
_t
v
,
int64_t
k_mod
)
{
if
(
k_mod
==
0
)
return
(
1
-
v
)
*
(
1
-
v
)
*
(
1
-
v
)
/
6.0
;
}
else
if
(
k_mod
==
1
)
else
if
(
k_mod
==
1
)
return
(
3
*
v
*
v
*
v
-
6
*
v
*
v
+
4
)
/
6
;
else
if
(
k_mod
==
2
)
return
(
-
3
*
v
*
v
*
v
+
3
*
v
*
v
+
3
*
v
+
1
)
/
6
;
...
...
@@ -24,31 +25,6 @@ template <typename scalar> inline scalar cubic(scalar v, int64_t k_mod) {
return
v
*
v
*
v
/
6
;
}
template
<
typename
scalar
>
inline
scalar
grad_linear
(
scalar
v
,
int64_t
k_mod
)
{
return
2
*
k_mod
-
1
;
}
template
<
typename
scalar
>
inline
scalar
grad_quadratic
(
scalar
v
,
int64_t
k_mod
)
{
if
(
k_mod
==
0
)
return
v
-
1
;
else
if
(
k_mod
==
1
)
return
-
2
*
v
+
1
;
else
return
v
;
}
template
<
typename
scalar
>
inline
scalar
grad_cubic
(
scalar
v
,
int64_t
k_mod
)
{
if
(
k_mod
==
0
)
return
(
-
v
*
v
+
2
*
v
-
1
)
/
2
;
else
if
(
k_mod
==
1
)
return
(
3
*
v
*
v
-
4
*
v
)
/
2
;
else
if
(
k_mod
==
2
)
return
(
-
3
*
v
*
v
+
2
*
v
+
1
)
/
2
;
else
return
v
*
v
/
2
;
}
#define BASIS_FORWARD(M, PSEUDO, KERNEL_SIZE, IS_OPEN_SPLINE, FUNC) \
[&]() -> std::tuple<at::Tensor, at::Tensor> { \
auto E = PSEUDO.size(0), D = PSEUDO.size(1); \
...
...
@@ -94,6 +70,50 @@ template <typename scalar> inline scalar grad_cubic(scalar v, int64_t k_mod) {
return std::make_tuple(basis, weight_index); \
}()
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
linear_fw
(
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
)
{
return
BASIS_FORWARD
(
1
,
pseudo
,
kernel_size
,
is_open_spline
,
linear
);
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
quadratic_fw
(
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
)
{
return
BASIS_FORWARD
(
2
,
pseudo
,
kernel_size
,
is_open_spline
,
quadratic
);
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
cubic_fw
(
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
)
{
return
BASIS_FORWARD
(
3
,
pseudo
,
kernel_size
,
is_open_spline
,
cubic
);
}
template
<
typename
scalar_t
>
inline
scalar_t
grad_linear
(
scalar_t
v
,
int64_t
k_mod
)
{
return
2
*
k_mod
-
1
;
}
template
<
typename
scalar_t
>
inline
scalar_t
grad_quadratic
(
scalar_t
v
,
int64_t
k_mod
)
{
if
(
k_mod
==
0
)
return
v
-
1
;
else
if
(
k_mod
==
1
)
return
-
2
*
v
+
1
;
else
return
v
;
}
template
<
typename
scalar_t
>
inline
scalar_t
grad_cubic
(
scalar_t
v
,
int64_t
k_mod
)
{
if
(
k_mod
==
0
)
return
(
-
v
*
v
+
2
*
v
-
1
)
/
2
;
else
if
(
k_mod
==
1
)
return
(
3
*
v
*
v
-
4
*
v
)
/
2
;
else
if
(
k_mod
==
2
)
return
(
-
3
*
v
*
v
+
2
*
v
+
1
)
/
2
;
else
return
v
*
v
/
2
;
}
#define BASIS_BACKWARD(M, GRAD_BASIS, PSEUDO, KERNEL_SIZE, IS_OPEN_SPLINE, \
FUNC, GRAD_FUNC) \
[&]() -> at::Tensor { \
...
...
@@ -143,23 +163,6 @@ template <typename scalar> inline scalar grad_cubic(scalar v, int64_t k_mod) {
return grad_pseudo; \
}()
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
linear_fw
(
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
)
{
return
BASIS_FORWARD
(
1
,
pseudo
,
kernel_size
,
is_open_spline
,
linear
);
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
quadratic_fw
(
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
)
{
return
BASIS_FORWARD
(
2
,
pseudo
,
kernel_size
,
is_open_spline
,
quadratic
);
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
cubic_fw
(
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
)
{
return
BASIS_FORWARD
(
3
,
pseudo
,
kernel_size
,
is_open_spline
,
cubic
);
}
at
::
Tensor
linear_bw
(
at
::
Tensor
grad_basis
,
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
)
{
return
BASIS_BACKWARD
(
1
,
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
,
...
...
cuda/basis.cpp
0 → 100644
View file @
10e5ee86
#include <torch/torch.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
linear_fw_cuda
(
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
quadratic_fw_cuda
(
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
cubic_fw_cuda
(
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
);
at
::
Tensor
linear_bw_cuda
(
at
::
Tensor
grad_basis
,
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
);
at
::
Tensor
quadratic_bw_cuda
(
at
::
Tensor
grad_basis
,
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
);
at
::
Tensor
cubic_bw_cuda
(
at
::
Tensor
grad_basis
,
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
linear_fw
(
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
)
{
CHECK_CUDA
(
pseudo
);
CHECK_CUDA
(
kernel_size
);
CHECK_CUDA
(
is_open_spline
);
return
linear_fw_cuda
(
pseudo
,
kernel_size
,
is_open_spline
);
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
quadratic_fw
(
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
)
{
CHECK_CUDA
(
pseudo
);
CHECK_CUDA
(
kernel_size
);
CHECK_CUDA
(
is_open_spline
);
return
quadratic_fw_cuda
(
pseudo
,
kernel_size
,
is_open_spline
);
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
cubic_fw
(
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
)
{
CHECK_CUDA
(
pseudo
);
CHECK_CUDA
(
kernel_size
);
CHECK_CUDA
(
is_open_spline
);
return
cubic_fw_cuda
(
pseudo
,
kernel_size
,
is_open_spline
);
}
at
::
Tensor
linear_bw
(
at
::
Tensor
grad_basis
,
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
)
{
CHECK_CUDA
(
grad_basis
);
CHECK_CUDA
(
pseudo
);
CHECK_CUDA
(
kernel_size
);
CHECK_CUDA
(
is_open_spline
);
return
linear_bw_cuda
(
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
);
}
at
::
Tensor
quadratic_bw
(
at
::
Tensor
grad_basis
,
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
)
{
CHECK_CUDA
(
grad_basis
);
CHECK_CUDA
(
pseudo
);
CHECK_CUDA
(
kernel_size
);
CHECK_CUDA
(
is_open_spline
);
return
quadratic_bw_cuda
(
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
);
}
at
::
Tensor
cubic_bw
(
at
::
Tensor
grad_basis
,
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
)
{
CHECK_CUDA
(
grad_basis
);
CHECK_CUDA
(
pseudo
);
CHECK_CUDA
(
kernel_size
);
CHECK_CUDA
(
is_open_spline
);
return
cubic_bw_cuda
(
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"linear_fw"
,
&
linear_fw
,
"Linear Basis Forward (CUDA)"
);
m
.
def
(
"quadratic_fw"
,
&
quadratic_fw
,
"Quadratic Basis Forward (CUDA)"
);
m
.
def
(
"cubic_fw"
,
&
cubic_fw
,
"Cubic Basis Forward (CUDA)"
);
m
.
def
(
"linear_bw"
,
&
linear_bw
,
"Linear Basis Backward (CUDA)"
);
m
.
def
(
"quadratic_bw"
,
&
quadratic_bw
,
"Quadratic Basis Backward (CUDA)"
);
m
.
def
(
"cubic_bw"
,
&
cubic_bw
,
"Cubic Basis Backward (CUDA)"
);
}
cuda/basis_kernel.cu
0 → 100644
View file @
10e5ee86
#include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template
<
typename
scalar_t
>
struct
BasisForward
{
static
inline
__device__
scalar_t
linear
(
scalar_t
v
,
int64_t
k_mod
)
{
return
1
-
v
-
k_mod
+
2
*
v
*
k_mod
;
}
static
inline
__device__
scalar_t
quadratic
(
scalar_t
v
,
int64_t
k_mod
)
{
if
(
k_mod
==
0
)
return
0.5
*
v
*
v
-
v
+
0.5
;
else
if
(
k_mod
==
1
)
return
-
v
*
v
+
v
+
0.5
;
else
return
0.5
*
v
*
v
;
}
static
inline
__device__
scalar_t
cubic
(
scalar_t
v
,
int64_t
k_mod
)
{
if
(
k_mod
==
0
)
return
(
1
-
v
)
*
(
1
-
v
)
*
(
1
-
v
)
/
6.0
;
else
if
(
k_mod
==
1
)
return
(
3
*
v
*
v
*
v
-
6
*
v
*
v
+
4
)
/
6
;
else
if
(
k_mod
==
2
)
return
(
-
3
*
v
*
v
*
v
+
3
*
v
*
v
+
3
*
v
+
1
)
/
6
;
else
return
v
*
v
*
v
/
6
;
}
};
#define BASIS_FORWARD(M, PSEUDO, KERNEL_SIZE, IS_OPEN_SPLINE, KERNEL_NAME) \
[&]() -> std::tuple<at::Tensor, at::Tensor> { \
auto E = PSEUDO.size(0); \
auto S = (int64_t)(pow(M + 1, KERNEL_SIZE.size(0)) + 0.5); \
auto basis = at::empty({E, S}, PSEUDO.type()); \
auto weight_index = at::empty({E, S}, KERNEL_SIZE.type()); \
\
AT_DISPATCH_FLOATING_TYPES(PSEUDO.type(), "basis_forward_##M", [&] { \
KERNEL_NAME<scalar_t><<<BLOCKS(basis.numel()), THREADS>>>( \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(basis), \
at::cuda::detail::getTensorInfo<int64_t, int64_t>(weight_index), \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(PSEUDO), \
KERNEL_SIZE.data<int64_t>(), IS_OPEN_SPLINE.data<uint8_t>(), \
basis.numel()); \
}); \
\
return std::make_tuple(basis, weight_index); \
}()
#define BASIS_FORWARD_KERNEL(M, BASIS, WEIGHT_INDEX, PSEUDO, KERNEL_SIZE, \
IS_OPEN_SPLINE, NUMEL, CODE) \
[&] { \
const size_t index = blockIdx.x * blockDim.x + threadIdx.x; \
const size_t stride = blockDim.x * gridDim.x; \
for (ptrdiff_t i = index; i < NUMEL; i += stride) { \
int64_t e = i / BASIS.sizes[1], s = i % BASIS.sizes[1]; \
int64_t k = s, wi = 0, wi_offset = 1; \
scalar_t b = 1; \
\
for (ptrdiff_t d = 0; d < PSEUDO.sizes[1]; d++) { \
auto k_mod = k % (M + 1); \
k /= M + 1; \
\
auto v = PSEUDO.data[e * PSEUDO.strides[0] + d * PSEUDO.strides[1]]; \
v *= KERNEL_SIZE[d] - M * IS_OPEN_SPLINE[d]; \
\
wi += (((int64_t)v + k_mod) % KERNEL_SIZE[d]) * wi_offset; \
wi_offset *= KERNEL_SIZE[d]; \
\
v -= floor(v); \
v = CODE; \
b *= v; \
} \
\
BASIS.data[e * BASIS.sizes[1] + s] = b; \
WEIGHT_INDEX.data[e * WEIGHT_INDEX.sizes[1] + s] = wi; \
} \
}()
template
<
typename
scalar_t
>
__global__
void
linear_fw_kernel
(
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
basis
,
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
weight_index
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
pseudo
,
int64_t
*
kernel_size
,
uint8_t
*
is_open_spline
,
size_t
numel
)
{
BASIS_FORWARD_KERNEL
(
1
,
basis
,
weight_index
,
pseudo
,
kernel_size
,
is_open_spline
,
numel
,
BasisForward
<
scalar_t
>::
linear
(
v
,
k_mod
));
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
linear_fw_cuda
(
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
)
{
return
BASIS_FORWARD
(
1
,
pseudo
,
kernel_size
,
is_open_spline
,
linear_fw_kernel
);
}
template
<
typename
scalar_t
>
__global__
void
quadratic_fw_kernel
(
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
basis
,
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
weight_index
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
pseudo
,
int64_t
*
kernel_size
,
uint8_t
*
is_open_spline
,
size_t
numel
)
{
BASIS_FORWARD_KERNEL
(
2
,
basis
,
weight_index
,
pseudo
,
kernel_size
,
is_open_spline
,
numel
,
BasisForward
<
scalar_t
>::
quadratic
(
v
,
k_mod
));
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
quadratic_fw_cuda
(
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
)
{
return
BASIS_FORWARD
(
2
,
pseudo
,
kernel_size
,
is_open_spline
,
quadratic_fw_kernel
);
}
template
<
typename
scalar_t
>
__global__
void
cubic_fw_kernel
(
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
basis
,
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
weight_index
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
pseudo
,
int64_t
*
kernel_size
,
uint8_t
*
is_open_spline
,
size_t
numel
)
{
BASIS_FORWARD_KERNEL
(
3
,
basis
,
weight_index
,
pseudo
,
kernel_size
,
is_open_spline
,
numel
,
BasisForward
<
scalar_t
>::
cubic
(
v
,
k_mod
));
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
cubic_fw_cuda
(
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
)
{
return
BASIS_FORWARD
(
3
,
pseudo
,
kernel_size
,
is_open_spline
,
cubic_fw_kernel
);
}
template
<
typename
scalar_t
>
struct
BasisBackward
{
static
inline
__device__
scalar_t
linear
(
scalar_t
v
,
int64_t
k_mod
)
{
return
2
*
k_mod
-
1
;
}
static
inline
__device__
scalar_t
quadratic
(
scalar_t
v
,
int64_t
k_mod
)
{
if
(
k_mod
==
0
)
return
v
-
1
;
else
if
(
k_mod
==
1
)
return
-
2
*
v
+
1
;
else
return
v
;
}
static
inline
__device__
scalar_t
cubic
(
scalar_t
v
,
int64_t
k_mod
)
{
if
(
k_mod
==
0
)
return
(
-
v
*
v
+
2
*
v
-
1
)
/
2
;
else
if
(
k_mod
==
1
)
return
(
3
*
v
*
v
-
4
*
v
)
/
2
;
else
if
(
k_mod
==
2
)
return
(
-
3
*
v
*
v
+
2
*
v
+
1
)
/
2
;
else
return
v
*
v
/
2
;
}
};
#define BASIS_BACKWARD(M, GRAD_BASIS, PSEUDO, KERNEL_SIZE, IS_OPEN_SPLINE, \
KERNEL_NAME) \
[&]() -> at::Tensor { \
auto E = PSEUDO.size(0); \
auto D = PSEUDO.size(1); \
auto grad_pseudo = at::empty({E, D}, PSEUDO.type()); \
\
AT_DISPATCH_FLOATING_TYPES(GRAD_BASIS.type(), "basis_backward_##M", [&] { \
KERNEL_NAME<scalar_t><<<BLOCKS(grad_pseudo.numel()), THREADS>>>( \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_pseudo), \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(GRAD_BASIS), \
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(PSEUDO), \
KERNEL_SIZE.data<int64_t>(), IS_OPEN_SPLINE.data<uint8_t>(), \
grad_pseudo.numel()); \
}); \
\
return grad_pseudo; \
}
at
::
Tensor
linear_bw_cuda
(
at
::
Tensor
grad_basis
,
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
)
{
return
grad_basis
;
}
at
::
Tensor
quadratic_bw_cuda
(
at
::
Tensor
grad_basis
,
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
)
{
return
grad_basis
;
}
at
::
Tensor
cubic_bw_cuda
(
at
::
Tensor
grad_basis
,
at
::
Tensor
pseudo
,
at
::
Tensor
kernel_size
,
at
::
Tensor
is_open_spline
)
{
return
grad_basis
;
}
setup.py
View file @
10e5ee86
...
...
@@ -8,6 +8,11 @@ ext_modules = [
]
cmdclass
=
{
'build_ext'
:
torch
.
utils
.
cpp_extension
.
BuildExtension
}
if
torch
.
cuda
.
is_available
():
ext_modules
+=
[
CUDAExtension
(
'basis_cuda'
,
[
'cuda/basis.cpp'
,
'cuda/basis_kernel.cu'
])
]
__version__
=
'1.0.4'
url
=
'https://github.com/rusty1s/pytorch_spline_conv'
...
...
test/test_basis.py
View file @
10e5ee86
...
...
@@ -6,8 +6,6 @@ from torch_spline_conv.basis import SplineBasis
from
.utils
import
dtypes
,
devices
,
tensor
devices
=
[
torch
.
device
(
'cpu'
)]
tests
=
[{
'pseudo'
:
[[
0
],
[
0.0625
],
[
0.25
],
[
0.75
],
[
0.9375
],
[
1
]],
'kernel_size'
:
[
5
],
...
...
torch_spline_conv/basis.py
View file @
10e5ee86
import
torch
import
basis_cpu
if
torch
.
cuda
.
is_available
():
import
basis_cuda
implemented_degrees
=
{
1
:
'linear'
,
2
:
'quadratic'
,
3
:
'cubic'
}
def
get_func
(
name
,
tensor
):
# module = basis_cuda if tensor.is_cuda else basis_cpu
module
=
basis_cpu
module
=
basis_cuda
if
tensor
.
is_cuda
else
basis_cpu
return
getattr
(
module
,
name
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment