Commit d9100e71 authored by rusty1s's avatar rusty1s
Browse files

bugfixes

parent ffcc4df7
import torch
def node_degree(index, out=None):
one = torch.ones(index.size(1), out)
zero = torch.zeros(index.size(1), out)
return zero.scatter_add_(0, index[0], one)
def node_degree(edge_index, n, out=None):
zero = torch.zeros(n, out=out)
one = torch.ones(edge_index.size(1), out=zero.new())
return zero.scatter_add_(0, edge_index[0], one)
import torch
# from torch.autograd import Variable as Var
from .degree import node_degree
from .utils import spline_basis, spline_weighting
def spline_conv(x,
index,
edge_index,
pseudo,
weight,
kernel_size,
......@@ -15,27 +14,28 @@ def spline_conv(x,
degree=1,
bias=None):
n, e = x.size(0), edge_index.size(1)
K, m_in, m_out = weight.size()
x = x.unsqueeze(-1) if x.dim() == 1 else x
# Get features for every target node => |E| x M_in
output = x[index[1]]
output = x[edge_index[1]]
# Get B-spline basis products and weight indices for each edge.
basis, weight_index = spline_basis(degree, pseudo, kernel_size,
is_open_spline, weight.size(0))
is_open_spline, K)
# Weight gathered features based on B-spline basis and trainable weights.
output = spline_weighting(output, weight, basis, weight_index)
# Perform the real convolution => Convert |E| x M_out to N x M_out output.
row = index[0].unsqueeze(-1).expand(-1, output.size(1))
# zero = x if torch.is_tensor(x) else x.data
zero = x.new(row.size()).fill_(0)
# row, zero = row, zero if torch.is_tensor(x) else Var(row), Var(zero)
row = edge_index[0].unsqueeze(-1).expand(e, m_out)
zero = x.new(n, m_out).fill_(0)
output = zero.scatter_add_(0, row, output)
# Normalize output by node degree.
output /= node_degree(index, out=x.new()).unsqueeze(-1).clamp_(min=1)
output /= node_degree(edge_index, n, out=x.new()).clamp_(min=1)
# Weight root node separately (if wished).
if root_weight is not None:
......
......@@ -20,9 +20,8 @@ def spline_basis(degree, pseudo, kernel_size, is_open_spline, K):
weight_index = kernel_size.new(pseudo.size(0), s)
degree = implemented_degrees.get(degree)
if degree is None:
raise NotImplementedError('Basis computation not implemented for '
'specified B-spline degree')
assert degree is not None, (
'Basis computation not implemented for specified B-spline degree')
func = get_func('basis_{}'.format(degree), pseudo)
func(basis, weight_index, pseudo, kernel_size, is_open_spline, K)
......@@ -31,7 +30,7 @@ def spline_basis(degree, pseudo, kernel_size, is_open_spline, K):
def spline_weighting_fw(x, weight, basis, weight_index):
output = x.new(x.size(0), weight.size(2))
func = get_func('spline_weighting_fw', x)
func = get_func('weighting_fw', x)
func(output, x, weight, basis, weight_index)
return output
......@@ -39,7 +38,7 @@ def spline_weighting_fw(x, weight, basis, weight_index):
def spline_weighting_bw(grad_output, x, weight, basis, weight_index):
grad_input = x.new(x.size(0), weight.size(1))
grad_weight = x.new(weight)
func = get_func('spline_weighting_bw', x)
func = get_func('weighting_bw', x)
func(grad_input, grad_weight, grad_output, x, weight, basis, weight_index)
return grad_input, grad_weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment