test_spot_target.py 2.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from collections.abc import Mapping

import dgl
import numpy as np
import pytest
import torch


def _create_homogeneous():
    s = torch.randint(0, 200, (1000,))
    d = torch.randint(0, 200, (1000,))
    g = dgl.graph((s, d), num_nodes=200)
    reverse_eids = torch.cat([torch.arange(1000, 2000), torch.arange(0, 1000)])
    seed_edges = torch.arange(0, 1000)
    return g, reverse_eids, seed_edges


def _find_edges_to_exclude(g, pair_eids, degree_threshold):
    src, dst = g.find_edges(pair_eids)
    head_degree = g.in_degrees(src)
    tail_degree = g.in_degrees(dst)
    degree = torch.min(head_degree, tail_degree)
    degree_mask = degree < degree_threshold
    low_degree_pair_eids = pair_eids[degree_mask]
    low_degree_pair_eids = torch.cat(
        [low_degree_pair_eids, low_degree_pair_eids + 1000]
    )
    return low_degree_pair_eids


@pytest.mark.parametrize("degree_threshold", [1, 2, 3, 4, 5])
@pytest.mark.parametrize("batch_size", [1, 10, 50])
def test_spot_target_excludes(degree_threshold, batch_size):
    g, reverse_eids, seed_edges = _create_homogeneous()
    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
    low_degree_excluder = dgl.dataloading.SpotTarget(
        g,
        exclude="reverse_id",
        degree_threshold=degree_threshold,
        reverse_eids=reverse_eids,
    )
    sampler = dgl.dataloading.as_edge_prediction_sampler(
        sampler,
        exclude=low_degree_excluder,
        negative_sampler=dgl.dataloading.negative_sampler.Uniform(1),
    )
    dataloader = dgl.dataloading.DataLoader(
        g, seed_edges, sampler, batch_size=batch_size
    )

    for i, (input_nodes, pair_graph, neg_pair_graph, blocks) in enumerate(
        dataloader
    ):
        if isinstance(blocks, list):
            subg = blocks[0]
        else:
            subg = blocks
        pair_eids = pair_graph.edata[dgl.EID]
        block_eids = subg.edata[dgl.EID]
        edges_to_exclude = _find_edges_to_exclude(
            g, pair_eids, degree_threshold
        )
        if edges_to_exclude is None:
            continue
        edges_to_exclude = dgl.utils.recursive_apply(
            edges_to_exclude, lambda x: x.cpu().numpy()
        )
        block_eids = dgl.utils.recursive_apply(
            block_eids, lambda x: x.cpu().numpy()
        )

        if isinstance(edges_to_exclude, Mapping):
            for k in edges_to_exclude.keys():
                assert not np.isin(edges_to_exclude[k], block_eids[k]).any()
        else:
            assert not np.isin(edges_to_exclude, block_eids).any()

        if i == 10:
            break


if __name__ == "__main__":
    test_spot_target_excludes(degree_threshold=2, batch_size=10)