test_subgraph_sampler.py 4.75 KB
Newer Older
1
import dgl.graphbolt as gb
2
3
4
5
import gb_test_utils
import pytest
import torch
import torchdata.datapipes as dp
6
from torchdata.datapipes.iter import Mapper
7
8


9
10
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node(labor):
11
    graph = gb_test_utils.rand_csc_graph(20, 0.15)
12
13
    itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
    item_sampler = gb.ItemSampler(itemset, batch_size=2)
14
15
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
16
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
17
    sampler_dp = Sampler(item_sampler, graph, fanouts)
18
    assert len(list(sampler_dp)) == 5
19
20


21
def to_link_batch(data):
22
    block = gb.MiniBatch(node_pairs=data)
23
    return block
24
25


26
27
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link(labor):
28
    graph = gb_test_utils.rand_csc_graph(20, 0.15)
29
30
    itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs")
    item_sampler = gb.ItemSampler(itemset, batch_size=2)
31
32
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
33
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
34
    neighbor_dp = Sampler(item_sampler, graph, fanouts)
35
    assert len(list(neighbor_dp)) == 5
36
37


38
39
40
41
42
43
44
45
46
@pytest.mark.parametrize(
    "format",
    [
        gb.LinkPredictionEdgeFormat.INDEPENDENT,
        gb.LinkPredictionEdgeFormat.CONDITIONED,
        gb.LinkPredictionEdgeFormat.HEAD_CONDITIONED,
        gb.LinkPredictionEdgeFormat.TAIL_CONDITIONED,
    ],
)
47
48
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_With_Negative(format, labor):
49
    graph = gb_test_utils.rand_csc_graph(20, 0.15)
50
51
    itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs")
    item_sampler = gb.ItemSampler(itemset, batch_size=2)
52
53
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
54
    negative_dp = gb.UniformNegativeSampler(item_sampler, 1, format, graph)
55
56
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
    neighbor_dp = Sampler(negative_dp, graph, fanouts)
57
    assert len(list(neighbor_dp)) == 5
58
59


60
61
62
63
64
65
66
def get_hetero_graph():
    # COO graph:
    # [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
    # [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]
    # [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
    # num_nodes = 5, num_n1 = 2, num_n2 = 3
    ntypes = {"n1": 0, "n2": 1}
67
    etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
68
69
70
71
72
73
74
75
76
77
78
79
    metadata = gb.GraphMetadata(ntypes, etypes)
    indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
    indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
    type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
    node_type_offset = torch.LongTensor([0, 2, 5])
    return gb.from_csc(
        indptr,
        indices,
        node_type_offset=node_type_offset,
        type_per_edge=type_per_edge,
        metadata=metadata,
    )
80
81


82
83
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_Hetero(labor):
84
85
86
    graph = get_hetero_graph()
    itemset = gb.ItemSetDict(
        {
87
            "n1:e1:n2": gb.ItemSet(
88
89
                torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T,
                names="node_pairs",
90
            ),
91
            "n2:e2:n1": gb.ItemSet(
92
93
                torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T,
                names="node_pairs",
94
95
96
            ),
        }
    )
97

98
    item_sampler = gb.ItemSampler(itemset, batch_size=2)
99
100
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
101
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
102
    neighbor_dp = Sampler(item_sampler, graph, fanouts)
103
    assert len(list(neighbor_dp)) == 5
104
105
106


@pytest.mark.parametrize(
107
108
109
110
111
112
113
    "format",
    [
        gb.LinkPredictionEdgeFormat.INDEPENDENT,
        gb.LinkPredictionEdgeFormat.CONDITIONED,
        gb.LinkPredictionEdgeFormat.HEAD_CONDITIONED,
        gb.LinkPredictionEdgeFormat.TAIL_CONDITIONED,
    ],
114
)
115
116
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor):
117
118
119
    graph = get_hetero_graph()
    itemset = gb.ItemSetDict(
        {
120
            "n1:e1:n2": gb.ItemSet(
121
122
                torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T,
                names="node_pairs",
123
            ),
124
            "n2:e2:n1": gb.ItemSet(
125
126
                torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T,
                names="node_pairs",
127
128
129
130
            ),
        }
    )

131
    item_sampler = gb.ItemSampler(itemset, batch_size=2)
132
133
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
134
    negative_dp = gb.UniformNegativeSampler(item_sampler, 1, format, graph)
135
136
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
    neighbor_dp = Sampler(negative_dp, graph, fanouts)
137
    assert len(list(neighbor_dp)) == 5