Commit 9362b2d3 authored by rusty1s's avatar rusty1s
Browse files

added cpu weighting backward

parent 6146660b
......@@ -11,16 +11,18 @@ void THTensor_(weightingForward)(THTensor *self, THTensor *src, THTensor *weight
int64_t *weightIndexData = THLongTensor_data(weightIndex);
ptrdiff_t e, mOut, s, mIn;
real v, b;
real v, b, tmp;
int64_t wi;
for (e = 0; e < THTensor_(size)(src, 0); e++) {
for (mOut = 0; mOut < THTensor_(size)(weight, 2); mOut++) {
for (mOut = 0; mOut < THTensor_(size)(self, 1); mOut++) {
v = 0;
for (s = 0; s < THTensor_(size)(basis, 1); s++) {
b = basisData[e * basis->stride[0] + s * basis->stride[1]];
wi = weightIndexData[e * weightIndex->stride[0] + s * weightIndex->stride[1]];
for (mIn = 0; mIn < THTensor_(size)(weight, 1); mIn++) {
v += b * weightData[wi * weight->stride[0] + mIn * weight->stride[1] + mOut * weight->stride[2]] * srcData[e * src->stride[0] + mIn * src->stride[1]];
for (mIn = 0; mIn < THTensor_(size)(src, 1); mIn++) {
tmp = weightData[wi * weight->stride[0] + mIn * weight->stride[1] + mOut * weight->stride[2]];
tmp *= b * srcData[e * src->stride[0] + mIn * src->stride[1]];
v += tmp;
}
}
selfData[e * self->stride[0] + mOut * self->stride[1]] = v;
......@@ -30,14 +32,88 @@ void THTensor_(weightingForward)(THTensor *self, THTensor *src, THTensor *weight
void THTensor_(weightingBackwardSrc)(THTensor *self, THTensor *gradOutput, THTensor *weight,
THTensor *basis, THLongTensor *weightIndex) {
THTensor_(fill)(self, 0);
real *selfData = THTensor_(data)(self);
real *gradOutputData = THTensor_(data)(gradOutput);
real *weightData = THTensor_(data)(weight);
real *basisData = THTensor_(data)(basis);
int64_t *weightIndexData = THLongTensor_data(weightIndex);
ptrdiff_t e, mOut, s, mIn;
real g, b, v;
int64_t wi;
for (e = 0; e < THTensor_(size)(self, 0); e++) {
for (mOut = 0; mOut < THTensor_(size)(gradOutput, 1); mOut++) {
g = gradOutputData[e * gradOutput->stride[0] + mOut * gradOutput->stride[1]];
for (s = 0; s < THTensor_(size)(basis, 1); s++) {
b = basisData[e * basis->stride[0] + s * basis->stride[1]];
wi = weightIndexData[e * weightIndex->stride[0] + s * weightIndex->stride[1]];
for (mIn = 0; mIn < THTensor_(size)(self, 1); mIn++) {
v = weightData[wi * weight->stride[0] + mIn * weight->stride[1] + mOut * weight->stride[2]];
selfData[e * self->stride[0] + mIn * self->stride[1]] += g * b * v;
}
}
}
}
}
void THTensor_(weightingBackwardWeight)(THTensor *self, THTensor *gradOutput, THTensor *src,
THTensor *basis, THLongTensor *weightIndex) {
THTensor_(fill)(self, 0);
real *selfData = THTensor_(data)(self);
real *gradOutputData = THTensor_(data)(gradOutput);
real *srcData = THTensor_(data)(src);
real *basisData = THTensor_(data)(basis);
int64_t *weightIndexData = THLongTensor_data(weightIndex);
ptrdiff_t e, mOut, s, mIn;
real g, b, v;
int64_t wi;
for (e = 0; e < THTensor_(size)(src, 0); e++) {
for (mOut = 0; mOut < THTensor_(size)(gradOutput, 1); mOut++) {
g = gradOutputData[e * gradOutput->stride[0] + mOut * gradOutput->stride[1]];
for (s = 0; s < THTensor_(size)(basis, 1); s++) {
b = basisData[e * basis->stride[0] + s * basis->stride[1]];
wi = weightIndexData[e * weightIndex->stride[0] + s * weightIndex->stride[1]];
for (mIn = 0; mIn < THTensor_(size)(src, 1); mIn++) {
v = b * g * srcData[e * src->stride[0] + mIn * src->stride[1]];
selfData[wi * self->stride[0] + mIn * self->stride[1] + mOut * self->stride[2]] += v;
}
}
}
}
}
void THTensor_(weightingBackwardBasis)(THTensor *self, THTensor *gradOutput, THTensor *src,
THTensor *weight, THLongTensor *weightIndex) {
THTensor_(fill)(self, 0);
real *selfData = THTensor_(data)(self);
real *gradOutputData = THTensor_(data)(gradOutput);
real *srcData = THTensor_(data)(src);
real *weightData = THTensor_(data)(weight);
int64_t *weightIndexData = THLongTensor_data(weightIndex);
ptrdiff_t e, mOut, s, mIn;
real g, v, tmp;
int64_t wi;
for (e = 0; e < THTensor_(size)(src, 0); e++) {
for (mOut = 0; mOut < THTensor_(size)(gradOutput, 1); mOut++) {
g = gradOutputData[e * gradOutput->stride[0] + mOut * gradOutput->stride[1]];
for (s = 0; s < THLongTensor_size(weightIndex, 1); s++) {
v = 0;
wi = weightIndexData[e * weightIndex->stride[0] + s * weightIndex->stride[1]];
for (mIn = 0; mIn < THTensor_(size)(src, 1); mIn++) {
tmp = weightData[wi * weight->stride[0] + mIn * weight->stride[1] + mOut * weight->stride[2]];
tmp *= srcData[e * src->stride[0] + mIn * src->stride[1]];
v += tmp;
}
selfData[e * self->stride[0] + s * self->stride[1]] += g * v;
}
}
}
}
#endif // TH_GENERIC_FILE
......@@ -2,7 +2,9 @@ from itertools import product
import pytest
import torch
from torch_spline_conv.weighting import spline_weighting
from torch.autograd import Variable, gradcheck
from torch_spline_conv.weighting import spline_weighting, SplineWeighting
from torch_spline_conv.basis import spline_basis
from .tensor import tensors
......@@ -19,7 +21,7 @@ tests = [{
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
def test_spline_basis_forward_cpu(tensor, i):
def test_spline_weighting_forward_cpu(tensor, i):
data = tests[i]
src = getattr(torch, tensor)(data['src'])
......@@ -33,7 +35,7 @@ def test_spline_basis_forward_cpu(tensor, i):
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
def test_spline_basis_forward_gpu(tensor, i):
def test_spline_weighting_forward_gpu(tensor, i):
data = tests[i]
src = getattr(torch.cuda, tensor)(data['src'])
......@@ -43,3 +45,19 @@ def test_spline_basis_forward_gpu(tensor, i):
output = spline_weighting(src, weight, basis, weight_index)
assert output.cpu().tolist() == data['output']
def test_spline_basis_backward_cpu():
src = torch.DoubleTensor(4, 2).uniform_(0, 1)
weight = torch.DoubleTensor(25, 2, 4).uniform_(0, 1)
kernel_size = torch.LongTensor([5, 5])
is_open_spline = torch.ByteTensor([1, 1])
pseudo = torch.DoubleTensor(4, 2).uniform_(0, 1)
basis, weight_index = spline_basis(1, pseudo, kernel_size, is_open_spline)
src = Variable(src, requires_grad=True)
weight = Variable(weight, requires_grad=True)
basis = Variable(basis, requires_grad=True)
op = SplineWeighting(weight_index)
assert gradcheck(op, (src, weight, basis), eps=1e-6, atol=1e-4) is True
......@@ -15,7 +15,6 @@ def weighting_forward(src, weight, basis, weight_index):
def weighting_backward_src(grad_output, weight, basis, weight_index):
grad_src = grad_output.new(grad_output.size(0), weight.size(1))
weight = weight.transpose(1, 2).contiguous() # Coalesced memory access.
weighting_bw_src(grad_src, grad_output, weight, basis, weight_index)
return grad_src
......@@ -49,8 +48,9 @@ class SplineWeighting(Function):
grad_src = weighting_backward_src(grad_output, weight, basis,
self.weight_index)
if self.needs_input_grad[1]:
K = weight.size(0)
grad_weight = weighting_backward_weight(grad_output, src, basis,
self.weight_index)
self.weight_index, K)
if self.needs_input_grad[2]:
grad_basis = weighting_backward_basis(grad_output, src, weight,
self.weight_index)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment