"examples/vscode:/vscode.git/clone" did not exist on "6d2e19f7466b70574209d3da4488e16610c4fac6"
test_subgraph_sampler.py 6.23 KB
Newer Older
1
import dgl.graphbolt as gb
2
3
4
import gb_test_utils
import pytest
import torch
5
from torchdata.datapipes.iter import Mapper
6
7


8
9
def test_SubgraphSampler_invoke():
    itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
10
    item_sampler = gb.ItemSampler(itemset, batch_size=2)
11
12

    # Invoke via class constructor.
13
    datapipe = gb.SubgraphSampler(item_sampler)
14
15
16
17
    with pytest.raises(NotImplementedError):
        next(iter(datapipe))

    # Invokde via functional form.
18
    datapipe = item_sampler.sample_subgraph()
19
20
21
22
23
24
25
26
    with pytest.raises(NotImplementedError):
        next(iter(datapipe))


@pytest.mark.parametrize("labor", [False, True])
def test_NeighborSampler_invoke(labor):
    graph = gb_test_utils.rand_csc_graph(20, 0.15)
    itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
27
    item_sampler = gb.ItemSampler(itemset, batch_size=2)
28
29
30
31
32
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]

    # Invoke via class constructor.
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
33
    datapipe = Sampler(item_sampler, graph, fanouts)
34
35
36
37
    assert len(list(datapipe)) == 5

    # Invokde via functional form.
    if labor:
38
        datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)
39
    else:
40
        datapipe = item_sampler.sample_neighbor(graph, fanouts)
41
42
43
    assert len(list(datapipe)) == 5


44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
@pytest.mark.parametrize("labor", [False, True])
def test_NeighborSampler_fanouts(labor):
    graph = gb_test_utils.rand_csc_graph(20, 0.15)
    itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
    item_sampler = gb.ItemSampler(itemset, batch_size=2)
    num_layer = 2

    # `fanouts` is a list of tensors.
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
    if labor:
        datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)
    else:
        datapipe = item_sampler.sample_neighbor(graph, fanouts)
    assert len(list(datapipe)) == 5

    # `fanouts` is a list of integers.
    fanouts = [2 for _ in range(num_layer)]
    if labor:
        datapipe = item_sampler.sample_layer_neighbor(graph, fanouts)
    else:
        datapipe = item_sampler.sample_neighbor(graph, fanouts)
    assert len(list(datapipe)) == 5


68
69
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node(labor):
70
    graph = gb_test_utils.rand_csc_graph(20, 0.15)
71
72
    itemset = gb.ItemSet(torch.arange(10), names="seed_nodes")
    item_sampler = gb.ItemSampler(itemset, batch_size=2)
73
74
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
75
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
76
    sampler_dp = Sampler(item_sampler, graph, fanouts)
77
    assert len(list(sampler_dp)) == 5
78
79


80
def to_link_batch(data):
81
    block = gb.MiniBatch(node_pairs=data)
82
    return block
83
84


85
86
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link(labor):
87
    graph = gb_test_utils.rand_csc_graph(20, 0.15)
88
89
    itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs")
    item_sampler = gb.ItemSampler(itemset, batch_size=2)
90
91
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
92
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
93
    neighbor_dp = Sampler(item_sampler, graph, fanouts)
94
    assert len(list(neighbor_dp)) == 5
95
96


97
@pytest.mark.parametrize("labor", [False, True])
98
def test_SubgraphSampler_Link_With_Negative(labor):
99
    graph = gb_test_utils.rand_csc_graph(20, 0.15)
100
101
    itemset = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2), names="node_pairs")
    item_sampler = gb.ItemSampler(itemset, batch_size=2)
102
103
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
104
    negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1)
105
106
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
    neighbor_dp = Sampler(negative_dp, graph, fanouts)
107
    assert len(list(neighbor_dp)) == 5
108
109


110
111
112
113
114
115
116
def get_hetero_graph():
    # COO graph:
    # [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
    # [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]
    # [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
    # num_nodes = 5, num_n1 = 2, num_n2 = 3
    ntypes = {"n1": 0, "n2": 1}
117
    etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
118
119
120
121
122
123
124
125
126
127
128
129
    metadata = gb.GraphMetadata(ntypes, etypes)
    indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
    indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
    type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
    node_type_offset = torch.LongTensor([0, 2, 5])
    return gb.from_csc(
        indptr,
        indices,
        node_type_offset=node_type_offset,
        type_per_edge=type_per_edge,
        metadata=metadata,
    )
130
131


132
133
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_Hetero(labor):
134
135
136
    graph = get_hetero_graph()
    itemset = gb.ItemSetDict(
        {
137
            "n1:e1:n2": gb.ItemSet(
138
139
                torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T,
                names="node_pairs",
140
            ),
141
            "n2:e2:n1": gb.ItemSet(
142
143
                torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T,
                names="node_pairs",
144
145
146
            ),
        }
    )
147

148
    item_sampler = gb.ItemSampler(itemset, batch_size=2)
149
150
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
151
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
152
    neighbor_dp = Sampler(item_sampler, graph, fanouts)
153
    assert len(list(neighbor_dp)) == 5
154
155


156
@pytest.mark.parametrize("labor", [False, True])
157
def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
158
159
160
    graph = get_hetero_graph()
    itemset = gb.ItemSetDict(
        {
161
            "n1:e1:n2": gb.ItemSet(
162
163
                torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T,
                names="node_pairs",
164
            ),
165
            "n2:e2:n1": gb.ItemSet(
166
167
                torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T,
                names="node_pairs",
168
169
170
171
            ),
        }
    )

172
    item_sampler = gb.ItemSampler(itemset, batch_size=2)
173
174
    num_layer = 2
    fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
175
    negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1)
176
177
    Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
    neighbor_dp = Sampler(negative_dp, graph, fanouts)
178
    assert len(list(neighbor_dp)) == 5