"examples/pytorch/hilander/data/.gitkeep" did not exist on "bfbaaeafe5fb3e0868c9453bde3feaa2dc78f1fb"
Commit 42542bff authored by rusty1s's avatar rusty1s
Browse files

test bw function

parent 40f5b757
import torch import torch
from torch.autograd import Variable, gradcheck
from torch_spline_conv import spline_conv from torch_spline_conv import spline_conv
from torch_spline_conv.functions.utils import SplineWeighting, spline_basis
x = torch.Tensor([[9, 10], [1, 2], [3, 4], [5, 6], [7, 8]]) x = torch.Tensor([[9, 10], [1, 2], [3, 4], [5, 6], [7, 8]])
index = torch.LongTensor([[0, 0, 0, 0], [1, 2, 3, 4]]) index = torch.LongTensor([[0, 0, 0, 0], [1, 2, 3, 4]])
pseudo = [[0.25, 0.125], [0.25, 0.375], [0.75, 0.625], [0.75, 0.875]] pseudo = [[0.25, 0.125], [0.25, 0.375], [0.75, 0.625], [0.75, 0.875]]
pseudo = torch.Tensor(pseudo) pseudo = torch.Tensor(pseudo)
weight = torch.arange(0.5, 0.5 * 25, step=0.5).view(12, 2, 1) weight = torch.arange(0.5, 0.5 * 25, step=0.5).view(12, 2, 1)
# print(weight[:, 0].squeeze())
kernel_size = torch.LongTensor([3, 4]) kernel_size = torch.LongTensor([3, 4])
is_open_spline = torch.ByteTensor([1, 0]) is_open_spline = torch.ByteTensor([1, 0])
root_weight = torch.arange(12.5, 13.5, step=0.5).view(2, 1) root_weight = torch.arange(12.5, 13.5, step=0.5).view(2, 1)
...@@ -28,3 +29,23 @@ expected_output = [ ...@@ -28,3 +29,23 @@ expected_output = [
[12.5 * 5 + 13 * 6], [12.5 * 5 + 13 * 6],
[12.5 * 7 + 13 * 8], [12.5 * 7 + 13 * 8],
] ]
print(output.tolist(), expected_output)
x = Variable(x, requires_grad=True)
weight = Variable(weight, requires_grad=True)
root_weight = Variable(root_weight, requires_grad=True)
output = spline_conv(x, index, pseudo, weight, kernel_size, is_open_spline,
root_weight)
print(output.data.tolist())
x, pseudo, weight = x.data.double(), pseudo.double(), weight.data.double()
x = x[index[1]]
x = Variable(x, requires_grad=True)
weight = Variable(weight, requires_grad=True)
basis, weight_index = spline_basis(1, pseudo, kernel_size, is_open_spline,
weight.size(0))
op = SplineWeighting(basis, weight_index)
test = gradcheck(op, (x, weight), eps=1e-6, atol=1e-4)
print(test)
...@@ -15,6 +15,9 @@ def spline_conv(x, ...@@ -15,6 +15,9 @@ def spline_conv(x,
degree=1, degree=1,
bias=None): bias=None):
print('TODO: Degree of 0')
print('TODO: Kernel size of 1')
n, e = x.size(0), edge_index.size(1) n, e = x.size(0), edge_index.size(1)
K, m_in, m_out = weight.size() K, m_in, m_out = weight.size()
......
...@@ -37,7 +37,7 @@ def spline_weighting_forward(x, weight, basis, weight_index): ...@@ -37,7 +37,7 @@ def spline_weighting_forward(x, weight, basis, weight_index):
def spline_weighting_backward(grad_output, x, weight, basis, weight_index): def spline_weighting_backward(grad_output, x, weight, basis, weight_index):
grad_input = x.new(x.size(0), weight.size(1)).fill_(0) grad_input = x.new(x.size(0), weight.size(1)).fill_(0)
grad_weight = x.new(weight).fill_(0) grad_weight = x.new(weight.size()).fill_(0)
func = get_func('weighting_backward', x) func = get_func('weighting_backward', x)
func(grad_input, grad_weight, grad_output, x, weight, basis, weight_index) func(grad_input, grad_weight, grad_output, x, weight, basis, weight_index)
return grad_input, grad_weight return grad_input, grad_weight
......
...@@ -75,7 +75,7 @@ void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *wei ...@@ -75,7 +75,7 @@ void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *wei
void spline_(weighting_backward)(THTensor *grad_input, THTensor *grad_weight, THTensor *grad_output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) { void spline_(weighting_backward)(THTensor *grad_input, THTensor *grad_weight, THTensor *grad_output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) {
real *weight_data = weight->storage->data + weight->storageOffset; real *weight_data = weight->storage->data + weight->storageOffset;
real *grad_weight_data = grad_weight->storage->data + grad_weight->storageOffset; real *grad_weight_data = grad_weight->storage->data + grad_weight->storageOffset;
int64_t M_out = THTensor_(size)(grad_input, 1); int64_t M_out = THTensor_(size)(grad_output, 1);
int64_t M_in = THTensor_(size)(input, 1); int64_t M_in = THTensor_(size)(input, 1);
int64_t S = THLongTensor_size(weight_index, 1); int64_t S = THLongTensor_size(weight_index, 1);
int64_t m_out, m_in, s, i, w_idx; real g, b; int64_t m_out, m_in, s, i, w_idx; real g, b;
...@@ -90,7 +90,6 @@ void spline_(weighting_backward)(THTensor *grad_input, THTensor *grad_weight, TH ...@@ -90,7 +90,6 @@ void spline_(weighting_backward)(THTensor *grad_input, THTensor *grad_weight, TH
w_idx = i * M_in * M_out + m_in * M_out + m_out; w_idx = i * M_in * M_out + m_in * M_out + m_out;
grad_input_data[m_in] += b * g * *(weight_data + w_idx); grad_input_data[m_in] += b * g * *(weight_data + w_idx);
grad_weight_data[w_idx] += b * g * *(input_data + m_in * input_stride); grad_weight_data[w_idx] += b * g * *(input_data + m_in * input_stride);
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment