Commit 6c6be201 authored by rusty1s's avatar rusty1s
Browse files

final tests

parent f7d3df7b
void spline_linear_basis_forward_cuda_Float ( THCudaTensor *basis, THCudaLongTensor *weight_index, THCudaTensor *pseudo, THCudaLongTensor *kernel_size, THCudaByteTensor *is_open_spline, int K);
void spline_linear_basis_forward_cuda_Double(THCudaDoubleTensor *basis, THCudaLongTensor *weight_index, THCudaDoubleTensor *pseudo, THCudaLongTensor *kernel_size, THCudaByteTensor *is_open_spline, int K);
void spline_quadratic_basis_forward_cuda_Float ( THCudaTensor *basis, THCudaLongTensor *weight_index, THCudaTensor *pseudo, THCudaLongTensor *kernel_size, THCudaByteTensor *is_open_spline, int K);
void spline_quadratic_basis_forward_cuda_Double(THCudaDoubleTensor *basis, THCudaLongTensor *weight_index, THCudaDoubleTensor *pseudo, THCudaLongTensor *kernel_size, THCudaByteTensor *is_open_spline, int K);
void spline_cubic_basis_forward_cuda_Float ( THCudaTensor *basis, THCudaLongTensor *weight_index, THCudaTensor *pseudo, THCudaLongTensor *kernel_size, THCudaByteTensor *is_open_spline, int K);
void spline_cubic_basis_forward_cuda_Double(THCudaDoubleTensor *basis, THCudaLongTensor *weight_index, THCudaDoubleTensor *pseudo, THCudaLongTensor *kernel_size, THCudaByteTensor *is_open_spline, int K);
void spline_weighting_forward_cuda_Float ( THCudaTensor *output, THCudaTensor *input, THCudaTensor *weight, THCudaTensor *basis, THCudaLongTensor *weight_index);
void spline_weighting_forward_cuda_Double(THCudaDoubleTensor *output, THCudaDoubleTensor *input, THCudaDoubleTensor *weight, THCudaDoubleTensor *basis, THCudaLongTensor *weight_index);
void spline_weighting_backward_input_cuda_Float ( THCudaTensor *grad_input, THCudaTensor *grad_output, THCudaTensor *weight, THCudaTensor *basis, THCudaLongTensor *weight_index);
void spline_weighting_backward_input_cuda_Double(THCudaDoubleTensor *grad_input, THCudaDoubleTensor *grad_output, THCudaDoubleTensor *weight, THCudaDoubleTensor *basis, THCudaLongTensor *weight_index);
void spline_weighting_backward_weight_cuda_Float ( THCudaTensor *grad_weight, THCudaTensor *grad_output, THCudaTensor *input, THCudaTensor *basis, THCudaLongTensor *weight_index);
void spline_weighting_backward_weight_cuda_Double(THCudaDoubleTensor *grad_weight, THCudaDoubleTensor *grad_output, THCudaDoubleTensor *input, THCudaDoubleTensor *basis, THCudaLongTensor *weight_index);
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/cpu.c"
#else
void spline_(linear_basis_forward)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) {
SPLINE_BASIS_FORWARD(1, basis, weight_index, pseudo, kernel_size, is_open_spline, K,
value = 1 - value - k_mod + 2 * value * k_mod;
)
}
void spline_(quadratic_basis_forward)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) {
SPLINE_BASIS_FORWARD(2, basis, weight_index, pseudo, kernel_size, is_open_spline, K,
if (k_mod == 0) value = 0.5 * value * value - value + 0.5;
else if (k_mod == 1) value = -value * value + value + 0.5;
else value = 0.5 * value * value;
)
}
void spline_(cubic_basis_forward)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) {
SPLINE_BASIS_FORWARD(3, basis, weight_index, pseudo, kernel_size, is_open_spline, K,
if (k_mod == 0) { value = (1 - value); value = value * value * value / 6.0; }
else if (k_mod == 1) value = (3 * value * value * value - 6 * value * value + 4) / 6;
else if (k_mod == 2) value = (-3 * value * value * value + 3 * value * value + 3 * value + 1) / 6;
else value = value * value * value / 6;
)
}
void spline_(linear_basis_backward)(THTensor *grad_pseudo, THTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) {
SPLINE_BASIS_BACKWARD(1, grad_pseudo, grad_basis, pseudo, kernel_size, is_open_spline,
value = 1 - value - k_mod + 2 * value * k_mod;
,
value = -1 + k_mod + k_mod;
)
}
void spline_(quadratic_basis_backward)(THTensor *grad_pseudo, THTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) {
SPLINE_BASIS_BACKWARD(2, grad_pseudo, grad_basis, pseudo, kernel_size, is_open_spline,
if (k_mod == 0) value = 0.5 * value * value - value + 0.5;
else if (k_mod == 1) value = -value * value + value + 0.5;
else value = 0.5 * value * value;
,
if (k_mod == 0) value = value - 1;
else if (k_mod == 1) value = -2 * value + 1;
else value = value;
)
}
void spline_(cubic_basis_backward)(THTensor *grad_pseudo, THTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) {
SPLINE_BASIS_BACKWARD(3, grad_pseudo, grad_basis, pseudo, kernel_size, is_open_spline,
if (k_mod == 0) { value = (1 - value); value = value * value * value / 6.0; }
else if (k_mod == 1) value = (3 * value * value * value - 6 * value * value + 4) / 6;
else if (k_mod == 2) value = (-3 * value * value * value + 3 * value * value + 3 * value + 1) / 6;
else value = value * value * value / 6;
,
if (k_mod == 0) value = (-value * value + 2 * value - 1) / 2;
else if (k_mod == 1) value = (3 * value * value - 4 * value) / 2;
else if (k_mod == 2) value = (-3 * value * value + 2 * value + 1) / 2;
else value = value * value / 2;
)
}
void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) {
real *weight_data = weight->storage->data + weight->storageOffset; real b;
SPLINE_WEIGHTING(output, input, basis, weight_index, THTensor_(size)(weight, 1), THTensor_(size)(weight, 2), THLongTensor_size(weight_index, 1),
for (m_out = 0; m_out < M_out; m_out++) {
value = 0;
for (s = 0; s < S; s++) {
b = *(basis_data + s * basis_stride);
w_idx = *(weight_index_data + s * weight_index_stride);
for (m_in = 0; m_in < M_in; m_in++) {
value += b * *(weight_data + w_idx * M_in * M_out + m_in * M_out + m_out) * *(input_data + m_in * input_stride);
}
}
output_data[m_out * output_stride] = value;
}
)
}
void spline_(weighting_backward_input)(THTensor *grad_input, THTensor *grad_output, THTensor *weight, THTensor *basis, THLongTensor *weight_index) {
real *weight_data = weight->storage->data + weight->storageOffset; real b;
SPLINE_WEIGHTING(grad_input, grad_output, basis, weight_index, THTensor_(size)(weight, 2), THTensor_(size)(weight, 1), THLongTensor_size(weight_index, 1),
for (m_in = 0; m_in < M_in; m_in++) {
value = 0;
for (s = 0; s < S; s++) {
b = *(basis_data + s * basis_stride);
w_idx = *(weight_index_data + s * weight_index_stride);
for (m_out = 0; m_out < M_out; m_out++) {
value += b * *(grad_output_data + m_out * grad_output_stride) * *(weight_data + w_idx * M_in * M_out + m_out * M_in + m_in);
}
}
grad_input_data[m_in * grad_input_stride] = value;
}
)
}
void spline_(weighting_backward_basis)(THTensor *grad_basis, THTensor *grad_output, THTensor *input, THTensor *weight, THLongTensor *weight_index) {
real *weight_data = weight->storage->data + weight->storageOffset;
SPLINE_WEIGHTING(grad_basis, grad_output, input, weight_index, THTensor_(size)(weight, 1), THTensor_(size)(weight, 2), THLongTensor_size(weight_index, 1),
for (m_out = 0; m_out < M_out; m_out++) {
for (s = 0; s < S; s++) {
w_idx = *(weight_index_data + s * weight_index_stride); value = 0;
for (m_in = 0; m_in < M_in; m_in++) {
value += *(input_data + m_in * input_stride) * *(weight_data + w_idx * M_in * M_out + m_in * M_out + m_out);
}
grad_basis_data[s * grad_basis_stride] += value * *(grad_output_data + m_out * grad_output_stride);
}
}
)
}
void spline_(weighting_backward_weight)(THTensor *grad_weight, THTensor *grad_output, THTensor *input, THTensor *basis, THLongTensor *weight_index) {
real *grad_weight_data = grad_weight->storage->data + grad_weight->storageOffset; real b;
SPLINE_WEIGHTING(grad_output, input, basis, weight_index, THTensor_(size)(input, 1), THTensor_(size)(grad_output, 1), THLongTensor_size(weight_index, 1),
for (m_out = 0; m_out < M_out; m_out++) {
value = *(grad_output_data + m_out * grad_output_stride);
for (s = 0; s < S; s++) {
b = *(basis_data + s * basis_stride);
w_idx = *(weight_index_data + s * weight_index_stride);
for (m_in = 0; m_in < M_in; m_in++) {
grad_weight_data[w_idx * M_in * M_out + m_in * M_out + m_out] += b * value * *(input_data + m_in * input_stride);
}
}
}
)
}
#endif
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/cuda.c"
#else
void spline_(linear_basis_forward)(THCTensor *basis, THCudaLongTensor *weight_index, THCTensor *pseudo, THCudaLongTensor *kernel_size, THCudaByteTensor *is_open_spline, int K) {
spline_kernel_(linear_basis_forward)(state, basis, weight_index, pseudo, kernel_size, is_open_spline, K);
}
void spline_(quadratic_basis_forward)(THCTensor *basis, THCudaLongTensor *weight_index, THCTensor *pseudo, THCudaLongTensor *kernel_size, THCudaByteTensor *is_open_spline, int K) {
spline_kernel_(quadratic_basis_forward)(state, basis, weight_index, pseudo, kernel_size, is_open_spline, K);
}
void spline_(cubic_basis_forward)(THCTensor *basis, THCudaLongTensor *weight_index, THCTensor *pseudo, THCudaLongTensor *kernel_size, THCudaByteTensor *is_open_spline, int K) {
spline_kernel_(cubic_basis_forward)(state, basis, weight_index, pseudo, kernel_size, is_open_spline, K);
}
void spline_(weighting_forward)(THCTensor *output, THCTensor *input, THCTensor *weight, THCTensor *basis, THCudaLongTensor *weight_index) {
spline_kernel_(weighting_forward)(state, output, input, weight, basis, weight_index);
}
void spline_(weighting_backward_input)(THCTensor *grad_input, THCTensor *grad_output, THCTensor *weight, THCTensor *basis, THCudaLongTensor *weight_index) {
spline_kernel_(weighting_backward_input)(state, grad_input, grad_output, weight, basis, weight_index);
}
void spline_(weighting_backward_weight)(THCTensor *grad_weight, THCTensor *grad_output, THCTensor *input, THCTensor *basis, THCudaLongTensor *weight_index) {
spline_kernel_(weighting_backward_weight)(state, grad_weight, grad_output, input, basis, weight_index);
}
#endif
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
from .new import new from .new import new
def node_degree(index, num_nodes, out=None): def node_degree(index, n, out=None):
zero = torch.zeros(num_nodes, out=out) zero = torch.zeros(n) if out is None else out.resize_(n).fill_(0)
one = torch.ones(index, out=new(zero)) one = new(zero, index.size(0)).fill_(1)
return zero.scatter_add_(0, index, one) return zero.scatter_add_(0, index, one)
...@@ -2,5 +2,5 @@ import torch ...@@ -2,5 +2,5 @@ import torch
from torch.autograd import Variable from torch.autograd import Variable
def new(x, *sizes): def new(x, *size):
return x.new(sizes) if torch.is_tensor(x) else Variable(x.data.new(sizes)) return x.new(*size) if torch.is_tensor(x) else Variable(x.data.new(*size))
...@@ -32,34 +32,31 @@ def weighting_backward_basis(grad_output, src, weight, weight_index): ...@@ -32,34 +32,31 @@ def weighting_backward_basis(grad_output, src, weight, weight_index):
class SplineWeighting(Function): class SplineWeighting(Function):
def __init__(self, weight_index): def forward(self, src, weight, basis, weight_index):
super(SplineWeighting, self).__init__() self.save_for_backward(src, weight, basis, weight_index)
self.weight_index = weight_index return weighting_forward(src, weight, basis, weight_index)
def forward(self, src, weight, basis):
self.save_for_backward(src, weight, basis)
return weighting_forward(src, weight, basis, self.weight_index)
def backward(self, grad_output): def backward(self, grad_output):
grad_src = grad_weight = grad_basis = None grad_src = grad_weight = grad_basis = None
src, weight, basis = self.saved_tensors src, weight, basis, weight_index = self.saved_tensors
if self.needs_input_grad[0]: if self.needs_input_grad[0]:
grad_src = weighting_backward_src(grad_output, weight, basis, grad_src = weighting_backward_src(grad_output, weight, basis,
self.weight_index) weight_index)
if self.needs_input_grad[1]: if self.needs_input_grad[1]:
K = weight.size(0) K = weight.size(0)
grad_weight = weighting_backward_weight(grad_output, src, basis, grad_weight = weighting_backward_weight(grad_output, src, basis,
self.weight_index, K) weight_index, K)
if self.needs_input_grad[2]: if self.needs_input_grad[2]:
grad_basis = weighting_backward_basis(grad_output, src, weight, grad_basis = weighting_backward_basis(grad_output, src, weight,
self.weight_index) weight_index)
return grad_src, grad_weight, grad_basis return grad_src, grad_weight, grad_basis, None
def spline_weighting(src, weight, basis, weight_index): def spline_weighting(src, weight, basis, weight_index):
if torch.is_tensor(src): if torch.is_tensor(src):
return weighting_forward(src, weight, basis, weight_index) return weighting_forward(src, weight, basis, weight_index)
else: else:
return SplineWeighting(weight_index)(src, weight, basis) return SplineWeighting()(src, weight, basis, weight_index)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment