Commit 08a517fe authored by rusty1s's avatar rusty1s
Browse files

better degree impl

parent af754235
......@@ -63,7 +63,7 @@ class SplineConv(object):
output = src.new_zeros((n, m_out)).scatter_add_(0, row_expand, output)
# Normalize output by node degree.
deg = node_degree(row, n, out=src.new_empty(()))
deg = node_degree(row, n, output.dtype, output.device)
output /= deg.unsqueeze(-1).clamp(min=1)
# Weight root node separately (if wished).
......
import torch
def degree(index, num_nodes=None, out=None):
def degree(index, num_nodes=None, dtype=None, device=None):
num_nodes = index.max() + 1 if num_nodes is None else num_nodes
out = index.new_empty((), dtype=torch.float) if out is None else out
out.resize_(num_nodes).fill_(0)
out = torch.zeros((num_nodes), dtype=dtype, device=device)
return out.scatter_add_(0, index, out.new_ones((index.size(0))))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment