Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
06eac75e
Commit
06eac75e
authored
Feb 27, 2020
by
rusty1s
Browse files
autograd boilerplate
parent
ac26fc19
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
223 additions
and
57 deletions
+223
-57
csrc/basis.cpp
csrc/basis.cpp
+80
-0
csrc/version.cpp
csrc/version.cpp
+21
-0
csrc/weighting.cpp
csrc/weighting.cpp
+121
-0
torch_spline_conv/__init__.py
torch_spline_conv/__init__.py
+1
-1
torch_spline_conv/basis.py
torch_spline_conv/basis.py
+0
-26
torch_spline_conv/weighting.py
torch_spline_conv/weighting.py
+0
-30
No files found.
csrc/basis.cpp
0 → 100644
View file @
06eac75e
#include <Python.h>
#include <torch/script.h>
#include "cpu/basis_cpu.h"
#include "utils.h"
#ifdef WITH_CUDA
#include "cuda/basis_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC
PyInit__basis
(
void
)
{
return
NULL
;
}
#endif
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
spline_basis_fw
(
torch
::
Tensor
pseudo
,
torch
::
Tensor
kernel_size
,
torch
::
Tensor
is_open_spline
,
int64_t
degree
)
{
if
(
pseudo
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
spline_basis_fw_cuda
(
pseudo
,
kernel_size
,
is_open_spline
,
degree
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
return
spline_basis_fw_cpu
(
pseudo
,
kernel_size
,
is_open_spline
,
degree
);
}
}
torch
::
Tensor
spline_basis_bw
(
torch
::
Tensor
grad_basis
,
torch
::
Tensor
pseudo
,
torch
::
Tensor
kernel_size
,
torch
::
Tensor
is_open_spline
,
int64_t
degree
)
{
if
(
grad_basis
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
spline_basis_bw_cuda
(
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
,
degree
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
return
spline_basis_bw_cpu
(
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
,
degree
);
}
}
using
torch
::
autograd
::
AutogradContext
;
using
torch
::
autograd
::
Variable
;
using
torch
::
autograd
::
variable_list
;
class
SplineBasis
:
public
torch
::
autograd
::
Function
<
SplineBasis
>
{
public:
static
variable_list
forward
(
AutogradContext
*
ctx
,
Variable
pseudo
,
Variable
kernel_size
,
Variable
is_open_spline
,
int64_t
degree
)
{
ctx
->
saved_data
[
"degree"
]
=
degree
;
auto
result
=
spline_basis_fw
(
pseudo
,
kernel_size
,
is_open_spline
,
degree
);
auto
basis
=
std
::
get
<
0
>
(
result
),
weight_index
=
std
::
get
<
1
>
(
result
);
ctx
->
save_for_backward
({
pseudo
,
kernel_size
,
is_open_spline
});
ctx
->
mark_non_differentiable
({
weight_index
});
return
{
basis
,
weight_index
};
}
static
variable_list
backward
(
AutogradContext
*
ctx
,
variable_list
grad_outs
)
{
auto
grad_basis
=
grad_outs
[
0
];
auto
saved
=
ctx
->
get_saved_variables
();
auto
pseudo
=
saved
[
0
],
kernel_size
=
saved
[
1
],
is_open_spline
=
saved
[
2
];
auto
gree
=
ctx
->
saved_data
[
"degree"
].
toInt
();
auto
grad_pseudo
=
spline_basis_bw
(
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
,
degree
);
return
{
grad_pseudo
,
Variable
(),
Variable
(),
Variable
()};
}
};
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
spline_basis
(
torch
::
Tensor
pseudo
,
torch
::
Tensor
kernel_size
,
torch
::
Tensor
is_open_spline
,
int64_t
degree
)
{
return
SplineBasis
::
apply
(
pseudo
,
kernel_size
,
is_open_spline
,
degree
);
}
static
auto
registry
=
torch
::
RegisterOperators
().
op
(
"torch_spline_conv::spline_basis"
,
&
spline_basis
);
csrc/version.cpp
0 → 100644
View file @
06eac75e
#include <Python.h>
#include <torch/script.h>
#ifdef WITH_CUDA
#include <cuda.h>
#endif
#ifdef _WIN32
PyMODINIT_FUNC
PyInit__version
(
void
)
{
return
NULL
;
}
#endif
int64_t
cuda_version
()
{
#ifdef WITH_CUDA
return
CUDA_VERSION
;
#else
return
-
1
;
#endif
}
static
auto
registry
=
torch
::
RegisterOperators
().
op
(
"torch_spline_conv::cuda_version"
,
&
cuda_version
);
csrc/weighting.cpp
0 → 100644
View file @
06eac75e
#include <Python.h>
#include <torch/script.h>
#include "cpu/weighting_cpu.h"
#include "utils.h"
#ifdef WITH_CUDA
#include "cuda/weighting_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC
PyInit__weighting
(
void
)
{
return
NULL
;
}
#endif
torch
::
Tensor
spline_weighting_fw
(
torch
::
Tensor
x
,
torch
::
Tensor
weight
,
torch
::
Tensor
basis
,
torch
::
Tensor
weight_index
)
{
if
(
x
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
spline_weighting_fw_cuda
(
x
,
weight
,
basis
,
weight_index
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
return
spline_weighting_fw_cpu
(
x
,
weight
,
basis
,
weight_index
);
}
}
torch
::
Tensor
spline_weighting_bw_x
(
torch
::
Tensor
grad_out
,
torch
::
Tensor
weight
,
torch
::
Tensor
basis
,
torch
::
Tensor
weight_index
)
{
if
(
grad_out
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
spline_weighting_bw_x_cuda
(
grad_out
,
weight
,
basis
,
weight_index
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
return
spline_weighting_bw_x_cpu
(
grad_out
,
weight
,
basis
,
weight_index
);
}
}
torch
::
Tensor
spline_weighting_bw_weight
(
torch
::
Tensor
grad_out
,
torch
::
Tensor
x
,
torch
::
Tensor
basis
,
torch
::
Tensor
weight_index
,
int64_t
kernel_size
)
{
if
(
grad_out
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
spline_weighting_bw_weight_cuda
(
grad_out
,
x
,
basis
,
weight_index
,
kernel_size
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
return
spline_weighting_bw_weight_cpu
(
grad_out
,
x
,
basis
,
weight_index
,
kernel_size
);
}
}
torch
::
Tensor
spline_weighting_bw_basis
(
torch
::
Tensor
grad_out
,
torch
::
Tensor
x
,
torch
::
Tensor
weight
,
torch
::
Tensor
weight_index
)
{
if
(
grad_out
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
spline_weighting_bw_basis_cuda
(
grad_out
,
x
,
weight
,
weight_index
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
return
spline_weighting_bw_basis_cpu
(
grad_out
,
x
,
weight
,
weight_index
);
}
}
using
torch
::
autograd
::
AutogradContext
;
using
torch
::
autograd
::
Variable
;
using
torch
::
autograd
::
variable_list
;
class
SplineWeighting
:
public
torch
::
autograd
::
Function
<
SplineWeighting
>
{
public:
static
variable_list
forward
(
AutogradContext
*
ctx
,
Variable
x
,
Variable
weight
,
Variable
basis
,
Variable
weight_index
)
{
auto
out
=
spline_weighting_fw
(
x
,
weight
,
basis
,
weight_index
);
ctx
->
save_for_backward
({
x
,
weight
,
basis
,
weight_index
});
return
{
out
};
}
static
variable_list
backward
(
AutogradContext
*
ctx
,
variable_list
grad_outs
)
{
auto
grad_out
=
grad_outs
[
0
];
auto
saved
=
ctx
->
get_saved_variables
();
auto
x
=
saved
[
0
],
weight
=
saved
[
1
],
basis
=
saved
[
2
],
weight_index
=
saved
[
3
];
auto
grad_x
=
Variable
();
if
(
torch
::
autograd
::
any_variable_requires_grad
({
x
}))
{
grad_x
=
spline_weighting_bw_x
(
grad_out
,
weight
,
basis
,
weight_index
);
}
auto
grad_weight
=
Variable
();
if
(
torch
::
autograd
::
any_variable_requires_grad
({
weight
}))
{
grad_weight
=
spline_weighting_bw_weight
(
grad_out
,
x
,
basis
,
weight_index
,
weight
.
size
(
0
));
}
auto
grad_basis
=
Variable
();
if
(
torch
::
autograd
::
any_variable_requires_grad
({
basis
}))
{
grad_basis
=
spline_weighting_bw_basis
(
grad_out
,
x
,
weight
,
weight_index
);
}
return
{
grad_x
,
grad_weight
,
grad_basis
,
Variable
()};
}
};
torch
::
Tensor
spline_weighting
(
torch
::
Tensor
x
,
torch
::
Tensor
weight
,
torch
::
Tensor
basis
,
torch
::
Tensor
weight_index
)
{
return
SplineWeighting
::
apply
(
x
,
weight
,
basis
,
weight_index
);
}
static
auto
registry
=
torch
::
RegisterOperators
().
op
(
"torch_spline_conv::spline_weighting"
,
&
spline_weighting
);
torch_spline_conv/__init__.py
View file @
06eac75e
...
...
@@ -20,7 +20,7 @@ except OSError as e:
raise
OSError
(
e
)
if
torch
.
version
.
cuda
is
not
None
:
# pragma: no cover
cuda_version
=
torch
.
ops
.
torch_s
catter
.
cuda_version
()
cuda_version
=
torch
.
ops
.
torch_s
pline_conv
.
cuda_version
()
if
cuda_version
==
-
1
:
major
=
minor
=
0
...
...
torch_spline_conv/basis.py
View file @
06eac75e
...
...
@@ -9,29 +9,3 @@ def spline_basis(pseudo: torch.Tensor, kernel_size: torch.Tensor,
degree
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
ops
.
torch_spline_conv
.
spline_basis
(
pseudo
,
kernel_size
,
is_open_spline
,
degree
)
# class SplineBasis(torch.autograd.Function):
# @staticmethod
# def forward(ctx, pseudo, kernel_size, is_open_spline, degree):
# ctx.save_for_backward(pseudo)
# ctx.kernel_size = kernel_size
# ctx.is_open_spline = is_open_spline
# ctx.degree = degree
# op = get_func('{}_fw'.format(implemented_degrees[degree]), pseudo)
# basis, weight_index = op(pseudo, kernel_size, is_open_spline)
# return basis, weight_index
# @staticmethod
# def backward(ctx, grad_basis, grad_weight_index):
# pseudo, = ctx.saved_tensors
# kernel_size, is_open_spline = ctx.kernel_size, ctx.is_open_spline
# degree = ctx.degree
# grad_pseudo = None
# if ctx.needs_input_grad[0]:
# grad_pseudo = op(grad_basis, pseudo, kernel_size, is_open_spline)
# return grad_pseudo, None, None, None
torch_spline_conv/weighting.py
View file @
06eac75e
...
...
@@ -7,33 +7,3 @@ def spline_weighting(x: torch.Tensor, weight: torch.Tensor,
weight_index
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
ops
.
spline_conv
.
spline_weighting
(
x
,
weight
,
basis
,
weight_index
)
# class SplineWeighting(torch.autograd.Function):
# @staticmethod
# def forward(ctx, x, weight, basis, weight_index):
# ctx.weight_index = weight_index
# ctx.save_for_backward(x, weight, basis)
# op = get_func('weighting_fw', x)
# out = op(x, weight, basis, weight_index)
# return out
# @staticmethod
# def backward(ctx, grad_out):
# x, weight, basis = ctx.saved_tensors
# grad_x = grad_weight = grad_basis = None
# if ctx.needs_input_grad[0]:
# op = get_func('weighting_bw_x', x)
# grad_x = op(grad_out, weight, basis, ctx.weight_index)
# if ctx.needs_input_grad[1]:
# op = get_func('weighting_bw_w', x)
# grad_weight = op(grad_out, x, basis, ctx.weight_index,
# weight.size(0))
# if ctx.needs_input_grad[2]:
# op = get_func('weighting_bw_b', x)
# grad_basis = op(grad_out, x, weight, ctx.weight_index)
# return grad_x, grad_weight, grad_basis, None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment