Commit 06eac75e authored by rusty1s's avatar rusty1s
Browse files

autograd boilerplate

parent ac26fc19
#include <Python.h>
#include <torch/script.h>
#include "cpu/basis_cpu.h"
#include "utils.h"
#ifdef WITH_CUDA
#include "cuda/basis_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__basis(void) { return NULL; }
#endif
std::tuple<torch::Tensor, torch::Tensor>
spline_basis_fw(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) {
if (pseudo.device().is_cuda()) {
#ifdef WITH_CUDA
return spline_basis_fw_cuda(pseudo, kernel_size, is_open_spline, degree);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spline_basis_fw_cpu(pseudo, kernel_size, is_open_spline, degree);
}
}
torch::Tensor spline_basis_bw(torch::Tensor grad_basis, torch::Tensor pseudo,
torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) {
if (grad_basis.device().is_cuda()) {
#ifdef WITH_CUDA
return spline_basis_bw_cuda(grad_basis, pseudo, kernel_size, is_open_spline,
degree);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spline_basis_bw_cpu(grad_basis, pseudo, kernel_size, is_open_spline,
degree);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class SplineBasis : public torch::autograd::Function<SplineBasis> {
public:
static variable_list forward(AutogradContext *ctx, Variable pseudo,
Variable kernel_size, Variable is_open_spline,
int64_t degree) {
ctx->saved_data["degree"] = degree;
auto result = spline_basis_fw(pseudo, kernel_size, is_open_spline, degree);
auto basis = std::get<0>(result), weight_index = std::get<1>(result);
ctx->save_for_backward({pseudo, kernel_size, is_open_spline});
ctx->mark_non_differentiable({weight_index});
return {basis, weight_index};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_basis = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto pseudo = saved[0], kernel_size = saved[1], is_open_spline = saved[2];
auto gree = ctx->saved_data["degree"].toInt();
auto grad_pseudo = spline_basis_bw(grad_basis, pseudo, kernel_size,
is_open_spline, degree);
return {grad_pseudo, Variable(), Variable(), Variable()};
}
};
std::tuple<torch::Tensor, torch::Tensor>
spline_basis(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) {
return SplineBasis::apply(pseudo, kernel_size, is_open_spline, degree);
}
static auto registry = torch::RegisterOperators().op(
"torch_spline_conv::spline_basis", &spline_basis);
#include <Python.h>
#include <torch/script.h>
#ifdef WITH_CUDA
#include <cuda.h>
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__version(void) { return NULL; }
#endif
int64_t cuda_version() {
#ifdef WITH_CUDA
return CUDA_VERSION;
#else
return -1;
#endif
}
static auto registry = torch::RegisterOperators().op(
"torch_spline_conv::cuda_version", &cuda_version);
#include <Python.h>
#include <torch/script.h>
#include "cpu/weighting_cpu.h"
#include "utils.h"
#ifdef WITH_CUDA
#include "cuda/weighting_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__weighting(void) { return NULL; }
#endif
torch::Tensor spline_weighting_fw(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index) {
if (x.device().is_cuda()) {
#ifdef WITH_CUDA
return spline_weighting_fw_cuda(x, weight, basis, weight_index);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spline_weighting_fw_cpu(x, weight, basis, weight_index);
}
}
torch::Tensor spline_weighting_bw_x(torch::Tensor grad_out,
torch::Tensor weight, torch::Tensor basis,
torch::Tensor weight_index) {
if (grad_out.device().is_cuda()) {
#ifdef WITH_CUDA
return spline_weighting_bw_x_cuda(grad_out, weight, basis, weight_index);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spline_weighting_bw_x_cpu(grad_out, weight, basis, weight_index);
}
}
torch::Tensor spline_weighting_bw_weight(torch::Tensor grad_out,
torch::Tensor x, torch::Tensor basis,
torch::Tensor weight_index,
int64_t kernel_size) {
if (grad_out.device().is_cuda()) {
#ifdef WITH_CUDA
return spline_weighting_bw_weight_cuda(grad_out, x, basis, weight_index,
kernel_size);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spline_weighting_bw_weight_cpu(grad_out, x, basis, weight_index,
kernel_size);
}
}
torch::Tensor spline_weighting_bw_basis(torch::Tensor grad_out, torch::Tensor x,
torch::Tensor weight,
torch::Tensor weight_index) {
if (grad_out.device().is_cuda()) {
#ifdef WITH_CUDA
return spline_weighting_bw_basis_cuda(grad_out, x, weight, weight_index);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spline_weighting_bw_basis_cpu(grad_out, x, weight, weight_index);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class SplineWeighting : public torch::autograd::Function<SplineWeighting> {
public:
static variable_list forward(AutogradContext *ctx, Variable x,
Variable weight, Variable basis,
Variable weight_index) {
auto out = spline_weighting_fw(x, weight, basis, weight_index);
ctx->save_for_backward({x, weight, basis, weight_index});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto x = saved[0], weight = saved[1], basis = saved[2],
weight_index = saved[3];
auto grad_x = Variable();
if (torch::autograd::any_variable_requires_grad({x})) {
grad_x = spline_weighting_bw_x(grad_out, weight, basis, weight_index);
}
auto grad_weight = Variable();
if (torch::autograd::any_variable_requires_grad({weight})) {
grad_weight = spline_weighting_bw_weight(grad_out, x, basis, weight_index,
weight.size(0));
}
auto grad_basis = Variable();
if (torch::autograd::any_variable_requires_grad({basis})) {
grad_basis = spline_weighting_bw_basis(grad_out, x, weight, weight_index);
}
return {grad_x, grad_weight, grad_basis, Variable()};
}
};
torch::Tensor spline_weighting(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index) {
return SplineWeighting::apply(x, weight, basis, weight_index);
}
static auto registry = torch::RegisterOperators().op(
"torch_spline_conv::spline_weighting", &spline_weighting);
...@@ -20,7 +20,7 @@ except OSError as e: ...@@ -20,7 +20,7 @@ except OSError as e:
raise OSError(e) raise OSError(e)
if torch.version.cuda is not None: # pragma: no cover if torch.version.cuda is not None: # pragma: no cover
cuda_version = torch.ops.torch_scatter.cuda_version() cuda_version = torch.ops.torch_spline_conv.cuda_version()
if cuda_version == -1: if cuda_version == -1:
major = minor = 0 major = minor = 0
......
...@@ -9,29 +9,3 @@ def spline_basis(pseudo: torch.Tensor, kernel_size: torch.Tensor, ...@@ -9,29 +9,3 @@ def spline_basis(pseudo: torch.Tensor, kernel_size: torch.Tensor,
degree: int) -> Tuple[torch.Tensor, torch.Tensor]: degree: int) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_spline_conv.spline_basis(pseudo, kernel_size, return torch.ops.torch_spline_conv.spline_basis(pseudo, kernel_size,
is_open_spline, degree) is_open_spline, degree)
# class SplineBasis(torch.autograd.Function):
# @staticmethod
# def forward(ctx, pseudo, kernel_size, is_open_spline, degree):
# ctx.save_for_backward(pseudo)
# ctx.kernel_size = kernel_size
# ctx.is_open_spline = is_open_spline
# ctx.degree = degree
# op = get_func('{}_fw'.format(implemented_degrees[degree]), pseudo)
# basis, weight_index = op(pseudo, kernel_size, is_open_spline)
# return basis, weight_index
# @staticmethod
# def backward(ctx, grad_basis, grad_weight_index):
# pseudo, = ctx.saved_tensors
# kernel_size, is_open_spline = ctx.kernel_size, ctx.is_open_spline
# degree = ctx.degree
# grad_pseudo = None
# if ctx.needs_input_grad[0]:
# grad_pseudo = op(grad_basis, pseudo, kernel_size, is_open_spline)
# return grad_pseudo, None, None, None
...@@ -7,33 +7,3 @@ def spline_weighting(x: torch.Tensor, weight: torch.Tensor, ...@@ -7,33 +7,3 @@ def spline_weighting(x: torch.Tensor, weight: torch.Tensor,
weight_index: torch.Tensor) -> torch.Tensor: weight_index: torch.Tensor) -> torch.Tensor:
return torch.ops.spline_conv.spline_weighting(x, weight, basis, return torch.ops.spline_conv.spline_weighting(x, weight, basis,
weight_index) weight_index)
# class SplineWeighting(torch.autograd.Function):
# @staticmethod
# def forward(ctx, x, weight, basis, weight_index):
# ctx.weight_index = weight_index
# ctx.save_for_backward(x, weight, basis)
# op = get_func('weighting_fw', x)
# out = op(x, weight, basis, weight_index)
# return out
# @staticmethod
# def backward(ctx, grad_out):
# x, weight, basis = ctx.saved_tensors
# grad_x = grad_weight = grad_basis = None
# if ctx.needs_input_grad[0]:
# op = get_func('weighting_bw_x', x)
# grad_x = op(grad_out, weight, basis, ctx.weight_index)
# if ctx.needs_input_grad[1]:
# op = get_func('weighting_bw_w', x)
# grad_weight = op(grad_out, x, basis, ctx.weight_index,
# weight.size(0))
# if ctx.needs_input_grad[2]:
# op = get_func('weighting_bw_b', x)
# grad_basis = op(grad_out, x, weight, ctx.weight_index)
# return grad_x, grad_weight, grad_basis, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment