test_dataloader.py 9.36 KB
Newer Older
1
import os
2
import dgl
3
import dgl.ops as OPS
4
5
import backend as F
import unittest
6
import torch
7
8
from torch.utils.data import DataLoader
from collections import defaultdict
9
from collections.abc import Iterator
10
from itertools import product
11
import pytest
12
13


14
15
16
17
18
def test_graph_dataloader():
    batch_size = 16
    num_batches = 2
    minigc_dataset = dgl.data.MiniGCDataset(batch_size * num_batches, 10, 20)
    data_loader = dgl.dataloading.GraphDataLoader(minigc_dataset, batch_size=batch_size, shuffle=True)
19
    assert isinstance(iter(data_loader), Iterator)
20
21
22
    for graph, label in data_loader:
        assert isinstance(graph, dgl.DGLGraph)
        assert F.asnumpy(label).shape[0] == batch_size
23

24
25
26
27
28
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@pytest.mark.parametrize('num_workers', [0, 4])
def test_cluster_gcn(num_workers):
    dataset = dgl.data.CoraFullDataset()
    g = dataset[0]
29
30
31
32
33
34
    sampler = dgl.dataloading.ClusterGCNSampler(g, 100)
    dataloader = dgl.dataloading.DataLoader(
        g, torch.arange(100), sampler, batch_size=4, num_workers=num_workers)
    assert len(dataloader) == 25
    for i, sg in enumerate(dataloader):
        pass
35
36
37
38
39
40
41
42

@pytest.mark.parametrize('num_workers', [0, 4])
def test_shadow(num_workers):
    g = dgl.data.CoraFullDataset()[0]
    sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15])
    dataloader = dgl.dataloading.NodeDataLoader(
        g, torch.arange(g.num_nodes()), sampler,
        batch_size=5, shuffle=True, drop_last=False, num_workers=num_workers)
43
    for i, (input_nodes, output_nodes, subgraph) in enumerate(dataloader):
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        assert torch.equal(input_nodes, subgraph.ndata[dgl.NID])
        assert torch.equal(input_nodes[:output_nodes.shape[0]], output_nodes)
        assert torch.equal(subgraph.ndata['label'], g.ndata['label'][input_nodes])
        assert torch.equal(subgraph.ndata['feat'], g.ndata['feat'][input_nodes])
        if i == 5:
            break


@pytest.mark.parametrize('num_workers', [0, 4])
def test_neighbor_nonuniform(num_workers):
    g = dgl.graph(([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1]))
    g.edata['p'] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])
    sampler = dgl.dataloading.MultiLayerNeighborSampler([2], prob='p')
    dataloader = dgl.dataloading.NodeDataLoader(g, [0, 1], sampler, batch_size=1, device=F.ctx())
    for input_nodes, output_nodes, blocks in dataloader:
        seed = output_nodes.item()
        neighbors = set(input_nodes[1:].cpu().numpy())
        if seed == 1:
            assert neighbors == {5, 6}
        elif seed == 0:
            assert neighbors == {1, 2}

    g = dgl.heterograph({
        ('B', 'BA', 'A'): ([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1]),
        ('C', 'CA', 'A'): ([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1]),
        })
    g.edges['BA'].data['p'] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])
    g.edges['CA'].data['p'] = torch.FloatTensor([0, 0, 1, 1, 0, 0, 1, 1])
    sampler = dgl.dataloading.MultiLayerNeighborSampler([2], prob='p')
    dataloader = dgl.dataloading.NodeDataLoader(
        g, {'A': [0, 1]}, sampler, batch_size=1, device=F.ctx())
    for input_nodes, output_nodes, blocks in dataloader:
        seed = output_nodes['A'].item()
        # Seed and neighbors are of different node types so slicing is not necessary here.
        neighbors = set(input_nodes['B'].cpu().numpy())
        if seed == 1:
            assert neighbors == {5, 6}
        elif seed == 0:
            assert neighbors == {1, 2}

        neighbors = set(input_nodes['C'].cpu().numpy())
        if seed == 1:
            assert neighbors == {7, 8}
        elif seed == 0:
            assert neighbors == {3, 4}


91
92
93
94
95
96
97
98
99
100
def _check_device(data):
    if isinstance(data, dict):
        for k, v in data.items():
            assert v.device == F.ctx()
    elif isinstance(data, list):
        for v in data:
            assert v.device == F.ctx()
    else:
        assert data.device == F.ctx()

101
@pytest.mark.parametrize('sampler_name', ['full', 'neighbor', 'neighbor2'])
102
def test_node_dataloader(sampler_name):
Xin Yao's avatar
Xin Yao committed
103
104
    g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4]))
    g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu())
105
106
    g1.ndata['label'] = F.copy_to(F.randn((g1.num_nodes(),)), F.cpu())

107
108
109
110
111
112
113
114
115
116
117
118
119
    for num_workers in [0, 1, 2]:
        sampler = {
            'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
            'neighbor': dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
            'neighbor2': dgl.dataloading.MultiLayerNeighborSampler([3, 3])}[sampler_name]
        dataloader = dgl.dataloading.NodeDataLoader(
            g1, g1.nodes(), sampler, device=F.ctx(),
            batch_size=g1.num_nodes(),
            num_workers=num_workers)
        for input_nodes, output_nodes, blocks in dataloader:
            _check_device(input_nodes)
            _check_device(output_nodes)
            _check_device(blocks)
120
121
122
123
124
125

    g2 = dgl.heterograph({
         ('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
         ('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]),
         ('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
         ('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5])
Xin Yao's avatar
Xin Yao committed
126
    })
127
    for ntype in g2.ntypes:
Xin Yao's avatar
Xin Yao committed
128
        g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu())
129
    batch_size = max(g2.num_nodes(nty) for nty in g2.ntypes)
130
131
132
    sampler = {
        'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
        'neighbor': dgl.dataloading.MultiLayerNeighborSampler([{etype: 3 for etype in g2.etypes}] * 2),
133
        'neighbor2': dgl.dataloading.MultiLayerNeighborSampler([3, 3])}[sampler_name]
134

135
136
137
138
139
140
141
142
    dataloader = dgl.dataloading.NodeDataLoader(
        g2, {nty: g2.nodes(nty) for nty in g2.ntypes},
        sampler, device=F.ctx(), batch_size=batch_size)
    assert isinstance(iter(dataloader), Iterator)
    for input_nodes, output_nodes, blocks in dataloader:
        _check_device(input_nodes)
        _check_device(output_nodes)
        _check_device(blocks)
143

144

145
@pytest.mark.parametrize('sampler_name', ['full', 'neighbor'])
146
147
148
149
150
@pytest.mark.parametrize('neg_sampler', [
    dgl.dataloading.negative_sampler.Uniform(2),
    dgl.dataloading.negative_sampler.GlobalUniform(15, False, 3),
    dgl.dataloading.negative_sampler.GlobalUniform(15, True, 3)])
def test_edge_dataloader(sampler_name, neg_sampler):
Xin Yao's avatar
Xin Yao committed
151
152
    g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4]))
    g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu())
153

154
155
    sampler = {
        'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
156
        'neighbor': dgl.dataloading.MultiLayerNeighborSampler([3, 3])}[sampler_name]
157

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
158
    # no negative sampler
159
160
161
162
163
164
165
    dataloader = dgl.dataloading.EdgeDataLoader(
        g1, g1.edges(form='eid'), sampler, device=F.ctx(), batch_size=g1.num_edges())
    for input_nodes, pos_pair_graph, blocks in dataloader:
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(blocks)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
166
    # negative sampler
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    dataloader = dgl.dataloading.EdgeDataLoader(
        g1, g1.edges(form='eid'), sampler, device=F.ctx(),
        negative_sampler=neg_sampler, batch_size=g1.num_edges())
    for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(neg_pair_graph)
        _check_device(blocks)

    g2 = dgl.heterograph({
         ('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
         ('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]),
         ('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
         ('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5])
Xin Yao's avatar
Xin Yao committed
181
    })
182
    for ntype in g2.ntypes:
Xin Yao's avatar
Xin Yao committed
183
        g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu())
184
    batch_size = max(g2.num_edges(ety) for ety in g2.canonical_etypes)
185
186
187
    sampler = {
        'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
        'neighbor': dgl.dataloading.MultiLayerNeighborSampler([{etype: 3 for etype in g2.etypes}] * 2),
188
        }[sampler_name]
189

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
190
    # no negative sampler
191
192
    dataloader = dgl.dataloading.EdgeDataLoader(
        g2, {ety: g2.edges(form='eid', etype=ety) for ety in g2.canonical_etypes},
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
193
194
        sampler, device=F.ctx(), batch_size=batch_size)
    for input_nodes, pos_pair_graph, blocks in dataloader:
195
196
197
198
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(blocks)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
199
    # negative sampler
200
201
202
    dataloader = dgl.dataloading.EdgeDataLoader(
        g2, {ety: g2.edges(form='eid', etype=ety) for ety in g2.canonical_etypes},
        sampler, device=F.ctx(), negative_sampler=neg_sampler,
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
203
        batch_size=batch_size)
204

205
    assert isinstance(iter(dataloader), Iterator)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
206
    for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
207
208
209
210
211
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(neg_pair_graph)
        _check_device(blocks)

212
if __name__ == '__main__':
213
    test_graph_dataloader()
214
215
    test_cluster_gcn(0)
    test_neighbor_nonuniform(0)
216
    for sampler in ['full', 'neighbor']:
217
        test_node_dataloader(sampler)
218
219
220
221
222
        for neg_sampler in [
                dgl.dataloading.negative_sampler.Uniform(2),
                dgl.dataloading.negative_sampler.GlobalUniform(2, False),
                dgl.dataloading.negative_sampler.GlobalUniform(2, True)]:
            test_edge_dataloader(sampler, neg_sampler)