Commit 5e006b95 authored by rusty1s's avatar rusty1s
Browse files

added bias test

parent f8ca386a
......@@ -17,9 +17,10 @@ def test_spline_conv_cpu(tensor):
kernel_size = torch.LongTensor([3, 4])
is_open_spline = torch.ByteTensor([1, 0])
root_weight = torch.arange(12.5, 13.5, step=0.5, out=x.new()).view(2, 1)
bias = Tensor(tensor, [1])
output = spline_conv(x, edge_index, pseudo, weight, kernel_size,
is_open_spline, root_weight)
is_open_spline, root_weight, 1, bias)
edgewise_output = [
1 * 0.25 * (0.5 + 1.5 + 4.5 + 5.5) + 2 * 0.25 * (1 + 2 + 5 + 6),
......@@ -29,21 +30,20 @@ def test_spline_conv_cpu(tensor):
]
expected_output = [
[12.5 * 9 + 13 * 10 + sum(edgewise_output) / 4],
[12.5 * 1 + 13 * 2],
[12.5 * 3 + 13 * 4],
[12.5 * 5 + 13 * 6],
[12.5 * 7 + 13 * 8],
[1 + 12.5 * 9 + 13 * 10 + sum(edgewise_output) / 4],
[1 + 12.5 * 1 + 13 * 2],
[1 + 12.5 * 3 + 13 * 4],
[1 + 12.5 * 5 + 13 * 6],
[1 + 12.5 * 7 + 13 * 8],
]
assert output.tolist() == expected_output
x = Variable(x, requires_grad=True)
weight = Variable(weight, requires_grad=True)
root_weight = Variable(root_weight, requires_grad=True)
x, weight = Variable(x), Variable(weight)
root_weight, bias = Variable(root_weight), Variable(bias)
output = spline_conv(x, edge_index, pseudo, weight, kernel_size,
is_open_spline, root_weight)
is_open_spline, root_weight, 1, bias)
assert output.data.tolist() == expected_output
......
......@@ -15,9 +15,6 @@ def spline_conv(x,
degree=1,
bias=None):
# TODO: degree of 0
# TODO: kernel size of 1
n, e = x.size(0), edge_index.size(1)
K, m_in, m_out = weight.size()
......
......@@ -35,7 +35,8 @@ def spline_weighting_forward(x, weight, basis, weight_index):
return output
def spline_weighting_backward(grad_output, x, weight, basis, weight_index):
def spline_weighting_backward(grad_output, x, weight, basis,
weight_index): # pragma: no cover
grad_input = x.new(x.size(0), weight.size(1))
# grad_weight computation via `atomic_add` => Initialize with zeros.
grad_weight = x.new(weight.size()).fill_(0)
......@@ -55,7 +56,7 @@ class SplineWeighting(Function):
basis, weight_index = self.basis, self.weight_index
return spline_weighting_forward(x, weight, basis, weight_index)
def backward(self, grad_output):
def backward(self, grad_output): # pragma: no cover
x, weight = self.saved_tensors
basis, weight_index = self.basis, self.weight_index
return spline_weighting_backward(grad_output, x, weight, basis,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment