test_graphbolt_utils.py 2.12 KB
Newer Older
1
import dgl.graphbolt as gb
2
import pytest
3
4
5
6
7
8
9
import torch


def test_unique_and_compact_node_pairs_hetero():
    N1 = torch.randint(0, 50, (30,))
    N2 = torch.randint(0, 50, (20,))
    N3 = torch.randint(0, 50, (10,))
10
11
12
    unique_N1 = torch.unique(N1)
    unique_N2 = torch.unique(N2)
    unique_N3 = torch.unique(N3)
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
    expected_unique_nodes = {
        "n1": unique_N1,
        "n2": unique_N2,
        "n3": unique_N3,
    }
    node_pairs = {
        ("n1", "e1", "n2"): (
            N1[:20],
            N2,
        ),
        ("n1", "e2", "n3"): (
            N1[20:30],
            N3,
        ),
        ("n2", "e3", "n3"): (
            N2[10:],
            N3,
        ),
    }

    unique_nodes, compacted_node_pairs = gb.unique_and_compact_node_pairs(
        node_pairs
    )
36
37
38
    for ntype, nodes in unique_nodes.items():
        expected_nodes = expected_unique_nodes[ntype]
        assert torch.equal(torch.sort(nodes)[0], expected_nodes)
39
40
    for etype, pair in compacted_node_pairs.items():
        u, v = pair
41
42
43
        u_type, _, v_type = etype
        u, v = unique_nodes[u_type][u], unique_nodes[v_type][v]
        expected_u, expected_v = node_pairs[etype]
44
45
46
47
48
        assert torch.equal(u, expected_u)
        assert torch.equal(v, expected_v)


def test_unique_and_compact_node_pairs_homo():
49
50
    N = torch.randint(0, 50, (200,))
    expected_unique_N = torch.unique(N)
51

52
    node_pairs = tuple(N.split(100))
53
54
55
    unique_nodes, compacted_node_pairs = gb.unique_and_compact_node_pairs(
        node_pairs
    )
56
57
    assert torch.equal(torch.sort(unique_nodes)[0], expected_unique_N)

58
    u, v = compacted_node_pairs
59
60
61
    u, v = unique_nodes[u], unique_nodes[v]
    expected_u, expected_v = node_pairs
    unique_v = torch.unique(expected_v)
62
63
    assert torch.equal(u, expected_u)
    assert torch.equal(v, expected_v)
64
65
66
67
68
69
70
71
    assert torch.equal(unique_nodes[: unique_v.size(0)], unique_v)


def test_incomplete_unique_dst_nodes_():
    node_pairs = (torch.randint(0, 50, (50,)), torch.randint(100, 150, (50,)))
    unique_dst_nodes = torch.arange(150, 200)
    with pytest.raises(IndexError):
        gb.unique_and_compact_node_pairs(node_pairs, unique_dst_nodes)