"tests/python/vscode:/vscode.git/clone" did not exist on "51ff82552150812a4135de88a16302e15baa1f89"
Commit c4b33b49 authored by rusty1s's avatar rusty1s
Browse files

added conv impl

parent 0440e1f4
# import torch import torch
from .basis import spline_basis
from .weighting import spline_weighting
def spline_conv(x, from .utils.new import new
from .utils.degree import node_degree
def spline_conv(src,
edge_index, edge_index,
pseudo, pseudo,
weight, weight,
...@@ -10,4 +16,33 @@ def spline_conv(x, ...@@ -10,4 +16,33 @@ def spline_conv(x,
degree=1, degree=1,
root_weight=None, root_weight=None,
bias=None): bias=None):
src = src.unsqueeze(-1) if src.dim() == 1 else src
row, col = edge_index
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
n, e, m_out = src.size(0), row.size(0), weight.size(2)
# Weight each node.
basis, weight_index = spline_basis(degree, pseudo, kernel_size,
is_open_spline)
output = spline_weighting(src[col], weight, basis, weight_index)
# Perform the real convolution => Convert e x m_out to n x m_out features.
zero = new(src, n, m_out).fill_(0)
row_expand = row.unsqueeze(-1).expand(e, m_out)
output = zero.scatter_add_(0, row_expand, output)
# Normalize output by node degree.
degree = node_degree(row, n, out=new(src))
output /= degree.unsqueeze(-1).clamp_(min=1)
# Weight root node separately (if wished).
if root_weight is not None:
output += torch.mm(src, root_weight)
# Add bias (if wished).
if bias is not None:
output += bias
return output
import torch
from .new import new
def node_degree(index, num_nodes, out=None):
zero = torch.zeros(num_nodes, out=out)
one = torch.ones(index, out=new(zero))
return zero.scatter_add_(0, index, one)
import torch
from torch.autograd import Variable
def new(x, *sizes):
return x.new(sizes) if torch.is_tensor(x) else Variable(x.data.new(sizes))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment