test_degree.py 478 Bytes
Newer Older
rusty1s's avatar
tests  
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from torch_cluster.functions.utils.degree import node_degree


def test_node_degree():
    row = torch.LongTensor([0, 1, 1, 0, 0, 3, 0])
    expected_degree = [4, 2, 0, 1]

    degree = node_degree(row, 4)
    assert degree.type() == torch.FloatTensor().type()
    assert degree.tolist() == expected_degree

    degree = node_degree(row, 4, out=torch.LongTensor())
    assert degree.type() == torch.LongTensor().type()
    assert degree.tolist() == expected_degree