Commit d9d943e8 authored by rusty1s's avatar rusty1s
Browse files

better degree impl

parent ad330355
...@@ -5,7 +5,7 @@ from .basis import spline_basis ...@@ -5,7 +5,7 @@ from .basis import spline_basis
from .weighting import spline_weighting from .weighting import spline_weighting
from .utils.new import new from .utils.new import new
from .utils.degree import node_degree from .utils.degree import degree as node_degree
def spline_conv(src, def spline_conv(src,
...@@ -49,7 +49,7 @@ def spline_conv(src, ...@@ -49,7 +49,7 @@ def spline_conv(src,
row, col = edge_index row, col = edge_index
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
n, e, m_out = src.size(0), row.size(0), weight.size(2) n, m_out = src.size(0), weight.size(2)
# Weight each node. # Weight each node.
basis, weight_index = spline_basis(degree, pseudo, kernel_size, basis, weight_index = spline_basis(degree, pseudo, kernel_size,
...@@ -58,14 +58,12 @@ def spline_conv(src, ...@@ -58,14 +58,12 @@ def spline_conv(src,
# Perform the real convolution => Convert e x m_out to n x m_out features. # Perform the real convolution => Convert e x m_out to n x m_out features.
zero = new(src, n, m_out).fill_(0) zero = new(src, n, m_out).fill_(0)
row_expand = row.unsqueeze(-1).expand(e, m_out) row_expand = row.unsqueeze(-1).expand_as(output)
row_expand = row_expand if torch.is_tensor(src) else Variable(row_expand) row_expand = row_expand if torch.is_tensor(src) else Variable(row_expand)
output = zero.scatter_add_(0, row_expand, output) output = zero.scatter_add_(0, row_expand, output)
# Normalize output by node degree. # Normalize output by node degree.
index = row if torch.is_tensor(src) else Variable(row) output /= node_degree(row, n, out=new(src)).unsqueeze(-1).clamp(min=1)
degree = node_degree(index, n, out=new(src))
output /= degree.unsqueeze(-1).clamp(min=1)
# Weight root node separately (if wished). # Weight root node separately (if wished).
if root_weight is not None: if root_weight is not None:
......
import torch import torch
from torch.autograd import Variable
from .new import new from .new import new
def node_degree(index, n, out=None): def degree(index, num_nodes=None, out=None):
if out is None: # pragma: no cover num_nodes = index.max() + 1 if num_nodes is None else num_nodes
zero = torch.zeros(n) out = index.new().float() if out is None else out
index = index if torch.is_tensor(out) else Variable(index)
if torch.is_tensor(out):
out.resize_(num_nodes)
else: else:
out.resize_(n) if torch.is_tensor(out) else out.data.resize_(n) out.data.resize_(num_nodes)
zero = out.fill_(0)
one = new(zero, index.size(0)).fill_(1) one = new(out, index.size(0)).fill_(1)
return zero.scatter_add_(0, index, one) return out.fill_(0).scatter_add_(0, index, one)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment