Commit 40f5b757 authored by rusty1s's avatar rusty1s
Browse files

works with variables

parent fbd0ffaf
import torch import torch
from torch.autograd import Variable as Var
from .degree import node_degree from .degree import node_degree
from .utils import spline_basis, spline_weighting from .utils import spline_basis, spline_weighting
...@@ -31,12 +32,14 @@ def spline_conv(x, ...@@ -31,12 +32,14 @@ def spline_conv(x,
# Perform the real convolution => Convert |E| x M_out to N x M_out output. # Perform the real convolution => Convert |E| x M_out to N x M_out output.
row = edge_index[0].unsqueeze(-1).expand(e, m_out) row = edge_index[0].unsqueeze(-1).expand(e, m_out)
zero = x.new(n, m_out).fill_(0) row = row if torch.is_tensor(x) else Var(row)
output = zero.scatter_add_(0, row, output) zero = x.new(n, m_out) if torch.is_tensor(x) else Var(x.data.new(n, m_out))
output = zero.fill_(0).scatter_add_(0, row, output)
# Normalize output by node degree. # Normalize output by node degree.
degree = node_degree(edge_index, n, out=x.new()) degree = x.new() if torch.is_tensor(x) else x.data.new()
output /= degree.unsqueeze(-1).clamp_(min=1) degree = node_degree(edge_index, n, out=degree).unsqueeze(-1).clamp_(min=1)
output /= degree if torch.is_tensor(x) else Var(degree)
# Weight root node separately (if wished). # Weight root node separately (if wished).
if root_weight is not None: if root_weight is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment