Commit 8f006b2d authored by rusty1s's avatar rusty1s
Browse files

outsource spline conv

parent 5e006b95
......@@ -5,15 +5,13 @@ from .degree import node_degree
from .utils import spline_basis, spline_weighting
def spline_conv(x,
def _spline_conv(x,
edge_index,
pseudo,
weight,
kernel_size,
is_open_spline,
root_weight=None,
degree=1,
bias=None):
degree=1):
n, e = x.size(0), edge_index.size(1)
K, m_in, m_out = weight.size()
......@@ -34,7 +32,22 @@ def spline_conv(x,
row = edge_index[0].unsqueeze(-1).expand(e, m_out)
row = row if torch.is_tensor(x) else Var(row)
zero = x.new(n, m_out) if torch.is_tensor(x) else Var(x.data.new(n, m_out))
output = zero.fill_(0).scatter_add_(0, row, output)
return zero.fill_(0).scatter_add_(0, row, output)
def spline_conv(x,
edge_index,
pseudo,
weight,
kernel_size,
is_open_spline,
root_weight=None,
degree=1,
bias=None):
n = x.size(0)
output = _spline_conv(x, edge_index, pseudo, weight, kernel_size,
is_open_spline, degree)
# Normalize output by node degree.
degree = x.new() if torch.is_tensor(x) else x.data.new()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment