"tests/vscode:/vscode.git/clone" did not exist on "e602ab1b56889c8f999f07aeddb55d641fba1014"
test_subgraph_sampler.py 5.53 KB
Newer Older
1
import dgl.graphbolt as gb
2
3
4
5
import gb_test_utils
import pytest
import torch
import torchdata.datapipes as dp
6
from torchdata.datapipes.iter import Mapper
7
8


9
10
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node(labor):
11
    graph = gb_test_utils.rand_csc_graph(20, 0.15)
12
    itemset = gb.ItemSet(torch.arange(10))
13
    item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
14
15
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
16
17
18
    minibatch_converter = Mapper(
        item_sampler_dp, gb_test_utils.minibatch_node_collator
    )
19
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
20
    sampler_dp = Sampler(minibatch_converter, graph, fanouts)
21
    assert len(list(sampler_dp)) == 5
22
23


24
def to_link_batch(data):
25
    block = gb.MiniBatch(node_pairs=data)
26
    return block
27
28


29
30
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link(labor):
31
32
33
34
35
36
37
    graph = gb_test_utils.rand_csc_graph(20, 0.15)
    itemset = gb.ItemSet(
        (
            torch.arange(0, 10),
            torch.arange(10, 20),
        )
    )
38
    item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
39
40
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
41
42
43
    minibatch_converter = Mapper(
        item_sampler_dp, gb_test_utils.minibatch_link_collator
    )
44
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
45
    neighbor_dp = Sampler(minibatch_converter, graph, fanouts)
46
    assert len(list(neighbor_dp)) == 5
47
48


49
50
51
52
53
54
55
56
57
@pytest.mark.parametrize(
    "format",
    [
        gb.LinkPredictionEdgeFormat.INDEPENDENT,
        gb.LinkPredictionEdgeFormat.CONDITIONED,
        gb.LinkPredictionEdgeFormat.HEAD_CONDITIONED,
        gb.LinkPredictionEdgeFormat.TAIL_CONDITIONED,
    ],
)
58
59
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_With_Negative(format, labor):
60
61
62
63
64
65
66
    graph = gb_test_utils.rand_csc_graph(20, 0.15)
    itemset = gb.ItemSet(
        (
            torch.arange(0, 10),
            torch.arange(10, 20),
        )
    )
67
    item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
68
69
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
70
71
72
    minibatch_converter = Mapper(
        item_sampler_dp, gb_test_utils.minibatch_link_collator
    )
73
    negative_dp = gb.UniformNegativeSampler(
74
        minibatch_converter, 1, format, graph
75
    )
76
77
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
    neighbor_dp = Sampler(negative_dp, graph, fanouts)
78
    assert len(list(neighbor_dp)) == 5
79
80


81
82
83
84
85
86
87
def get_hetero_graph():
    # COO graph:
    # [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
    # [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]
    # [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
    # num_nodes = 5, num_n1 = 2, num_n2 = 3
    ntypes = {"n1": 0, "n2": 1}
88
    etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
89
90
91
92
93
94
95
96
97
98
99
100
    metadata = gb.GraphMetadata(ntypes, etypes)
    indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
    indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
    type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
    node_type_offset = torch.LongTensor([0, 2, 5])
    return gb.from_csc(
        indptr,
        indices,
        node_type_offset=node_type_offset,
        type_per_edge=type_per_edge,
        metadata=metadata,
    )
101
102


103
104
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_Hetero(labor):
105
106
107
    graph = get_hetero_graph()
    itemset = gb.ItemSetDict(
        {
108
            "n1:e1:n2": gb.ItemSet(
109
110
111
112
113
                (
                    torch.LongTensor([0, 0, 1, 1]),
                    torch.LongTensor([0, 2, 0, 1]),
                )
            ),
114
            "n2:e2:n1": gb.ItemSet(
115
116
117
118
119
120
121
                (
                    torch.LongTensor([0, 0, 1, 1, 2, 2]),
                    torch.LongTensor([0, 1, 1, 0, 0, 1]),
                )
            ),
        }
    )
122

123
    item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
124
125
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
126
127
128
    minibatch_converter = Mapper(
        item_sampler_dp, gb_test_utils.minibatch_link_collator
    )
129
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
130
    neighbor_dp = Sampler(minibatch_converter, graph, fanouts)
131
    assert len(list(neighbor_dp)) == 5
132
133
134


@pytest.mark.parametrize(
135
136
137
138
139
140
141
    "format",
    [
        gb.LinkPredictionEdgeFormat.INDEPENDENT,
        gb.LinkPredictionEdgeFormat.CONDITIONED,
        gb.LinkPredictionEdgeFormat.HEAD_CONDITIONED,
        gb.LinkPredictionEdgeFormat.TAIL_CONDITIONED,
    ],
142
)
143
144
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor):
145
146
147
    graph = get_hetero_graph()
    itemset = gb.ItemSetDict(
        {
148
            "n1:e1:n2": gb.ItemSet(
149
150
151
152
153
                (
                    torch.LongTensor([0, 0, 1, 1]),
                    torch.LongTensor([0, 2, 0, 1]),
                )
            ),
154
            "n2:e2:n1": gb.ItemSet(
155
156
157
158
159
160
161
162
                (
                    torch.LongTensor([0, 0, 1, 1, 2, 2]),
                    torch.LongTensor([0, 1, 1, 0, 0, 1]),
                )
            ),
        }
    )

163
    item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
164
165
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
166
167
168
    minibatch_converter = Mapper(
        item_sampler_dp, gb_test_utils.minibatch_link_collator
    )
169
    negative_dp = gb.UniformNegativeSampler(
170
        minibatch_converter, 1, format, graph
171
    )
172
173
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
    neighbor_dp = Sampler(negative_dp, graph, fanouts)
174
    assert len(list(neighbor_dp)) == 5