Commit 235d586b authored by rusty1s's avatar rusty1s
Browse files

better degree impl

parent efaa27c1
...@@ -4,27 +4,25 @@ from torch_cluster.functions.utils.degree import node_degree ...@@ -4,27 +4,25 @@ from torch_cluster.functions.utils.degree import node_degree
def test_node_degree_cpu(): def test_node_degree_cpu():
row = torch.LongTensor([0, 1, 1, 0, 0, 3, 0]) target = torch.LongTensor([0, 1, 1, 0, 0, 3, 0])
degree = node_degree(target, 4)
expected_degree = [4, 2, 0, 1] expected_degree = [4, 2, 0, 1]
degree = node_degree(row, 4)
assert degree.type() == torch.LongTensor().type() assert degree.type() == torch.LongTensor().type()
assert degree.tolist() == expected_degree assert degree.tolist() == expected_degree
degree = node_degree(row, 4, out=torch.FloatTensor()) degree = node_degree(target, 4, out=torch.FloatTensor())
assert degree.type() == torch.FloatTensor().type() assert degree.type() == torch.FloatTensor().type()
assert degree.tolist() == expected_degree assert degree.tolist() == expected_degree
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA') @pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
def test_node_degree_gpu(): # pragma: no cover def test_node_degree_gpu(): # pragma: no cover
row = torch.cuda.LongTensor([0, 1, 1, 0, 0, 3, 0]) target = torch.cuda.LongTensor([0, 1, 1, 0, 0, 3, 0])
degree = node_degree(target, 4)
expected_degree = [4, 2, 0, 1] expected_degree = [4, 2, 0, 1]
degree = node_degree(row, 4)
assert degree.type() == torch.cuda.LongTensor().type() assert degree.type() == torch.cuda.LongTensor().type()
assert degree.cpu().tolist() == expected_degree assert degree.cpu().tolist() == expected_degree
degree = node_degree(row, 4, out=torch.cuda.FloatTensor()) degree = node_degree(target, 4, out=torch.cuda.FloatTensor())
assert degree.type() == torch.cuda.FloatTensor().type() assert degree.type() == torch.cuda.FloatTensor().type()
assert degree.cpu().tolist() == expected_degree assert degree.cpu().tolist() == expected_degree
import torch import torch
def node_degree(row, num_nodes, out=None): def node_degree(target, num_nodes, out=None):
out = row.new() if out is None else out out = target.new(num_nodes) if out is None else out
zero = torch.zeros(num_nodes, out=out) zero = torch.zeros(num_nodes, out=out)
one = torch.ones(row.size(0), out=zero.new()) one = torch.ones(target.size(0), out=zero.new(target.size(0)))
return zero.scatter_add_(0, row, one) return zero.scatter_add_(0, target, one)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment