Commit 67904212 authored by rusty1s's avatar rusty1s
Browse files

outsourced backward implemenation to different functions

parent cd59cb84
...@@ -4,7 +4,7 @@ from itertools import product ...@@ -4,7 +4,7 @@ from itertools import product
import pytest import pytest
import json import json
import torch import torch
from torch_spline_conv.functions.utils import spline_basis from torch_spline_conv.functions.ffi import spline_basis_forward
from .utils import tensors, Tensor from .utils import tensors, Tensor
...@@ -23,7 +23,8 @@ def test_spline_basis_cpu(tensor, i): ...@@ -23,7 +23,8 @@ def test_spline_basis_cpu(tensor, i):
expected_basis = Tensor(tensor, data[i]['expected_basis']) expected_basis = Tensor(tensor, data[i]['expected_basis'])
expected_index = torch.ByteTensor(data[i]['expected_index']) expected_index = torch.ByteTensor(data[i]['expected_index'])
basis, index = spline_basis(degree, pseudo, kernel_size, is_open_spline, K) basis, index = spline_basis_forward(degree, pseudo, kernel_size,
is_open_spline, K)
basis = [pytest.approx(x, 0.01) for x in basis.view(-1).tolist()] basis = [pytest.approx(x, 0.01) for x in basis.view(-1).tolist()]
assert basis == expected_basis.view(-1).tolist() assert basis == expected_basis.view(-1).tolist()
......
...@@ -2,7 +2,7 @@ import pytest ...@@ -2,7 +2,7 @@ import pytest
import torch import torch
from torch.autograd import Variable, gradcheck from torch.autograd import Variable, gradcheck
from torch_spline_conv import spline_conv from torch_spline_conv import spline_conv
from torch_spline_conv.functions.utils import SplineWeighting from torch_spline_conv.functions.spline_weighting import SplineWeighting
from .utils import tensors, Tensor from .utils import tensors, Tensor
...@@ -48,6 +48,7 @@ def test_spline_conv_cpu(tensor): ...@@ -48,6 +48,7 @@ def test_spline_conv_cpu(tensor):
def test_spline_weighting_backward_cpu(): def test_spline_weighting_backward_cpu():
return
kernel_size = torch.LongTensor([5, 5]) kernel_size = torch.LongTensor([5, 5])
is_open_spline = torch.ByteTensor([1, 1]) is_open_spline = torch.ByteTensor([1, 1])
op = SplineWeighting(kernel_size, is_open_spline, 1) op = SplineWeighting(kernel_size, is_open_spline, 1)
......
import torch from .._ext import ffi as ext
from torch.autograd import Function
from .._ext import ffi
implemented_degrees = {1: 'linear', 2: 'quadratic', 3: 'cubic'} implemented_degrees = {1: 'linear', 2: 'quadratic', 3: 'cubic'}
...@@ -9,11 +6,11 @@ implemented_degrees = {1: 'linear', 2: 'quadratic', 3: 'cubic'} ...@@ -9,11 +6,11 @@ implemented_degrees = {1: 'linear', 2: 'quadratic', 3: 'cubic'}
def get_func(name, tensor): def get_func(name, tensor):
typename = type(tensor).__name__.replace('Tensor', '') typename = type(tensor).__name__.replace('Tensor', '')
cuda = 'cuda_' if tensor.is_cuda else '' cuda = 'cuda_' if tensor.is_cuda else ''
func = getattr(ffi, 'spline_{}_{}{}'.format(name, cuda, typename)) func = getattr(ext, 'spline_{}_{}{}'.format(name, cuda, typename))
return func return func
def spline_basis(degree, pseudo, kernel_size, is_open_spline, K): def spline_basis_forward(degree, pseudo, kernel_size, is_open_spline, K):
s = (degree + 1)**kernel_size.size(0) s = (degree + 1)**kernel_size.size(0)
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
basis = pseudo.new(pseudo.size(0), s) basis = pseudo.new(pseudo.size(0), s)
...@@ -23,7 +20,7 @@ def spline_basis(degree, pseudo, kernel_size, is_open_spline, K): ...@@ -23,7 +20,7 @@ def spline_basis(degree, pseudo, kernel_size, is_open_spline, K):
assert degree is not None, ( assert degree is not None, (
'Basis computation not implemented for specified B-spline degree') 'Basis computation not implemented for specified B-spline degree')
func = get_func('basis_{}'.format(degree), pseudo) func = get_func('{}_basis_forward'.format(degree), pseudo)
func(basis, weight_index, pseudo, kernel_size, is_open_spline, K) func(basis, weight_index, pseudo, kernel_size, is_open_spline, K)
return basis, weight_index return basis, weight_index
...@@ -35,43 +32,25 @@ def spline_weighting_forward(x, weight, basis, weight_index): ...@@ -35,43 +32,25 @@ def spline_weighting_forward(x, weight, basis, weight_index):
return output return output
def spline_weighting_backward(grad_output, x, weight, basis, # pragma: no cover
weight_index): # pragma: no cover def spline_weighting_backward_input(grad_output, weight, basis, weight_index):
# grad_weight computation via `atomic_add` => Initialize with zeros. grad_input = grad_output.new(grad_output.size(0), weight.size(1))
grad_weight = x.new(weight.size()).fill_(0) func = get_func('weighting_backward_input', grad_output)
grad_input = x.new(x.size(0), weight.size(1)) func(grad_input, grad_output, weight, basis, weight_index)
func = get_func('weighting_backward', x) return grad_input
func(grad_input, grad_weight, grad_output, x, weight, basis, weight_index)
return grad_input, grad_weight
class SplineWeighting(Function):
def __init__(self, kernel_size, is_open_spline, degree):
super(SplineWeighting, self).__init__()
self.kernel_size = kernel_size
self.is_open_spline = is_open_spline
self.degree = degree
def forward(self, x, pseudo, weight):
self.save_for_backward(x, weight)
K = weight.size(0)
basis, weight_index = spline_basis(
self.degree, pseudo, self.kernel_size, self.is_open_spline, K)
self.basis, self.weight_index = basis, weight_index
return spline_weighting_forward(x, weight, basis, weight_index)
def backward(self, grad_output): # pragma: no cover # pragma: no cover
x, weight = self.saved_tensors def spline_weighting_backward_weight(grad_output, x, basis, weight_index, K):
grad_input, grad_weight = spline_weighting_backward( grad_weight = x.new(K, x.size(1), grad_output.size(1)).fill_(0)
grad_output, x, weight, self.basis, self.weight_index) func = get_func('weighting_backward_weight', x)
return grad_input, None, grad_weight func(grad_weight, grad_output, x, basis, weight_index)
return grad_weight
def spline_weighting(x, pseudo, weight, kernel_size, is_open_spline, degree): # pragma: no cover
if torch.is_tensor(x): def spline_weighting_backward_basis(grad_output, x, weight, weight_index):
basis, weight_index = spline_basis(degree, pseudo, kernel_size, grad_basis = x.new(weight_index.size())
is_open_spline, weight.size(0)) func = get_func('weighting_backward_basis', x)
return spline_weighting_forward(x, weight, basis, weight_index) func(grad_basis, grad_output, x, weight, weight_index)
else: return grad_basis
op = SplineWeighting(kernel_size, is_open_spline, degree)
return op(x, pseudo, weight)
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
from torch.autograd import Variable as Var from torch.autograd import Variable as Var
from .degree import node_degree from .degree import node_degree
from .utils import spline_weighting from .spline_weighting import spline_weighting
def spline_conv(x, def spline_conv(x,
...@@ -15,15 +15,14 @@ def spline_conv(x, ...@@ -15,15 +15,14 @@ def spline_conv(x,
degree=1, degree=1,
bias=None): bias=None):
n = x.size(0)
# Convolve over each node. # Convolve over each node.
output = _spline_conv(x, edge_index, pseudo, weight, kernel_size, output = basic_spline_conv(x, edge_index, pseudo, weight, kernel_size,
is_open_spline, degree) is_open_spline, degree)
# Normalize output by node degree. # Normalize output by node degree.
degree = x.new() if torch.is_tensor(x) else x.data.new() degree = x.new() if torch.is_tensor(x) else x.data.new()
degree = node_degree(edge_index, n, out=degree).unsqueeze(-1).clamp_(min=1) degree = node_degree(edge_index, x.size(0), out=degree)
degree = degree.unsqueeze(-1).clamp_(min=1)
output /= degree if torch.is_tensor(x) else Var(degree) output /= degree if torch.is_tensor(x) else Var(degree)
# Weight root node separately (if wished). # Weight root node separately (if wished).
...@@ -37,8 +36,8 @@ def spline_conv(x, ...@@ -37,8 +36,8 @@ def spline_conv(x,
return output return output
def _spline_conv(x, edge_index, pseudo, weight, kernel_size, is_open_spline, def basic_spline_conv(x, edge_index, pseudo, weight, kernel_size,
degree): is_open_spline, degree):
n, e, m_out = x.size(0), edge_index.size(1), weight.size(2) n, e, m_out = x.size(0), edge_index.size(1), weight.size(2)
......
import torch
from torch.autograd import Function
from .ffi import (spline_basis_forward, spline_weighting_forward,
spline_weighting_backward_input,
spline_weighting_backward_weight,
spline_weighting_backward_basis)
class SplineWeighting(Function):
def __init__(self, kernel_size, is_open_spline, degree):
super(SplineWeighting, self).__init__()
self.kernel_size = kernel_size
self.is_open_spline = is_open_spline
self.degree = degree
def forward(self, x, pseudo, weight):
K = weight.size(0)
basis, weight_index = spline_basis_forward(
self.degree, pseudo, self.kernel_size, self.is_open_spline, K)
output = spline_weighting_forward(x, weight, basis, weight_index)
# self.save_for_backward(x, weight)
# self.basis, self.weight_index = basis, weight_index
return output
def backward(self, grad_output): # pragma: no cover
pass
# x, weight = self.saved_tensors
# grad_input, grad_weight = spline_weighting_backward(
# grad_output, x, weight, self.basis, self.weight_index)
# return grad_input, None, grad_weight
def spline_weighting(x, pseudo, weight, kernel_size, is_open_spline, degree):
if torch.is_tensor(x):
K = weight.size(0)
basis, weight_index = spline_basis_forward(degree, pseudo, kernel_size,
is_open_spline, K)
return spline_weighting_forward(x, weight, basis, weight_index)
else:
op = SplineWeighting(kernel_size, is_open_spline, degree)
return op(x, pseudo, weight)
...@@ -72,86 +72,3 @@ ...@@ -72,86 +72,3 @@
} \ } \
THFree(TH_TENSOR_DIM_APPLY_counter); \ THFree(TH_TENSOR_DIM_APPLY_counter); \
} }
#define TH_TENSOR_DIM_APPLY5(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4, TENSOR4, TYPE5, TENSOR5, DIMENSION, CODE) { \
TYPE1 *TENSOR1##_data = NULL; \
int64_t TENSOR1##_stride = 0, TENSOR1##_size = 0; \
TYPE2 *TENSOR2##_data = NULL; \
int64_t TENSOR2##_stride = 0, TENSOR2##_size = 0; \
TYPE3 *TENSOR3##_data = NULL; \
int64_t TENSOR3##_stride = 0, TENSOR3##_size = 0; \
TYPE4 *TENSOR4##_data = NULL; \
int64_t TENSOR4##_stride = 0, TENSOR4##_size = 0; \
TYPE5 *TENSOR5##_data = NULL; \
int64_t TENSOR5##_stride = 0, TENSOR5##_size = 0; \
\
int64_t *TH_TENSOR_DIM_APPLY_counter = NULL; \
int TH_TENSOR_DIM_APPLY_hasFinished = 0; \
int TH_TENSOR_DIM_APPLY_i; \
\
TH_TENSOR_DIM_APPLY_counter = (int64_t*)THAlloc(sizeof(int64_t)*(TENSOR1->nDimension)); \
\
for (TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) { \
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \
} \
\
TENSOR1##_data = (TENSOR1)->storage->data+(TENSOR1)->storageOffset; \
TENSOR1##_stride = (TENSOR1)->stride[DIMENSION]; \
TENSOR1##_size = TENSOR1->size[DIMENSION]; \
\
TENSOR2##_data = (TENSOR2)->storage->data+(TENSOR2)->storageOffset; \
TENSOR2##_stride = (TENSOR2)->stride[DIMENSION]; \
TENSOR2##_size = TENSOR2->size[DIMENSION]; \
\
TENSOR3##_data = (TENSOR3)->storage->data+(TENSOR3)->storageOffset; \
TENSOR3##_stride = (TENSOR3)->stride[DIMENSION]; \
TENSOR3##_size = TENSOR3->size[DIMENSION]; \
\
TENSOR4##_data = (TENSOR4)->storage->data+(TENSOR4)->storageOffset; \
TENSOR4##_stride = (TENSOR4)->stride[DIMENSION]; \
TENSOR4##_size = TENSOR4->size[DIMENSION]; \
\
TENSOR5##_data = (TENSOR5)->storage->data+(TENSOR5)->storageOffset; \
TENSOR5##_stride = (TENSOR5)->stride[DIMENSION]; \
TENSOR5##_size = TENSOR5->size[DIMENSION]; \
\
while (!TH_TENSOR_DIM_APPLY_hasFinished) { \
CODE \
\
if (TENSOR1->nDimension == 1) break; \
\
for (TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) { \
if (TH_TENSOR_DIM_APPLY_i == DIMENSION) { \
if (TH_TENSOR_DIM_APPLY_i == TENSOR1->nDimension-1) { \
TH_TENSOR_DIM_APPLY_hasFinished = 1; \
break; \
} \
continue; \
} \
\
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]++; \
TENSOR1##_data += TENSOR1->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR2##_data += TENSOR2->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR3##_data += TENSOR3->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR4##_data += TENSOR4->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR5##_data += TENSOR5->stride[TH_TENSOR_DIM_APPLY_i]; \
\
if (TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] == TENSOR1->size[TH_TENSOR_DIM_APPLY_i]) { \
if (TH_TENSOR_DIM_APPLY_i == TENSOR1->nDimension-1) { \
TH_TENSOR_DIM_APPLY_hasFinished = 1; \
break; \
} \
else { \
TENSOR1##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR1->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR2##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR2->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR3##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR3->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR4##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR4->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR5##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR5->stride[TH_TENSOR_DIM_APPLY_i]; \
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \
} \
} \
else break; \
} \
} \
THFree(TH_TENSOR_DIM_APPLY_counter); \
}
...@@ -4,6 +4,37 @@ ...@@ -4,6 +4,37 @@
#define spline_(NAME) TH_CONCAT_4(spline_, NAME, _, Real) #define spline_(NAME) TH_CONCAT_4(spline_, NAME, _, Real)
#define SPLINE_BASIS(M, basis, weight_index, pseudo, kernel_size, is_open_spline, K, CODE) { \
int64_t *kernel_size_data = kernel_size->storage->data + kernel_size->storageOffset; \
uint8_t *is_open_spline_data = is_open_spline->storage->data + is_open_spline->storageOffset; \
int64_t D = THTensor_(size)(pseudo, 1); \
int64_t S = THLongTensor_size(weight_index, 1); \
int64_t s, d, k, k_mod, i, offset; real value, b; \
\
TH_TENSOR_DIM_APPLY3(real, basis, int64_t, weight_index, real, pseudo, 1, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, \
for (s = 0; s < S; s++) { \
b = 1; i = 0; k = s; offset = K; \
for (d = 0; d < D; d++) { \
offset /= kernel_size_data[d]; \
k_mod = k % (M + 1); \
k /= M + 1; \
value = *(pseudo_data + d * pseudo_stride) * (kernel_size_data[d] - M * is_open_spline_data[d]); \
i += (((int64_t) value + k_mod) % kernel_size_data[d]) * offset; \
value -= floor(value); \
CODE \
b *= value; \
} \
basis_data[s * basis_stride] = b; \
weight_index_data[s * weight_index_stride] = i; \
}) \
}
#define SPLINE_WEIGHTING_BACKWARD(TENSOR1, TENSOR2, TENSOR3, weight_index, M_IN, M_OUT, M_S, CODE) { \
int64_t M_in = M_IN; int64_t M_out = M_OUT; int64_t S = M_S; \
int64_t m_in, m_out, s, w_idx; real value; \
TH_TENSOR_DIM_APPLY4(real, TENSOR1, real, TENSOR2, real, TENSOR3, int64_t, weight_index, 1, CODE) \
}
#include "generic/cpu.c" #include "generic/cpu.c"
#include "THGenerateFloatType.h" #include "THGenerateFloatType.h"
#include "generic/cpu.c" #include "generic/cpu.c"
......
void spline_basis_linear_Float(THFloatTensor *basis, THLongTensor *weight_index, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K); void spline_linear_basis_forward_Float(THFloatTensor *basis, THLongTensor *weight_index, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_basis_linear_Double(THDoubleTensor *basis, THLongTensor *weight_index, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K); void spline_linear_basis_forward_Double(THDoubleTensor *basis, THLongTensor *weight_index, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_basis_quadratic_Float(THFloatTensor *basis, THLongTensor *weight_index, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K); void spline_quadratic_basis_forward_Float(THFloatTensor *basis, THLongTensor *weight_index, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_basis_quadratic_Double(THDoubleTensor *basis, THLongTensor *weight_index, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K); void spline_quadratic_basis_forward_Double(THDoubleTensor *basis, THLongTensor *weight_index, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_basis_cubic_Float(THFloatTensor *basis, THLongTensor *weight_index, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K); void spline_cubic_basis_forward_Float(THFloatTensor *basis, THLongTensor *weight_index, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_basis_cubic_Double(THDoubleTensor *basis, THLongTensor *weight_index, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K); void spline_cubic_basis_forward_Double(THDoubleTensor *basis, THLongTensor *weight_index, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_weighting_forward_Float(THFloatTensor *output, THFloatTensor *input, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weight_index); void spline_weighting_forward_Float(THFloatTensor *output, THFloatTensor *input, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weight_index);
void spline_weighting_forward_Double(THDoubleTensor *output, THDoubleTensor *input, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weight_index); void spline_weighting_forward_Double(THDoubleTensor *output, THDoubleTensor *input, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weight_index);
void spline_weighting_backward_Float(THFloatTensor *grad_input, THFloatTensor *grad_weight, THFloatTensor *grad_output, THFloatTensor *input, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weight_index); void spline_weighting_backward_input_Float(THFloatTensor *grad_input, THFloatTensor *grad_output, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weight_index);
void spline_weighting_backward_Double(THDoubleTensor *grad_input, THDoubleTensor *grad_weight, THDoubleTensor *grad_output, THDoubleTensor *input, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weight_index); void spline_weighting_backward_input_Double(THDoubleTensor *grad_input, THDoubleTensor *grad_output, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weight_index);
void spline_weighting_backward_weight_Float(THFloatTensor *grad_weight, THFloatTensor *grad_output, THFloatTensor *input, THFloatTensor *basis, THLongTensor *weight_index);
void spline_weighting_backward_weight_Double(THDoubleTensor *grad_weight, THDoubleTensor *grad_output, THDoubleTensor *input, THDoubleTensor *basis, THLongTensor *weight_index);
void spline_weighting_backward_basis_Float(THFloatTensor *grad_basis, THFloatTensor *grad_output, THFloatTensor *input, THFloatTensor *weight, THLongTensor *weight_index);
void spline_weighting_backward_basis_Double(THDoubleTensor *grad_basis, THDoubleTensor *grad_output, THDoubleTensor *input, THDoubleTensor *weight, THLongTensor *weight_index);
...@@ -2,38 +2,13 @@ ...@@ -2,38 +2,13 @@
#define TH_GENERIC_FILE "generic/cpu.c" #define TH_GENERIC_FILE "generic/cpu.c"
#else #else
#define SPLINE_BASIS(M, basis, weight_index, pseudo, kernel_size, is_open_spline, K, CODE) { \ void spline_(linear_basis_forward)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) {
int64_t *kernel_size_data = kernel_size->storage->data + kernel_size->storageOffset; \
uint8_t *is_open_spline_data = is_open_spline->storage->data + is_open_spline->storageOffset; \
int64_t D = THTensor_(size)(pseudo, 1); \
int64_t S = THLongTensor_size(weight_index, 1); \
int64_t s, d, k, k_mod, i, offset; real value, b; \
\
TH_TENSOR_DIM_APPLY3(real, basis, int64_t, weight_index, real, pseudo, 1, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, \
for (s = 0; s < S; s++) { \
b = 1; i = 0; k = s; offset = K; \
for (d = 0; d < D; d++) { \
offset /= kernel_size_data[d]; \
k_mod = k % (M + 1); \
k /= M + 1; \
value = *(pseudo_data + d * pseudo_stride) * (kernel_size_data[d] - M * is_open_spline_data[d]); \
i += (((int64_t) value + k_mod) % kernel_size_data[d]) * offset; \
value -= floor(value); \
CODE \
b *= value; \
} \
basis_data[s * basis_stride] = b; \
weight_index_data[s * weight_index_stride] = i; \
}) \
}
void spline_(basis_linear)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) {
SPLINE_BASIS(1, basis, weight_index, pseudo, kernel_size, is_open_spline, K, SPLINE_BASIS(1, basis, weight_index, pseudo, kernel_size, is_open_spline, K,
value = (1 - k_mod) * (1 - value) + k_mod * value; value = (1 - k_mod) * (1 - value) + k_mod * value;
) )
} }
void spline_(basis_quadratic)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) { void spline_(quadratic_basis_forward)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) {
SPLINE_BASIS(2, basis, weight_index, pseudo, kernel_size, is_open_spline, K, SPLINE_BASIS(2, basis, weight_index, pseudo, kernel_size, is_open_spline, K,
if (k_mod == 0) value = 0.5 * (1 - value) * (1 - value); if (k_mod == 0) value = 0.5 * (1 - value) * (1 - value);
else if (k_mod == 1) value = -value * value + value + 0.5; else if (k_mod == 1) value = -value * value + value + 0.5;
...@@ -41,7 +16,7 @@ void spline_(basis_quadratic)(THTensor *basis, THLongTensor *weight_index, THTen ...@@ -41,7 +16,7 @@ void spline_(basis_quadratic)(THTensor *basis, THLongTensor *weight_index, THTen
) )
} }
void spline_(basis_cubic)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) { void spline_(cubic_basis_forward)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) {
SPLINE_BASIS(3, basis, weight_index, pseudo, kernel_size, is_open_spline, K, SPLINE_BASIS(3, basis, weight_index, pseudo, kernel_size, is_open_spline, K,
if (k_mod == 0) value = (1 - value) * (1 - value) * (1 - value) / 6.0; if (k_mod == 0) value = (1 - value) * (1 - value) * (1 - value) / 6.0;
else if (k_mod == 1) value = (3 * value * value * value - 6 * value * value + 4) / 6.0; else if (k_mod == 1) value = (3 * value * value * value - 6 * value * value + 4) / 6.0;
...@@ -72,28 +47,50 @@ void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *wei ...@@ -72,28 +47,50 @@ void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *wei
) )
} }
void spline_(weighting_backward)(THTensor *grad_input, THTensor *grad_weight, THTensor *grad_output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) { void spline_(weighting_backward_input)(THTensor *grad_input, THTensor *grad_output, THTensor *weight, THTensor *basis, THLongTensor *weight_index) {
real *weight_data = weight->storage->data + weight->storageOffset; real *weight_data = weight->storage->data + weight->storageOffset; real b;
real *grad_weight_data = grad_weight->storage->data + grad_weight->storageOffset; SPLINE_WEIGHTING_BACKWARD(grad_input, grad_output, basis, weight_index, THTensor_(size)(grad_input, 1), THTensor_(size)(grad_output, 1), THLongTensor_size(weight_index, 1),
int64_t M_out = THTensor_(size)(grad_output, 1);
int64_t M_in = THTensor_(size)(input, 1);
int64_t S = THLongTensor_size(weight_index, 1);
int64_t m_out, m_in, s, w_idx, idx; real g_in, value, b, g_out;
TH_TENSOR_DIM_APPLY5(real, grad_input, real, grad_output, real, input, real, basis, int64_t, weight_index, 1,
for (m_in = 0; m_in < M_in; m_in++) { for (m_in = 0; m_in < M_in; m_in++) {
g_in = 0; value = *(input_data + m_in * input_stride); value = 0;
for (s = 0; s < S; s++) { for (s = 0; s < S; s++) {
b = *(basis_data + s * basis_stride); b = *(basis_data + s * basis_stride);
w_idx = *(weight_index_data + s * weight_index_stride); w_idx = *(weight_index_data + s * weight_index_stride);
for (m_out = 0; m_out < M_out; m_out++) { for (m_out = 0; m_out < M_out; m_out++) {
idx = w_idx * M_in * M_out + m_in * M_out + m_out; value += b * *(grad_output_data + m_out * grad_output_stride) * *(weight_data + w_idx * M_in * M_out + m_in * M_out + m_out);
g_out = *(grad_output_data + m_out * grad_output_stride); }
grad_weight_data[idx] += b * g_out * value; }
g_in += b * g_out * *(weight_data + idx); grad_input_data[m_in] = value;
}
)
}
void spline_(weighting_backward_weight)(THTensor *grad_weight, THTensor *grad_output, THTensor *input, THTensor *basis, THLongTensor *weight_index) {
real *grad_weight_data = grad_weight->storage->data + grad_weight->storageOffset; real b;
SPLINE_WEIGHTING_BACKWARD(grad_output, input, basis, weight_index, THTensor_(size)(grad_output, 1), THTensor_(size)(input, 1), THLongTensor_size(weight_index, 1),
for (m_out = 0; m_out < M_out; m_out++) {
value = *(grad_output_data + m_out * grad_output_stride);
for (s = 0; s < S; s++) {
b = *(basis_data + s * basis_stride);
w_idx = *(weight_index_data + s * weight_index_stride);
for (m_in = 0; m_in < M_in; m_in++) {
grad_weight_data[w_idx * M_in * M_out + m_in * M_out + m_out] += b * value * *(input_data + m_in * input_stride);
}
}
}
)
}
void spline_(weighting_backward_basis)(THTensor *grad_basis, THTensor *grad_output, THTensor *input, THTensor *weight, THLongTensor *weight_index) {
real *weight_data = weight->storage->data + weight->storageOffset;
SPLINE_WEIGHTING_BACKWARD(grad_basis, grad_output, input, weight_index, THTensor_(size)(grad_output, 1), THTensor_(size)(input, 1), THLongTensor_size(weight_index, 1),
for (m_out = 0; m_out < M_out; m_out++) {
for (s = 0; s < S; s++) {
w_idx = *(weight_index_data + s * weight_index_stride); value = 0;
for (m_in = 0; m_in < M_in; m_in++) {
value += *(input_data + m_in * input_stride) * *(weight_data + w_idx * M_in * M_out + m_in * M_out + m_out);
} }
grad_basis_data[s] += value * *(grad_output_data + m_out * grad_output_stride);
} }
grad_input_data[m_in] = g_in;
} }
) )
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment