"vscode:/vscode.git/clone" did not exist on "0877f1e75b508f74ca06adc93beb86d80732a310"
test_dataloader.py 17.7 KB
Newer Older
1
import os
2
3
4
import dgl
import backend as F
import unittest
5
import torch
6
7
from torch.utils.data import DataLoader
from collections import defaultdict
8
from collections.abc import Iterator
9
from itertools import product
10
import pytest
11

12
def _check_neighbor_sampling_dataloader(g, nids, dl, mode, collator):
13
14
    seeds = defaultdict(list)

15
16
    for item in dl:
        if mode == 'node':
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
17
            input_nodes, output_nodes, blocks = item
18
        elif mode == 'edge':
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
19
            input_nodes, pair_graph, blocks = item
20
21
            output_nodes = pair_graph.ndata[dgl.NID]
        elif mode == 'link':
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
22
            input_nodes, pair_graph, neg_graph, blocks = item
23
24
25
26
            output_nodes = pair_graph.ndata[dgl.NID]
            for ntype in pair_graph.ntypes:
                assert F.array_equal(pair_graph.nodes[ntype].data[dgl.NID], neg_graph.nodes[ntype].data[dgl.NID])

27
28
29
30
31
32
33
        if len(g.ntypes) > 1:
            for ntype in g.ntypes:
                assert F.array_equal(input_nodes[ntype], blocks[0].srcnodes[ntype].data[dgl.NID])
                assert F.array_equal(output_nodes[ntype], blocks[-1].dstnodes[ntype].data[dgl.NID])
        else:
            assert F.array_equal(input_nodes, blocks[0].srcdata[dgl.NID])
            assert F.array_equal(output_nodes, blocks[-1].dstdata[dgl.NID])
34

35
36
37
38
39
40
41
        prev_dst = {ntype: None for ntype in g.ntypes}
        for block in blocks:
            for canonical_etype in block.canonical_etypes:
                utype, etype, vtype = canonical_etype
                uu, vv = block.all_edges(order='eid', etype=canonical_etype)
                src = block.srcnodes[utype].data[dgl.NID]
                dst = block.dstnodes[vtype].data[dgl.NID]
42
43
44
45
                assert F.array_equal(
                    block.srcnodes[utype].data['feat'], g.nodes[utype].data['feat'][src])
                assert F.array_equal(
                    block.dstnodes[vtype].data['feat'], g.nodes[vtype].data['feat'][dst])
46
47
48
49
50
51
                if prev_dst[utype] is not None:
                    assert F.array_equal(src, prev_dst[utype])
                u = src[uu]
                v = dst[vv]
                assert F.asnumpy(g.has_edges_between(u, v, etype=canonical_etype)).all()
                eid = block.edges[canonical_etype].data[dgl.EID]
52
53
54
                assert F.array_equal(
                    block.edges[canonical_etype].data['feat'],
                    g.edges[canonical_etype].data['feat'][eid])
55
56
57
58
59
60
61
62
63
                ufound, vfound = g.find_edges(eid, etype=canonical_etype)
                assert F.array_equal(ufound, u)
                assert F.array_equal(vfound, v)
            for ntype in block.dsttypes:
                src = block.srcnodes[ntype].data[dgl.NID]
                dst = block.dstnodes[ntype].data[dgl.NID]
                assert F.array_equal(src[:block.number_of_dst_nodes(ntype)], dst)
                prev_dst[ntype] = dst

64
65
66
67
68
69
70
71
        if mode == 'node':
            for ntype in blocks[-1].dsttypes:
                seeds[ntype].append(blocks[-1].dstnodes[ntype].data[dgl.NID])
        elif mode == 'edge' or mode == 'link':
            for etype in pair_graph.canonical_etypes:
                seeds[etype].append(pair_graph.edges[etype].data[dgl.EID])

    # Check if all nodes/edges are iterated
72
73
    seeds = {k: F.cat(v, 0) for k, v in seeds.items()}
    for k, v in seeds.items():
74
75
76
77
78
79
80
        if k in nids:
            seed_set = set(F.asnumpy(nids[k]))
        elif isinstance(k, tuple) and k[1] in nids:
            seed_set = set(F.asnumpy(nids[k[1]]))
        else:
            continue

81
82
83
84
        v_set = set(F.asnumpy(v))
        assert v_set == seed_set

def test_neighbor_sampler_dataloader():
85
    g = dgl.heterograph({('user', 'follow', 'user'): ([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])},
86
                        {'user': 6}).long()
87
    g = dgl.to_bidirected(g).to(F.ctx())
88
89
    g.ndata['feat'] = F.randn((6, 8))
    g.edata['feat'] = F.randn((10, 4))
90
91
92
    reverse_eids = F.tensor([5, 6, 7, 8, 9, 0, 1, 2, 3, 4], dtype=F.int64)
    g_sampler1 = dgl.dataloading.MultiLayerNeighborSampler([2, 2], return_eids=True)
    g_sampler2 = dgl.dataloading.MultiLayerFullNeighborSampler(2, return_eids=True)
93
94

    hg = dgl.heterograph({
95
96
97
98
         ('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
         ('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]),
         ('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
         ('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5])
99
    }).long().to(F.ctx())
100
101
102
103
    for ntype in hg.ntypes:
        hg.nodes[ntype].data['feat'] = F.randn((hg.number_of_nodes(ntype), 8))
    for etype in hg.canonical_etypes:
        hg.edges[etype].data['feat'] = F.randn((hg.number_of_edges(etype), 4))
104
105
106
107
108
109
110
111
112
113
114
115
    hg_sampler1 = dgl.dataloading.MultiLayerNeighborSampler(
        [{'play': 1, 'played-by': 1, 'follow': 2, 'followed-by': 1}] * 2, return_eids=True)
    hg_sampler2 = dgl.dataloading.MultiLayerFullNeighborSampler(2, return_eids=True)
    reverse_etypes = {'follow': 'followed-by', 'followed-by': 'follow', 'play': 'played-by', 'played-by': 'play'}

    collators = []
    graphs = []
    nids = []
    modes = []
    for seeds, sampler in product(
            [F.tensor([0, 1, 2, 3, 5], dtype=F.int64), F.tensor([4, 5], dtype=F.int64)],
            [g_sampler1, g_sampler2]):
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
116
        collators.append(dgl.dataloading.NodeCollator(g, seeds, sampler))
117
118
119
120
        graphs.append(g)
        nids.append({'user': seeds})
        modes.append('node')

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
121
        collators.append(dgl.dataloading.EdgeCollator(g, seeds, sampler))
122
123
124
125
        graphs.append(g)
        nids.append({'follow': seeds})
        modes.append('edge')

126
127
128
129
130
131
        collators.append(dgl.dataloading.EdgeCollator(
            g, seeds, sampler, exclude='self'))
        graphs.append(g)
        nids.append({'follow': seeds})
        modes.append('edge')

132
        collators.append(dgl.dataloading.EdgeCollator(
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
133
            g, seeds, sampler, exclude='reverse_id', reverse_eids=reverse_eids))
134
135
136
137
138
        graphs.append(g)
        nids.append({'follow': seeds})
        modes.append('edge')

        collators.append(dgl.dataloading.EdgeCollator(
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
139
            g, seeds, sampler, negative_sampler=dgl.dataloading.negative_sampler.Uniform(2)))
140
141
142
143
        graphs.append(g)
        nids.append({'follow': seeds})
        modes.append('link')

144
145
146
147
148
149
        collators.append(dgl.dataloading.EdgeCollator(
            g, seeds, sampler, exclude='self', negative_sampler=dgl.dataloading.negative_sampler.Uniform(2)))
        graphs.append(g)
        nids.append({'follow': seeds})
        modes.append('link')

150
151
        collators.append(dgl.dataloading.EdgeCollator(
            g, seeds, sampler, exclude='reverse_id', reverse_eids=reverse_eids,
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
152
            negative_sampler=dgl.dataloading.negative_sampler.Uniform(2)))
153
154
155
156
157
158
159
160
        graphs.append(g)
        nids.append({'follow': seeds})
        modes.append('link')

    for seeds, sampler in product(
            [{'user': F.tensor([0, 1, 3, 5], dtype=F.int64), 'game': F.tensor([0, 1, 2], dtype=F.int64)},
             {'user': F.tensor([4, 5], dtype=F.int64), 'game': F.tensor([0, 1, 2], dtype=F.int64)}],
            [hg_sampler1, hg_sampler2]):
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
161
        collators.append(dgl.dataloading.NodeCollator(hg, seeds, sampler))
162
163
164
165
166
167
168
169
        graphs.append(hg)
        nids.append(seeds)
        modes.append('node')

    for seeds, sampler in product(
            [{'follow': F.tensor([0, 1, 3, 5], dtype=F.int64), 'play': F.tensor([1, 3], dtype=F.int64)},
             {'follow': F.tensor([4, 5], dtype=F.int64), 'play': F.tensor([1, 3], dtype=F.int64)}],
            [hg_sampler1, hg_sampler2]):
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
170
        collators.append(dgl.dataloading.EdgeCollator(hg, seeds, sampler))
171
172
173
174
175
        graphs.append(hg)
        nids.append(seeds)
        modes.append('edge')

        collators.append(dgl.dataloading.EdgeCollator(
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
176
            hg, seeds, sampler, exclude='reverse_types', reverse_etypes=reverse_etypes))
177
178
179
180
181
        graphs.append(hg)
        nids.append(seeds)
        modes.append('edge')

        collators.append(dgl.dataloading.EdgeCollator(
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
182
            hg, seeds, sampler, negative_sampler=dgl.dataloading.negative_sampler.Uniform(2)))
183
184
185
186
187
188
        graphs.append(hg)
        nids.append(seeds)
        modes.append('link')

        collators.append(dgl.dataloading.EdgeCollator(
            hg, seeds, sampler, exclude='reverse_types', reverse_etypes=reverse_etypes,
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
189
            negative_sampler=dgl.dataloading.negative_sampler.Uniform(2)))
190
191
192
193
194
        graphs.append(hg)
        nids.append(seeds)
        modes.append('link')

    for _g, nid, collator, mode in zip(graphs, nids, collators, modes):
195
196
        dl = DataLoader(
            collator.dataset, collate_fn=collator.collate, batch_size=2, shuffle=True, drop_last=False)
197
        assert isinstance(iter(dl), Iterator)
198
        _check_neighbor_sampling_dataloader(_g, nid, dl, mode, collator)
199

200
201
202
203
204
def test_graph_dataloader():
    batch_size = 16
    num_batches = 2
    minigc_dataset = dgl.data.MiniGCDataset(batch_size * num_batches, 10, 20)
    data_loader = dgl.dataloading.GraphDataLoader(minigc_dataset, batch_size=batch_size, shuffle=True)
205
    assert isinstance(iter(data_loader), Iterator)
206
207
208
    for graph, label in data_loader:
        assert isinstance(graph, dgl.DGLGraph)
        assert F.asnumpy(label).shape[0] == batch_size
209

210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@pytest.mark.parametrize('num_workers', [0, 4])
def test_cluster_gcn(num_workers):
    dataset = dgl.data.CoraFullDataset()
    g = dataset[0]
    sgiter = dgl.dataloading.ClusterGCNSubgraphIterator(g, 100, '.', refresh=True)
    dataloader = dgl.dataloading.GraphDataLoader(sgiter, batch_size=4, num_workers=num_workers)
    for sg in dataloader:
        assert sg.batch_size == 4

    sgiter = dgl.dataloading.ClusterGCNSubgraphIterator(g, 100, '.', refresh=False) # use cache
    dataloader = dgl.dataloading.GraphDataLoader(sgiter, batch_size=4, num_workers=num_workers)
    for sg in dataloader:
        assert sg.batch_size == 4

@pytest.mark.parametrize('num_workers', [0, 4])
def test_shadow(num_workers):
    g = dgl.data.CoraFullDataset()[0]
    sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15])
    dataloader = dgl.dataloading.NodeDataLoader(
        g, torch.arange(g.num_nodes()), sampler,
        batch_size=5, shuffle=True, drop_last=False, num_workers=num_workers)
    for i, (input_nodes, output_nodes, (subgraph,)) in enumerate(dataloader):
        assert torch.equal(input_nodes, subgraph.ndata[dgl.NID])
        assert torch.equal(input_nodes[:output_nodes.shape[0]], output_nodes)
        assert torch.equal(subgraph.ndata['label'], g.ndata['label'][input_nodes])
        assert torch.equal(subgraph.ndata['feat'], g.ndata['feat'][input_nodes])
        if i == 5:
            break


@pytest.mark.parametrize('num_workers', [0, 4])
def test_neighbor_nonuniform(num_workers):
    g = dgl.graph(([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1]))
    g.edata['p'] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])
    sampler = dgl.dataloading.MultiLayerNeighborSampler([2], prob='p')
    dataloader = dgl.dataloading.NodeDataLoader(g, [0, 1], sampler, batch_size=1, device=F.ctx())
    for input_nodes, output_nodes, blocks in dataloader:
        seed = output_nodes.item()
        neighbors = set(input_nodes[1:].cpu().numpy())
        if seed == 1:
            assert neighbors == {5, 6}
        elif seed == 0:
            assert neighbors == {1, 2}

    g = dgl.heterograph({
        ('B', 'BA', 'A'): ([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1]),
        ('C', 'CA', 'A'): ([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1]),
        })
    g.edges['BA'].data['p'] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])
    g.edges['CA'].data['p'] = torch.FloatTensor([0, 0, 1, 1, 0, 0, 1, 1])
    sampler = dgl.dataloading.MultiLayerNeighborSampler([2], prob='p')
    dataloader = dgl.dataloading.NodeDataLoader(
        g, {'A': [0, 1]}, sampler, batch_size=1, device=F.ctx())
    for input_nodes, output_nodes, blocks in dataloader:
        seed = output_nodes['A'].item()
        # Seed and neighbors are of different node types so slicing is not necessary here.
        neighbors = set(input_nodes['B'].cpu().numpy())
        if seed == 1:
            assert neighbors == {5, 6}
        elif seed == 0:
            assert neighbors == {1, 2}

        neighbors = set(input_nodes['C'].cpu().numpy())
        if seed == 1:
            assert neighbors == {7, 8}
        elif seed == 0:
            assert neighbors == {3, 4}


280
281
282
283
284
285
286
287
288
289
def _check_device(data):
    if isinstance(data, dict):
        for k, v in data.items():
            assert v.device == F.ctx()
    elif isinstance(data, list):
        for v in data:
            assert v.device == F.ctx()
    else:
        assert data.device == F.ctx()

290
291
@pytest.mark.parametrize('sampler_name', ['full', 'neighbor', 'neighbor2', 'shadow'])
def test_node_dataloader(sampler_name):
Xin Yao's avatar
Xin Yao committed
292
293
    g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4]))
    g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu())
294
295
296
297
298
    sampler = {
        'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
        'neighbor': dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
        'neighbor2': dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
        'shadow': dgl.dataloading.ShaDowKHopSampler([3, 3])}[sampler_name]
299
300
301
302
303
304
305
306
307
308
309
310
311

    dataloader = dgl.dataloading.NodeDataLoader(
        g1, g1.nodes(), sampler, device=F.ctx(), batch_size=g1.num_nodes())
    for input_nodes, output_nodes, blocks in dataloader:
        _check_device(input_nodes)
        _check_device(output_nodes)
        _check_device(blocks)

    g2 = dgl.heterograph({
         ('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
         ('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]),
         ('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
         ('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5])
Xin Yao's avatar
Xin Yao committed
312
    })
313
    for ntype in g2.ntypes:
Xin Yao's avatar
Xin Yao committed
314
        g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu())
315
    batch_size = max(g2.num_nodes(nty) for nty in g2.ntypes)
316
317
318
319
320
    sampler = {
        'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
        'neighbor': dgl.dataloading.MultiLayerNeighborSampler([{etype: 3 for etype in g2.etypes}] * 2),
        'neighbor2': dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
        'shadow': dgl.dataloading.ShaDowKHopSampler([{etype: 3 for etype in g2.etypes}] * 2)}[sampler_name]
321
322
323
324

    dataloader = dgl.dataloading.NodeDataLoader(
        g2, {nty: g2.nodes(nty) for nty in g2.ntypes},
        sampler, device=F.ctx(), batch_size=batch_size)
325
    assert isinstance(iter(dataloader), Iterator)
326
327
328
329
330
    for input_nodes, output_nodes, blocks in dataloader:
        _check_device(input_nodes)
        _check_device(output_nodes)
        _check_device(blocks)

331
332
333

@pytest.mark.parametrize('sampler_name', ['full', 'neighbor', 'shadow'])
def test_edge_dataloader(sampler_name):
334
335
    neg_sampler = dgl.dataloading.negative_sampler.Uniform(2)

Xin Yao's avatar
Xin Yao committed
336
337
    g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4]))
    g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu())
338

339
340
341
342
343
    sampler = {
        'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
        'neighbor': dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
        'shadow': dgl.dataloading.ShaDowKHopSampler([3, 3])}[sampler_name]

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
344
    # no negative sampler
345
346
347
348
349
350
351
    dataloader = dgl.dataloading.EdgeDataLoader(
        g1, g1.edges(form='eid'), sampler, device=F.ctx(), batch_size=g1.num_edges())
    for input_nodes, pos_pair_graph, blocks in dataloader:
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(blocks)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
352
    # negative sampler
353
354
355
356
357
358
359
360
361
362
363
364
365
366
    dataloader = dgl.dataloading.EdgeDataLoader(
        g1, g1.edges(form='eid'), sampler, device=F.ctx(),
        negative_sampler=neg_sampler, batch_size=g1.num_edges())
    for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(neg_pair_graph)
        _check_device(blocks)

    g2 = dgl.heterograph({
         ('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
         ('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]),
         ('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
         ('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5])
Xin Yao's avatar
Xin Yao committed
367
    })
368
    for ntype in g2.ntypes:
Xin Yao's avatar
Xin Yao committed
369
        g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu())
370
    batch_size = max(g2.num_edges(ety) for ety in g2.canonical_etypes)
371
372
373
374
    sampler = {
        'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
        'neighbor': dgl.dataloading.MultiLayerNeighborSampler([{etype: 3 for etype in g2.etypes}] * 2),
        'shadow': dgl.dataloading.ShaDowKHopSampler([{etype: 3 for etype in g2.etypes}] * 2)}[sampler_name]
375

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
376
    # no negative sampler
377
378
    dataloader = dgl.dataloading.EdgeDataLoader(
        g2, {ety: g2.edges(form='eid', etype=ety) for ety in g2.canonical_etypes},
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
379
380
        sampler, device=F.ctx(), batch_size=batch_size)
    for input_nodes, pos_pair_graph, blocks in dataloader:
381
382
383
384
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(blocks)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
385
    # negative sampler
386
387
388
    dataloader = dgl.dataloading.EdgeDataLoader(
        g2, {ety: g2.edges(form='eid', etype=ety) for ety in g2.canonical_etypes},
        sampler, device=F.ctx(), negative_sampler=neg_sampler,
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
389
        batch_size=batch_size)
390

391
    assert isinstance(iter(dataloader), Iterator)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
392
    for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
393
394
395
396
397
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(neg_pair_graph)
        _check_device(blocks)

398
399
if __name__ == '__main__':
    test_neighbor_sampler_dataloader()
400
    test_graph_dataloader()
401
402
403
404
405
    test_cluster_gcn(0)
    test_neighbor_nonuniform(0)
    for sampler in ['full', 'neighbor', 'shadow']:
        test_node_dataloader(sampler)
        test_edge_dataloader(sampler)