Commit 8418614e authored by rusty1s's avatar rusty1s
Browse files

prepare for backward to pseudo

parent 12a47ebc
......@@ -39,7 +39,7 @@ def test_spline_conv_cpu(tensor):
assert output.tolist() == expected_output
x, weight = Variable(x), Variable(weight)
x, weight, pseudo = Variable(x), Variable(weight), Variable(pseudo)
root_weight, bias = Variable(root_weight), Variable(bias)
output = spline_conv(x, edge_index, pseudo, weight, kernel_size,
......@@ -48,16 +48,16 @@ def test_spline_conv_cpu(tensor):
def test_spline_weighting_backward_cpu():
pseudo = [[0.25, 0.125], [0.25, 0.375], [0.75, 0.625], [0.75, 0.875]]
pseudo = torch.DoubleTensor(pseudo)
kernel_size = torch.LongTensor([5, 5])
is_open_spline = torch.ByteTensor([1, 1])
basis, index = spline_basis(1, pseudo, kernel_size, is_open_spline, 25)
op = SplineWeighting(kernel_size, is_open_spline, 1)
op = SplineWeighting(basis, index)
x = torch.DoubleTensor([[1, 2], [3, 4], [5, 6], [7, 8]])
x = Variable(x, requires_grad=True)
pseudo = [[0.25, 0.125], [0.25, 0.375], [0.75, 0.625], [0.75, 0.875]]
# pseudo = Variable(torch.DoubleTensor(pseudo), requires_grad=True)
pseudo = Variable(torch.DoubleTensor(pseudo))
weight = torch.DoubleTensor(25, 2, 4).uniform_(-1, 1)
weight = Variable(weight, requires_grad=True)
assert gradcheck(op, (x, weight), eps=1e-6, atol=1e-4) is True
assert gradcheck(op, (x, pseudo, weight), eps=1e-6, atol=1e-4) is True
......@@ -2,39 +2,7 @@ import torch
from torch.autograd import Variable as Var
from .degree import node_degree
from .utils import spline_basis, spline_weighting
def _spline_conv(x,
edge_index,
pseudo,
weight,
kernel_size,
is_open_spline,
degree=1):
n, e = x.size(0), edge_index.size(1)
K, m_in, m_out = weight.size()
x = x.unsqueeze(-1) if x.dim() == 1 else x
# Get features for every target node => |E| x M_in
output = x[edge_index[1]]
# Get B-spline basis products and weight indices for each edge.
basis, weight_index = spline_basis(degree, pseudo, kernel_size,
is_open_spline, K)
# Weight gathered features based on B-spline basis and trainable weights.
output = spline_weighting(output, weight, basis, weight_index)
# Perform the real convolution => Convert |E| x M_out to N x M_out output.
row = edge_index[0].unsqueeze(-1).expand(e, m_out)
row = row if torch.is_tensor(x) else Var(row)
zero = x.new(n, m_out) if torch.is_tensor(x) else Var(x.data.new(n, m_out))
output = zero.fill_(0).scatter_add_(0, row, output)
return output
from .utils import spline_weighting
def spline_conv(x,
......@@ -67,3 +35,23 @@ def spline_conv(x,
output += bias
return output
def _spline_conv(x, edge_index, pseudo, weight, kernel_size, is_open_spline,
degree):
n, e, m_out = x.size(0), edge_index.size(1), weight.size(2)
x = x.unsqueeze(-1) if x.dim() == 1 else x
# Weight gathered features based on B-spline bases and trainable weights.
output = spline_weighting(x[edge_index[1]], pseudo, weight, kernel_size,
is_open_spline, degree)
# Perform the real convolution => Convert e x m_out to n x m_out features.
row = edge_index[0].unsqueeze(-1).expand(e, m_out)
row = row if torch.is_tensor(x) else Var(row)
zero = x.new(n, m_out) if torch.is_tensor(x) else Var(x.data.new(n, m_out))
output = zero.fill_(0).scatter_add_(0, row, output)
return output
......@@ -37,34 +37,41 @@ def spline_weighting_forward(x, weight, basis, weight_index):
def spline_weighting_backward(grad_output, x, weight, basis,
weight_index): # pragma: no cover
grad_input = x.new(x.size(0), weight.size(1))
# grad_weight computation via `atomic_add` => Initialize with zeros.
grad_weight = x.new(weight.size()).fill_(0)
grad_input = x.new(x.size(0), weight.size(1))
func = get_func('weighting_backward', x)
func(grad_input, grad_weight, grad_output, x, weight, basis, weight_index)
return grad_input, grad_weight
class SplineWeighting(Function):
def __init__(self, basis, weight_index):
def __init__(self, kernel_size, is_open_spline, degree):
super(SplineWeighting, self).__init__()
self.basis = basis
self.weight_index = weight_index
self.kernel_size = kernel_size
self.is_open_spline = is_open_spline
self.degree = degree
def forward(self, x, weight):
def forward(self, x, pseudo, weight):
self.save_for_backward(x, weight)
basis, weight_index = self.basis, self.weight_index
K = weight.size(0)
basis, weight_index = spline_basis(
self.degree, pseudo, self.kernel_size, self.is_open_spline, K)
self.basis, self.weight_index = basis, weight_index
return spline_weighting_forward(x, weight, basis, weight_index)
def backward(self, grad_output): # pragma: no cover
x, weight = self.saved_tensors
basis, weight_index = self.basis, self.weight_index
return spline_weighting_backward(grad_output, x, weight, basis,
weight_index)
grad_input, grad_weight = spline_weighting_backward(
grad_output, x, weight, self.basis, self.weight_index)
return grad_input, None, grad_weight
def spline_weighting(x, weight, basis, weight_index):
def spline_weighting(x, pseudo, weight, kernel_size, is_open_spline, degree):
if torch.is_tensor(x):
basis, weight_index = spline_basis(degree, pseudo, kernel_size,
is_open_spline, weight.size(0))
return spline_weighting_forward(x, weight, basis, weight_index)
else:
return SplineWeighting(basis, weight_index)(x, weight)
op = SplineWeighting(kernel_size, is_open_spline, degree)
return op(x, pseudo, weight)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment