"vscode:/vscode.git/clone" did not exist on "fc1e5a973cb7c86f63a2a50e2785529afcac0eba"
Commit 40f5b757 authored by rusty1s's avatar rusty1s
Browse files

works with variables

parent fbd0ffaf
import torch import torch
from torch.autograd import Variable as Var
from .degree import node_degree from .degree import node_degree
from .utils import spline_basis, spline_weighting from .utils import spline_basis, spline_weighting
...@@ -31,12 +32,14 @@ def spline_conv(x, ...@@ -31,12 +32,14 @@ def spline_conv(x,
# Perform the real convolution => Convert |E| x M_out to N x M_out output. # Perform the real convolution => Convert |E| x M_out to N x M_out output.
row = edge_index[0].unsqueeze(-1).expand(e, m_out) row = edge_index[0].unsqueeze(-1).expand(e, m_out)
zero = x.new(n, m_out).fill_(0) row = row if torch.is_tensor(x) else Var(row)
output = zero.scatter_add_(0, row, output) zero = x.new(n, m_out) if torch.is_tensor(x) else Var(x.data.new(n, m_out))
output = zero.fill_(0).scatter_add_(0, row, output)
# Normalize output by node degree. # Normalize output by node degree.
degree = node_degree(edge_index, n, out=x.new()) degree = x.new() if torch.is_tensor(x) else x.data.new()
output /= degree.unsqueeze(-1).clamp_(min=1) degree = node_degree(edge_index, n, out=degree).unsqueeze(-1).clamp_(min=1)
output /= degree if torch.is_tensor(x) else Var(degree)
# Weight root node separately (if wished). # Weight root node separately (if wished).
if root_weight is not None: if root_weight is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment