test_graphbolt_utils.py 3.48 KB
Newer Older
1
import dgl.graphbolt as gb
2
import pytest
3
4
5
import torch


6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def test_unique_and_compact_hetero():
    N1 = torch.randint(0, 50, (30,))
    N2 = torch.randint(0, 50, (20,))
    N3 = torch.randint(0, 50, (10,))
    unique_N1 = torch.unique(N1)
    unique_N2 = torch.unique(N2)
    unique_N3 = torch.unique(N3)
    expected_unique = {
        "n1": unique_N1,
        "n2": unique_N2,
        "n3": unique_N3,
    }
    nodes_dict = {
        "n1": N1.split(5),
        "n2": N2.split(4),
        "n3": N3.split(2),
    }

    unique, compacted = gb.unique_and_compact(nodes_dict)
    for ntype, nodes in unique.items():
        expected_nodes = expected_unique[ntype]
        assert torch.equal(torch.sort(nodes)[0], expected_nodes)

    for ntype, nodes in compacted.items():
        expected_nodes = nodes_dict[ntype]
31
        assert isinstance(nodes, list)
32
33
34
35
36
37
38
39
40
41
42
43
44
45
        for expected_node, node in zip(expected_nodes, nodes):
            node = unique[ntype][node]
            assert torch.equal(expected_node, node)


def test_unique_and_compact_homo():
    N = torch.randint(0, 50, (200,))
    expected_unique_N = torch.unique(N)
    nodes_list = N.split(5)

    unique, compacted = gb.unique_and_compact(nodes_list)

    assert torch.equal(torch.sort(unique)[0], expected_unique_N)

46
    assert isinstance(compacted, list)
47
48
49
50
51
    for expected_node, node in zip(nodes_list, compacted):
        node = unique[node]
        assert torch.equal(expected_node, node)


52
53
54
55
def test_unique_and_compact_node_pairs_hetero():
    N1 = torch.randint(0, 50, (30,))
    N2 = torch.randint(0, 50, (20,))
    N3 = torch.randint(0, 50, (10,))
56
57
58
    unique_N1 = torch.unique(N1)
    unique_N2 = torch.unique(N2)
    unique_N3 = torch.unique(N3)
59
60
61
62
63
64
    expected_unique_nodes = {
        "n1": unique_N1,
        "n2": unique_N2,
        "n3": unique_N3,
    }
    node_pairs = {
65
        "n1:e1:n2": (
66
67
68
            N1[:20],
            N2,
        ),
69
        "n1:e2:n3": (
70
71
72
            N1[20:30],
            N3,
        ),
73
        "n2:e3:n3": (
74
75
76
77
78
79
80
81
            N2[10:],
            N3,
        ),
    }

    unique_nodes, compacted_node_pairs = gb.unique_and_compact_node_pairs(
        node_pairs
    )
82
83
84
    for ntype, nodes in unique_nodes.items():
        expected_nodes = expected_unique_nodes[ntype]
        assert torch.equal(torch.sort(nodes)[0], expected_nodes)
85
86
    for etype, pair in compacted_node_pairs.items():
        u, v = pair
87
        u_type, _, v_type = gb.etype_str_to_tuple(etype)
88
89
        u, v = unique_nodes[u_type][u], unique_nodes[v_type][v]
        expected_u, expected_v = node_pairs[etype]
90
91
92
93
94
        assert torch.equal(u, expected_u)
        assert torch.equal(v, expected_v)


def test_unique_and_compact_node_pairs_homo():
95
96
    N = torch.randint(0, 50, (200,))
    expected_unique_N = torch.unique(N)
97

98
    node_pairs = tuple(N.split(100))
99
100
101
    unique_nodes, compacted_node_pairs = gb.unique_and_compact_node_pairs(
        node_pairs
    )
102
103
    assert torch.equal(torch.sort(unique_nodes)[0], expected_unique_N)

104
    u, v = compacted_node_pairs
105
106
107
    u, v = unique_nodes[u], unique_nodes[v]
    expected_u, expected_v = node_pairs
    unique_v = torch.unique(expected_v)
108
109
    assert torch.equal(u, expected_u)
    assert torch.equal(v, expected_v)
110
111
112
113
114
115
116
117
    assert torch.equal(unique_nodes[: unique_v.size(0)], unique_v)


def test_incomplete_unique_dst_nodes_():
    node_pairs = (torch.randint(0, 50, (50,)), torch.randint(100, 150, (50,)))
    unique_dst_nodes = torch.arange(150, 200)
    with pytest.raises(IndexError):
        gb.unique_and_compact_node_pairs(node_pairs, unique_dst_nodes)