test_dataloader.py 9.91 KB
Newer Older
1
2
3
4
5
6
import dgl
import backend as F
import numpy as np
import unittest
from torch.utils.data import DataLoader
from collections import defaultdict
7
from itertools import product
8

9
def _check_neighbor_sampling_dataloader(g, nids, dl, mode, collator):
10
11
    seeds = defaultdict(list)

12
13
    for item in dl:
        if mode == 'node':
14
            input_nodes, output_nodes, items, blocks = item
15
        elif mode == 'edge':
16
            input_nodes, pair_graph, items, blocks = item
17
18
            output_nodes = pair_graph.ndata[dgl.NID]
        elif mode == 'link':
19
            input_nodes, pair_graph, neg_graph, items, blocks = item
20
21
22
23
            output_nodes = pair_graph.ndata[dgl.NID]
            for ntype in pair_graph.ntypes:
                assert F.array_equal(pair_graph.nodes[ntype].data[dgl.NID], neg_graph.nodes[ntype].data[dgl.NID])

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
        # TODO: check if items match output nodes/edges
        if mode == 'node':
            if len(g.ntypes) > 1:
                for ntype in g.ntypes:
                    if ntype not in items:
                        assert len(output_nodes[ntype]) == 0
                    else:
                        assert F.array_equal(output_nodes[ntype], F.gather_row(collator.nids[ntype], items[ntype]))
            else:
                assert F.array_equal(output_nodes, F.gather_row(collator.nids, items))
        else:
            if len(g.etypes) > 1:
                for etype, eids in collator.eids.items():
                    if etype not in items:
                        assert pair_graph.num_edges(etype=etype) == 0
                    else:
                        assert F.array_equal(pair_graph.edges[etype].data[dgl.EID], F.gather_row(eids, items[etype]))
            else:
                assert F.array_equal(pair_graph.edata[dgl.EID], F.gather_row(collator.eids, items))

44
45
46
47
48
49
50
        if len(g.ntypes) > 1:
            for ntype in g.ntypes:
                assert F.array_equal(input_nodes[ntype], blocks[0].srcnodes[ntype].data[dgl.NID])
                assert F.array_equal(output_nodes[ntype], blocks[-1].dstnodes[ntype].data[dgl.NID])
        else:
            assert F.array_equal(input_nodes, blocks[0].srcdata[dgl.NID])
            assert F.array_equal(output_nodes, blocks[-1].dstdata[dgl.NID])
51

52
53
54
55
56
57
58
        prev_dst = {ntype: None for ntype in g.ntypes}
        for block in blocks:
            for canonical_etype in block.canonical_etypes:
                utype, etype, vtype = canonical_etype
                uu, vv = block.all_edges(order='eid', etype=canonical_etype)
                src = block.srcnodes[utype].data[dgl.NID]
                dst = block.dstnodes[vtype].data[dgl.NID]
59
60
61
62
                assert F.array_equal(
                    block.srcnodes[utype].data['feat'], g.nodes[utype].data['feat'][src])
                assert F.array_equal(
                    block.dstnodes[vtype].data['feat'], g.nodes[vtype].data['feat'][dst])
63
64
65
66
67
68
                if prev_dst[utype] is not None:
                    assert F.array_equal(src, prev_dst[utype])
                u = src[uu]
                v = dst[vv]
                assert F.asnumpy(g.has_edges_between(u, v, etype=canonical_etype)).all()
                eid = block.edges[canonical_etype].data[dgl.EID]
69
70
71
                assert F.array_equal(
                    block.edges[canonical_etype].data['feat'],
                    g.edges[canonical_etype].data['feat'][eid])
72
73
74
75
76
77
78
79
80
                ufound, vfound = g.find_edges(eid, etype=canonical_etype)
                assert F.array_equal(ufound, u)
                assert F.array_equal(vfound, v)
            for ntype in block.dsttypes:
                src = block.srcnodes[ntype].data[dgl.NID]
                dst = block.dstnodes[ntype].data[dgl.NID]
                assert F.array_equal(src[:block.number_of_dst_nodes(ntype)], dst)
                prev_dst[ntype] = dst

81
82
83
84
85
86
87
88
        if mode == 'node':
            for ntype in blocks[-1].dsttypes:
                seeds[ntype].append(blocks[-1].dstnodes[ntype].data[dgl.NID])
        elif mode == 'edge' or mode == 'link':
            for etype in pair_graph.canonical_etypes:
                seeds[etype].append(pair_graph.edges[etype].data[dgl.EID])

    # Check if all nodes/edges are iterated
89
90
    seeds = {k: F.cat(v, 0) for k, v in seeds.items()}
    for k, v in seeds.items():
91
92
93
94
95
96
97
        if k in nids:
            seed_set = set(F.asnumpy(nids[k]))
        elif isinstance(k, tuple) and k[1] in nids:
            seed_set = set(F.asnumpy(nids[k[1]]))
        else:
            continue

98
99
100
101
102
        v_set = set(F.asnumpy(v))
        assert v_set == seed_set

@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
def test_neighbor_sampler_dataloader():
103
104
    g = dgl.heterograph({('user', 'follow', 'user'): ([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])}, 
                        {'user': 6}).long()
105
    g = dgl.to_bidirected(g)
106
107
    g.ndata['feat'] = F.randn((6, 8))
    g.edata['feat'] = F.randn((10, 4))
108
109
110
    reverse_eids = F.tensor([5, 6, 7, 8, 9, 0, 1, 2, 3, 4], dtype=F.int64)
    g_sampler1 = dgl.dataloading.MultiLayerNeighborSampler([2, 2], return_eids=True)
    g_sampler2 = dgl.dataloading.MultiLayerFullNeighborSampler(2, return_eids=True)
111
112

    hg = dgl.heterograph({
113
114
115
116
117
         ('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
         ('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]),
         ('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
         ('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5])
    }).long()
118
119
120
121
    for ntype in hg.ntypes:
        hg.nodes[ntype].data['feat'] = F.randn((hg.number_of_nodes(ntype), 8))
    for etype in hg.canonical_etypes:
        hg.edges[etype].data['feat'] = F.randn((hg.number_of_edges(etype), 4))
122
123
124
125
126
127
128
129
130
131
132
133
    hg_sampler1 = dgl.dataloading.MultiLayerNeighborSampler(
        [{'play': 1, 'played-by': 1, 'follow': 2, 'followed-by': 1}] * 2, return_eids=True)
    hg_sampler2 = dgl.dataloading.MultiLayerFullNeighborSampler(2, return_eids=True)
    reverse_etypes = {'follow': 'followed-by', 'followed-by': 'follow', 'play': 'played-by', 'played-by': 'play'}

    collators = []
    graphs = []
    nids = []
    modes = []
    for seeds, sampler in product(
            [F.tensor([0, 1, 2, 3, 5], dtype=F.int64), F.tensor([4, 5], dtype=F.int64)],
            [g_sampler1, g_sampler2]):
134
        collators.append(dgl.dataloading.NodeCollator(g, seeds, sampler, return_indices=True))
135
136
137
138
        graphs.append(g)
        nids.append({'user': seeds})
        modes.append('node')

139
        collators.append(dgl.dataloading.EdgeCollator(g, seeds, sampler, return_indices=True))
140
141
142
143
144
        graphs.append(g)
        nids.append({'follow': seeds})
        modes.append('edge')

        collators.append(dgl.dataloading.EdgeCollator(
145
146
            g, seeds, sampler, exclude='reverse_id', reverse_eids=reverse_eids,
            return_indices=True))
147
148
149
150
151
        graphs.append(g)
        nids.append({'follow': seeds})
        modes.append('edge')

        collators.append(dgl.dataloading.EdgeCollator(
152
153
            g, seeds, sampler, negative_sampler=dgl.dataloading.negative_sampler.Uniform(2),
            return_indices=True))
154
155
156
157
158
159
        graphs.append(g)
        nids.append({'follow': seeds})
        modes.append('link')

        collators.append(dgl.dataloading.EdgeCollator(
            g, seeds, sampler, exclude='reverse_id', reverse_eids=reverse_eids,
160
161
            negative_sampler=dgl.dataloading.negative_sampler.Uniform(2),
            return_indices=True))
162
163
164
165
166
167
168
169
        graphs.append(g)
        nids.append({'follow': seeds})
        modes.append('link')

    for seeds, sampler in product(
            [{'user': F.tensor([0, 1, 3, 5], dtype=F.int64), 'game': F.tensor([0, 1, 2], dtype=F.int64)},
             {'user': F.tensor([4, 5], dtype=F.int64), 'game': F.tensor([0, 1, 2], dtype=F.int64)}],
            [hg_sampler1, hg_sampler2]):
170
        collators.append(dgl.dataloading.NodeCollator(hg, seeds, sampler, return_indices=True))
171
172
173
174
175
176
177
178
        graphs.append(hg)
        nids.append(seeds)
        modes.append('node')

    for seeds, sampler in product(
            [{'follow': F.tensor([0, 1, 3, 5], dtype=F.int64), 'play': F.tensor([1, 3], dtype=F.int64)},
             {'follow': F.tensor([4, 5], dtype=F.int64), 'play': F.tensor([1, 3], dtype=F.int64)}],
            [hg_sampler1, hg_sampler2]):
179
        collators.append(dgl.dataloading.EdgeCollator(hg, seeds, sampler, return_indices=True))
180
181
182
183
184
        graphs.append(hg)
        nids.append(seeds)
        modes.append('edge')

        collators.append(dgl.dataloading.EdgeCollator(
185
186
            hg, seeds, sampler, exclude='reverse_types', reverse_etypes=reverse_etypes,
            return_indices=True))
187
188
189
190
191
        graphs.append(hg)
        nids.append(seeds)
        modes.append('edge')

        collators.append(dgl.dataloading.EdgeCollator(
192
193
            hg, seeds, sampler, negative_sampler=dgl.dataloading.negative_sampler.Uniform(2),
            return_indices=True))
194
195
196
197
198
199
        graphs.append(hg)
        nids.append(seeds)
        modes.append('link')

        collators.append(dgl.dataloading.EdgeCollator(
            hg, seeds, sampler, exclude='reverse_types', reverse_etypes=reverse_etypes,
200
201
            negative_sampler=dgl.dataloading.negative_sampler.Uniform(2),
            return_indices=True))
202
203
204
205
206
        graphs.append(hg)
        nids.append(seeds)
        modes.append('link')

    for _g, nid, collator, mode in zip(graphs, nids, collators, modes):
207
208
        dl = DataLoader(
            collator.dataset, collate_fn=collator.collate, batch_size=2, shuffle=True, drop_last=False)
209
        _check_neighbor_sampling_dataloader(_g, nid, dl, mode, collator)
210

211
212
213
214
215
216
217
218
def test_graph_dataloader():
    batch_size = 16
    num_batches = 2
    minigc_dataset = dgl.data.MiniGCDataset(batch_size * num_batches, 10, 20)
    data_loader = dgl.dataloading.GraphDataLoader(minigc_dataset, batch_size=batch_size, shuffle=True)
    for graph, label in data_loader:
        assert isinstance(graph, dgl.DGLGraph)
        assert F.asnumpy(label).shape[0] == batch_size
219
220
221

if __name__ == '__main__':
    test_neighbor_sampler_dataloader()
222
    test_graph_dataloader()