"vscode:/vscode.git/clone" did not exist on "1ad65879a1af25941abc13a43566ebdf92073e6c"
Commit 69d73030 authored by rusty1s's avatar rusty1s
Browse files

basis backward boilerplate

parent d7a83c01
...@@ -17,6 +17,7 @@ f.close() ...@@ -17,6 +17,7 @@ f.close()
def test_spline_basis_cpu(tensor, i): def test_spline_basis_cpu(tensor, i):
degree = data[i].get('degree') degree = data[i].get('degree')
pseudo = Tensor(tensor, data[i]['pseudo']) pseudo = Tensor(tensor, data[i]['pseudo'])
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
kernel_size = torch.LongTensor(data[i]['kernel_size']) kernel_size = torch.LongTensor(data[i]['kernel_size'])
is_open_spline = torch.ByteTensor(data[i]['is_open_spline']) is_open_spline = torch.ByteTensor(data[i]['is_open_spline'])
K = kernel_size.prod() K = kernel_size.prod()
......
...@@ -3,6 +3,13 @@ from .._ext import ffi as ext ...@@ -3,6 +3,13 @@ from .._ext import ffi as ext
implemented_degrees = {1: 'linear', 2: 'quadratic', 3: 'cubic'} implemented_degrees = {1: 'linear', 2: 'quadratic', 3: 'cubic'}
def get_degree_str(degree):
degree = implemented_degrees.get(degree)
assert degree is not None, (
'No implementation found for specified B-spline degree')
return degree
def get_func(name, tensor): def get_func(name, tensor):
typename = type(tensor).__name__.replace('Tensor', '') typename = type(tensor).__name__.replace('Tensor', '')
cuda = 'cuda_' if tensor.is_cuda else '' cuda = 'cuda_' if tensor.is_cuda else ''
...@@ -12,19 +19,22 @@ def get_func(name, tensor): ...@@ -12,19 +19,22 @@ def get_func(name, tensor):
def spline_basis_forward(degree, pseudo, kernel_size, is_open_spline, K): def spline_basis_forward(degree, pseudo, kernel_size, is_open_spline, K):
s = (degree + 1)**kernel_size.size(0) s = (degree + 1)**kernel_size.size(0)
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
basis = pseudo.new(pseudo.size(0), s) basis = pseudo.new(pseudo.size(0), s)
weight_index = kernel_size.new(pseudo.size(0), s) weight_index = kernel_size.new(pseudo.size(0), s)
func = get_func('{}_basis_forward'.format(get_degree_str(degree)), pseudo)
degree = implemented_degrees.get(degree)
assert degree is not None, (
'Basis computation not implemented for specified B-spline degree')
func = get_func('{}_basis_forward'.format(degree), pseudo)
func(basis, weight_index, pseudo, kernel_size, is_open_spline, K) func(basis, weight_index, pseudo, kernel_size, is_open_spline, K)
return basis, weight_index return basis, weight_index
# pragma: no cover
def spline_basis_backward(degree, grad_basis, pseudo, kernel_size,
is_open_spline):
grad_pseudo = pseudo.new(pseudo.size())
func = get_func('{}_basis_backward'.format(get_degree_str(degree)), pseudo)
func(grad_pseudo, grad_basis, pseudo, kernel_size, is_open_spline)
return grad_pseudo
def spline_weighting_forward(x, weight, basis, weight_index): def spline_weighting_forward(x, weight, basis, weight_index):
output = x.new(x.size(0), weight.size(2)) output = x.new(x.size(0), weight.size(2))
func = get_func('weighting_forward', x) func = get_func('weighting_forward', x)
......
...@@ -42,6 +42,7 @@ def basic_spline_conv(x, edge_index, pseudo, weight, kernel_size, ...@@ -42,6 +42,7 @@ def basic_spline_conv(x, edge_index, pseudo, weight, kernel_size,
n, e, m_out = x.size(0), edge_index.size(1), weight.size(2) n, e, m_out = x.size(0), edge_index.size(1), weight.size(2)
x = x.unsqueeze(-1) if x.dim() == 1 else x x = x.unsqueeze(-1) if x.dim() == 1 else x
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
# Weight gathered features based on B-spline bases and trainable weights. # Weight gathered features based on B-spline bases and trainable weights.
output = spline_weighting(x[edge_index[1]], pseudo, weight, kernel_size, output = spline_weighting(x[edge_index[1]], pseudo, weight, kernel_size,
......
import torch import torch
from torch.autograd import Function from torch.autograd import Function
from .ffi import (spline_basis_forward, spline_weighting_forward, from .ffi import (
spline_basis_forward,
spline_basis_backward,
spline_weighting_forward,
spline_weighting_backward_input, spline_weighting_backward_input,
spline_weighting_backward_basis, spline_weighting_backward_basis,
spline_weighting_backward_weight) spline_weighting_backward_weight,
)
class SplineWeighting(Function): class SplineWeighting(Function):
...@@ -20,13 +24,13 @@ class SplineWeighting(Function): ...@@ -20,13 +24,13 @@ class SplineWeighting(Function):
self.degree, pseudo, self.kernel_size, self.is_open_spline, K) self.degree, pseudo, self.kernel_size, self.is_open_spline, K)
output = spline_weighting_forward(x, weight, basis, weight_index) output = spline_weighting_forward(x, weight, basis, weight_index)
self.save_for_backward(x, weight) self.save_for_backward(x, pseudo, weight)
self.basis, self.weight_index = basis, weight_index self.basis, self.weight_index = basis, weight_index
return output return output
def backward(self, grad_output): # pragma: no cover def backward(self, grad_output): # pragma: no cover
x, weight = self.saved_tensors x, pseudo, weight = self.saved_tensors
basis, weight_index = self.basis, self.weight_index basis, weight_index = self.basis, self.weight_index
grad_input, grad_pseudo, grad_weight = None, None, None grad_input, grad_pseudo, grad_weight = None, None, None
...@@ -37,7 +41,9 @@ class SplineWeighting(Function): ...@@ -37,7 +41,9 @@ class SplineWeighting(Function):
if self.needs_input_grad[1]: if self.needs_input_grad[1]:
grad_basis = spline_weighting_backward_basis( grad_basis = spline_weighting_backward_basis(
grad_output, x, weight, weight_index) grad_output, x, weight, weight_index)
print('pseudo needs grad') grad_pseudo = spline_basis_backward(self.degree, grad_basis,
pseudo, self.kernel_size,
self.is_open_spline)
if self.needs_input_grad[2]: if self.needs_input_grad[2]:
K = weight.size(0) K = weight.size(0)
......
void spline_linear_basis_forward_Float(THFloatTensor *basis, THLongTensor *weight_index, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K); void spline_linear_basis_forward_Float( THFloatTensor *basis, THLongTensor *weight_index, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_linear_basis_forward_Double(THDoubleTensor *basis, THLongTensor *weight_index, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K); void spline_linear_basis_forward_Double(THDoubleTensor *basis, THLongTensor *weight_index, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_quadratic_basis_forward_Float( THFloatTensor *basis, THLongTensor *weight_index, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_quadratic_basis_forward_Float(THFloatTensor *basis, THLongTensor *weight_index, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_quadratic_basis_forward_Double(THDoubleTensor *basis, THLongTensor *weight_index, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K); void spline_quadratic_basis_forward_Double(THDoubleTensor *basis, THLongTensor *weight_index, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_cubic_basis_forward_Float( THFloatTensor *basis, THLongTensor *weight_index, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_cubic_basis_forward_Float(THFloatTensor *basis, THLongTensor *weight_index, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_cubic_basis_forward_Double(THDoubleTensor *basis, THLongTensor *weight_index, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K); void spline_cubic_basis_forward_Double(THDoubleTensor *basis, THLongTensor *weight_index, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_weighting_forward_Float(THFloatTensor *output, THFloatTensor *input, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weight_index); void spline_linear_basis_backward_Float( THFloatTensor *grad_pseudo, THLongTensor *grad_basis, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline);
void spline_linear_basis_backward_Double(THDoubleTensor *grad_pseudo, THLongTensor *grad_basis, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline);
void spline_quadratic_basis_backward_Float( THFloatTensor *grad_pseudo, THLongTensor *grad_basis, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline);
void spline_quadratic_basis_backward_Double(THDoubleTensor *grad_pseudo, THLongTensor *grad_basis, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline);
void spline_cubic_basis_backward_Float( THFloatTensor *grad_pseudo, THLongTensor *grad_basis, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline);
void spline_cubic_basis_backward_Double(THDoubleTensor *grad_pseudo, THLongTensor *grad_basis, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline);
void spline_weighting_forward_Float( THFloatTensor *output, THFloatTensor *input, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weight_index);
void spline_weighting_forward_Double(THDoubleTensor *output, THDoubleTensor *input, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weight_index); void spline_weighting_forward_Double(THDoubleTensor *output, THDoubleTensor *input, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weight_index);
void spline_weighting_backward_input_Float(THFloatTensor *grad_input, THFloatTensor *grad_output, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weight_index); void spline_weighting_backward_input_Float( THFloatTensor *grad_input, THFloatTensor *grad_output, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weight_index);
void spline_weighting_backward_input_Double(THDoubleTensor *grad_input, THDoubleTensor *grad_output, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weight_index); void spline_weighting_backward_input_Double(THDoubleTensor *grad_input, THDoubleTensor *grad_output, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weight_index);
void spline_weighting_backward_basis_Float(THFloatTensor *grad_basis, THFloatTensor *grad_output, THFloatTensor *input, THFloatTensor *weight, THLongTensor *weight_index); void spline_weighting_backward_basis_Float( THFloatTensor *grad_basis, THFloatTensor *grad_output, THFloatTensor *input, THFloatTensor *weight, THLongTensor *weight_index);
void spline_weighting_backward_basis_Double(THDoubleTensor *grad_basis, THDoubleTensor *grad_output, THDoubleTensor *input, THDoubleTensor *weight, THLongTensor *weight_index); void spline_weighting_backward_basis_Double(THDoubleTensor *grad_basis, THDoubleTensor *grad_output, THDoubleTensor *input, THDoubleTensor *weight, THLongTensor *weight_index);
void spline_weighting_backward_weight_Float(THFloatTensor *grad_weight, THFloatTensor *grad_output, THFloatTensor *input, THFloatTensor *basis, THLongTensor *weight_index); void spline_weighting_backward_weight_Float( THFloatTensor *grad_weight, THFloatTensor *grad_output, THFloatTensor *input, THFloatTensor *basis, THLongTensor *weight_index);
void spline_weighting_backward_weight_Double(THDoubleTensor *grad_weight, THDoubleTensor *grad_output, THDoubleTensor *input, THDoubleTensor *basis, THLongTensor *weight_index); void spline_weighting_backward_weight_Double(THDoubleTensor *grad_weight, THDoubleTensor *grad_output, THDoubleTensor *input, THDoubleTensor *basis, THLongTensor *weight_index);
...@@ -25,6 +25,15 @@ void spline_(cubic_basis_forward)(THTensor *basis, THLongTensor *weight_index, T ...@@ -25,6 +25,15 @@ void spline_(cubic_basis_forward)(THTensor *basis, THLongTensor *weight_index, T
) )
} }
void spline_(linear_basis_backward)(THTensor *grad_pseudo, THLongTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) {
}
void spline_(quadratic_basis_backward)(THTensor *grad_pseudo, THLongTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) {
}
void spline_(cubic_basis_backward)(THTensor *grad_pseudo, THLongTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) {
}
void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) { void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) {
real *weight_data = weight->storage->data + weight->storageOffset; real *weight_data = weight->storage->data + weight->storageOffset;
int64_t M_out = THTensor_(size)(output, 1); int64_t M_out = THTensor_(size)(output, 1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment