Commit 42542bff authored by rusty1s's avatar rusty1s
Browse files

test bw function

parent 40f5b757
import torch
from torch.autograd import Variable, gradcheck
from torch_spline_conv import spline_conv
from torch_spline_conv.functions.utils import SplineWeighting, spline_basis
x = torch.Tensor([[9, 10], [1, 2], [3, 4], [5, 6], [7, 8]])
index = torch.LongTensor([[0, 0, 0, 0], [1, 2, 3, 4]])
pseudo = [[0.25, 0.125], [0.25, 0.375], [0.75, 0.625], [0.75, 0.875]]
pseudo = torch.Tensor(pseudo)
weight = torch.arange(0.5, 0.5 * 25, step=0.5).view(12, 2, 1)
# print(weight[:, 0].squeeze())
kernel_size = torch.LongTensor([3, 4])
is_open_spline = torch.ByteTensor([1, 0])
root_weight = torch.arange(12.5, 13.5, step=0.5).view(2, 1)
......@@ -28,3 +29,23 @@ expected_output = [
[12.5 * 5 + 13 * 6],
[12.5 * 7 + 13 * 8],
]
print(output.tolist(), expected_output)
x = Variable(x, requires_grad=True)
weight = Variable(weight, requires_grad=True)
root_weight = Variable(root_weight, requires_grad=True)
output = spline_conv(x, index, pseudo, weight, kernel_size, is_open_spline,
root_weight)
print(output.data.tolist())
x, pseudo, weight = x.data.double(), pseudo.double(), weight.data.double()
x = x[index[1]]
x = Variable(x, requires_grad=True)
weight = Variable(weight, requires_grad=True)
basis, weight_index = spline_basis(1, pseudo, kernel_size, is_open_spline,
weight.size(0))
op = SplineWeighting(basis, weight_index)
test = gradcheck(op, (x, weight), eps=1e-6, atol=1e-4)
print(test)
......@@ -15,6 +15,9 @@ def spline_conv(x,
degree=1,
bias=None):
print('TODO: Degree of 0')
print('TODO: Kernel size of 1')
n, e = x.size(0), edge_index.size(1)
K, m_in, m_out = weight.size()
......
......@@ -37,7 +37,7 @@ def spline_weighting_forward(x, weight, basis, weight_index):
def spline_weighting_backward(grad_output, x, weight, basis, weight_index):
grad_input = x.new(x.size(0), weight.size(1)).fill_(0)
grad_weight = x.new(weight).fill_(0)
grad_weight = x.new(weight.size()).fill_(0)
func = get_func('weighting_backward', x)
func(grad_input, grad_weight, grad_output, x, weight, basis, weight_index)
return grad_input, grad_weight
......
......@@ -75,7 +75,7 @@ void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *wei
void spline_(weighting_backward)(THTensor *grad_input, THTensor *grad_weight, THTensor *grad_output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) {
real *weight_data = weight->storage->data + weight->storageOffset;
real *grad_weight_data = grad_weight->storage->data + grad_weight->storageOffset;
int64_t M_out = THTensor_(size)(grad_input, 1);
int64_t M_out = THTensor_(size)(grad_output, 1);
int64_t M_in = THTensor_(size)(input, 1);
int64_t S = THLongTensor_size(weight_index, 1);
int64_t m_out, m_in, s, i, w_idx; real g, b;
......@@ -90,7 +90,6 @@ void spline_(weighting_backward)(THTensor *grad_input, THTensor *grad_weight, TH
w_idx = i * M_in * M_out + m_in * M_out + m_out;
grad_input_data[m_in] += b * g * *(weight_data + w_idx);
grad_weight_data[w_idx] += b * g * *(input_data + m_in * input_stride);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment