Commit d9100e71 authored by rusty1s's avatar rusty1s
Browse files

bugfixes

parent ffcc4df7
import torch import torch
def node_degree(index, out=None): def node_degree(edge_index, n, out=None):
one = torch.ones(index.size(1), out) zero = torch.zeros(n, out=out)
zero = torch.zeros(index.size(1), out) one = torch.ones(edge_index.size(1), out=zero.new())
return zero.scatter_add_(0, index[0], one) return zero.scatter_add_(0, edge_index[0], one)
import torch import torch
# from torch.autograd import Variable as Var
from .degree import node_degree from .degree import node_degree
from .utils import spline_basis, spline_weighting from .utils import spline_basis, spline_weighting
def spline_conv(x, def spline_conv(x,
index, edge_index,
pseudo, pseudo,
weight, weight,
kernel_size, kernel_size,
...@@ -15,27 +14,28 @@ def spline_conv(x, ...@@ -15,27 +14,28 @@ def spline_conv(x,
degree=1, degree=1,
bias=None): bias=None):
n, e = x.size(0), edge_index.size(1)
K, m_in, m_out = weight.size()
x = x.unsqueeze(-1) if x.dim() == 1 else x x = x.unsqueeze(-1) if x.dim() == 1 else x
# Get features for every target node => |E| x M_in # Get features for every target node => |E| x M_in
output = x[index[1]] output = x[edge_index[1]]
# Get B-spline basis products and weight indices for each edge. # Get B-spline basis products and weight indices for each edge.
basis, weight_index = spline_basis(degree, pseudo, kernel_size, basis, weight_index = spline_basis(degree, pseudo, kernel_size,
is_open_spline, weight.size(0)) is_open_spline, K)
# Weight gathered features based on B-spline basis and trainable weights. # Weight gathered features based on B-spline basis and trainable weights.
output = spline_weighting(output, weight, basis, weight_index) output = spline_weighting(output, weight, basis, weight_index)
# Perform the real convolution => Convert |E| x M_out to N x M_out output. # Perform the real convolution => Convert |E| x M_out to N x M_out output.
row = index[0].unsqueeze(-1).expand(-1, output.size(1)) row = edge_index[0].unsqueeze(-1).expand(e, m_out)
# zero = x if torch.is_tensor(x) else x.data zero = x.new(n, m_out).fill_(0)
zero = x.new(row.size()).fill_(0)
# row, zero = row, zero if torch.is_tensor(x) else Var(row), Var(zero)
output = zero.scatter_add_(0, row, output) output = zero.scatter_add_(0, row, output)
# Normalize output by node degree. # Normalize output by node degree.
output /= node_degree(index, out=x.new()).unsqueeze(-1).clamp_(min=1) output /= node_degree(edge_index, n, out=x.new()).clamp_(min=1)
# Weight root node separately (if wished). # Weight root node separately (if wished).
if root_weight is not None: if root_weight is not None:
......
...@@ -20,9 +20,8 @@ def spline_basis(degree, pseudo, kernel_size, is_open_spline, K): ...@@ -20,9 +20,8 @@ def spline_basis(degree, pseudo, kernel_size, is_open_spline, K):
weight_index = kernel_size.new(pseudo.size(0), s) weight_index = kernel_size.new(pseudo.size(0), s)
degree = implemented_degrees.get(degree) degree = implemented_degrees.get(degree)
if degree is None: assert degree is not None, (
raise NotImplementedError('Basis computation not implemented for ' 'Basis computation not implemented for specified B-spline degree')
'specified B-spline degree')
func = get_func('basis_{}'.format(degree), pseudo) func = get_func('basis_{}'.format(degree), pseudo)
func(basis, weight_index, pseudo, kernel_size, is_open_spline, K) func(basis, weight_index, pseudo, kernel_size, is_open_spline, K)
...@@ -31,7 +30,7 @@ def spline_basis(degree, pseudo, kernel_size, is_open_spline, K): ...@@ -31,7 +30,7 @@ def spline_basis(degree, pseudo, kernel_size, is_open_spline, K):
def spline_weighting_fw(x, weight, basis, weight_index): def spline_weighting_fw(x, weight, basis, weight_index):
output = x.new(x.size(0), weight.size(2)) output = x.new(x.size(0), weight.size(2))
func = get_func('spline_weighting_fw', x) func = get_func('weighting_fw', x)
func(output, x, weight, basis, weight_index) func(output, x, weight, basis, weight_index)
return output return output
...@@ -39,7 +38,7 @@ def spline_weighting_fw(x, weight, basis, weight_index): ...@@ -39,7 +38,7 @@ def spline_weighting_fw(x, weight, basis, weight_index):
def spline_weighting_bw(grad_output, x, weight, basis, weight_index): def spline_weighting_bw(grad_output, x, weight, basis, weight_index):
grad_input = x.new(x.size(0), weight.size(1)) grad_input = x.new(x.size(0), weight.size(1))
grad_weight = x.new(weight) grad_weight = x.new(weight)
func = get_func('spline_weighting_bw', x) func = get_func('weighting_bw', x)
func(grad_input, grad_weight, grad_output, x, weight, basis, weight_index) func(grad_input, grad_weight, grad_output, x, weight, basis, weight_index)
return grad_input, grad_weight return grad_input, grad_weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment