test_mp_dataloader.py 10.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import dgl
import unittest
import os
from dgl.data import CitationGraphDataset
from dgl.distributed import sample_neighbors
from dgl.distributed import partition_graph, load_partition, load_partition_book
import sys
import multiprocessing as mp
import numpy as np
import time
from utils import get_local_usable_addr
from pathlib import Path
from dgl.distributed import DistGraphServer, DistGraph, DistDataLoader
import pytest
import backend as F

class NeighborSampler(object):
    def __init__(self, g, fanouts, sample_neighbors):
        self.g = g
        self.fanouts = fanouts
        self.sample_neighbors = sample_neighbors

    def sample_blocks(self, seeds):
        import torch as th
        seeds = th.LongTensor(np.asarray(seeds))
        blocks = []
        for fanout in self.fanouts:
            # For each seed node, sample ``fanout`` neighbors.
            frontier = self.sample_neighbors(
                self.g, seeds, fanout, replace=True)
            # Then we compact the frontier into a bipartite graph for message passing.
            block = dgl.to_block(frontier, seeds)
            # Obtain the seed nodes for next layer.
            seeds = block.srcdata[dgl.NID]

            blocks.insert(0, block)
        return blocks


def start_server(rank, tmpdir, disable_shared_mem, num_clients):
    import dgl
    print('server: #clients=' + str(num_clients))
43
    g = DistGraphServer(rank, "mp_ip_config.txt", 1, num_clients,
44
45
                        tmpdir / 'test_sampling.json', disable_shared_mem=disable_shared_mem,
                        graph_format=['csc', 'coo'])
46
47
48
    g.start()


49
def start_dist_dataloader(rank, tmpdir, num_server, drop_last, orig_nid, orig_eid):
50
51
    import dgl
    import torch as th
52
    dgl.distributed.initialize("mp_ip_config.txt")
53
    gpb = None
54
    disable_shared_mem = num_server > 0
55
    if disable_shared_mem:
56
        _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank)
57
58
59
    num_nodes_to_sample = 202
    batch_size = 32
    train_nid = th.arange(num_nodes_to_sample)
60
    dist_graph = DistGraph("test_mp", gpb=gpb, part_config=tmpdir / 'test_sampling.json')
61

62
63
64
    for i in range(num_server):
        part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)

65
66
67
68
    # Create sampler
    sampler = NeighborSampler(dist_graph, [5, 10],
                              dgl.distributed.sample_neighbors)

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    # We need to test creating DistDataLoader multiple times.
    for i in range(2):
        # Create DataLoader for constructing blocks
        dataloader = DistDataLoader(
            dataset=train_nid.numpy(),
            batch_size=batch_size,
            collate_fn=sampler.sample_blocks,
            shuffle=False,
            drop_last=drop_last)

        groundtruth_g = CitationGraphDataset("cora")[0]
        max_nid = []

        for epoch in range(2):
            for idx, blocks in zip(range(0, num_nodes_to_sample, batch_size), dataloader):
                block = blocks[-1]
                o_src, o_dst =  block.edges()
                src_nodes_id = block.srcdata[dgl.NID][o_src]
                dst_nodes_id = block.dstdata[dgl.NID][o_dst]
88
89
90
91
                max_nid.append(np.max(F.asnumpy(dst_nodes_id)))

                src_nodes_id = orig_nid[src_nodes_id]
                dst_nodes_id = orig_nid[dst_nodes_id]
92
93
94
95
96
97
98
                has_edges = groundtruth_g.has_edges_between(src_nodes_id, dst_nodes_id)
                assert np.all(F.asnumpy(has_edges))
                # assert np.all(np.unique(np.sort(F.asnumpy(dst_nodes_id))) == np.arange(idx, batch_size))
            if drop_last:
                assert np.max(max_nid) == num_nodes_to_sample - 1 - num_nodes_to_sample % batch_size
            else:
                assert np.max(max_nid) == num_nodes_to_sample - 1
99
100
    del dataloader
    dgl.distributed.exit_client() # this is needed since there's two test here in one process
101
102
103
104
105
106

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
def test_standalone(tmpdir):
    ip_config = open("mp_ip_config.txt", "w")
    for _ in range(1):
107
        ip_config.write('{}\n'.format(get_local_usable_addr()))
108
109
110
111
112
113
114
    ip_config.close()

    g = CitationGraphDataset("cora")[0]
    print(g.idtype)
    num_parts = 1
    num_hops = 1

115
116
117
    orig_nid, orig_eid = partition_graph(g, 'test_sampling', num_parts, tmpdir,
                                         num_hops=num_hops, part_method='metis', reshuffle=True,
                                         return_mapping=True)
118
119

    os.environ['DGL_DIST_MODE'] = 'standalone'
120
    try:
121
        start_dist_dataloader(0, tmpdir, 1, True, orig_nid, orig_eid)
122
123
    except Exception as e:
        print(e)
124
125
126
127
128
129
    dgl.distributed.exit_client() # this is needed since there's two test here in one process


@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@pytest.mark.parametrize("num_server", [3])
130
@pytest.mark.parametrize("num_workers", [0, 4])
131
@pytest.mark.parametrize("drop_last", [True, False])
132
133
@pytest.mark.parametrize("reshuffle", [True, False])
def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle):
134
135
    ip_config = open("mp_ip_config.txt", "w")
    for _ in range(num_server):
136
        ip_config.write('{}\n'.format(get_local_usable_addr()))
137
138
139
140
141
142
143
    ip_config.close()

    g = CitationGraphDataset("cora")[0]
    print(g.idtype)
    num_parts = num_server
    num_hops = 1

144
145
146
    orig_nid, orig_eid = partition_graph(g, 'test_sampling', num_parts, tmpdir,
                                         num_hops=num_hops, part_method='metis',
                                         reshuffle=reshuffle, return_mapping=True)
147
148
149
150
151
152
153
154
155
156
157

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
        p = ctx.Process(target=start_server, args=(
            i, tmpdir, num_server > 1, num_workers+1))
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    time.sleep(3)
158
    os.environ['DGL_DIST_MODE'] = 'distributed'
159
    os.environ['DGL_NUM_SAMPLER'] = str(num_workers)
160
    ptrainer = ctx.Process(target=start_dist_dataloader, args=(
161
        0, tmpdir, num_server, drop_last, orig_nid, orig_eid))
162
163
164
165
166
167
168
    ptrainer.start()
    time.sleep(1)

    for p in pserver_list:
        p.join()
    ptrainer.join()

169
def start_node_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, orig_eid):
170
171
    import dgl
    import torch as th
172
    dgl.distributed.initialize("mp_ip_config.txt")
173
    gpb = None
174
    disable_shared_mem = num_server > 1
175
    if disable_shared_mem:
176
        _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank)
177
178
179
180
181
    num_nodes_to_sample = 202
    batch_size = 32
    train_nid = th.arange(num_nodes_to_sample)
    dist_graph = DistGraph("test_mp", gpb=gpb, part_config=tmpdir / 'test_sampling.json')

182
183
184
    for i in range(num_server):
        part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)

185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    # Create sampler
    sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 10])

    # We need to test creating DistDataLoader multiple times.
    for i in range(2):
        # Create DataLoader for constructing blocks
        dataloader = dgl.dataloading.NodeDataLoader(
            dist_graph,
            train_nid,
            sampler,
            batch_size=batch_size,
            shuffle=True,
            drop_last=False,
            num_workers=num_workers)

        groundtruth_g = CitationGraphDataset("cora")[0]
        max_nid = []

        for epoch in range(2):
            for idx, (_, _, blocks) in zip(range(0, num_nodes_to_sample, batch_size), dataloader):
                block = blocks[-1]
                o_src, o_dst =  block.edges()
                src_nodes_id = block.srcdata[dgl.NID][o_src]
                dst_nodes_id = block.dstdata[dgl.NID][o_dst]
209
210
                src_nodes_id = orig_nid[src_nodes_id]
                dst_nodes_id = orig_nid[dst_nodes_id]
211
212
213
214
215
                has_edges = groundtruth_g.has_edges_between(src_nodes_id, dst_nodes_id)
                assert np.all(F.asnumpy(has_edges))
                max_nid.append(np.max(F.asnumpy(dst_nodes_id)))
                # assert np.all(np.unique(np.sort(F.asnumpy(dst_nodes_id))) == np.arange(idx, batch_size))
    del dataloader
216
    dgl.distributed.exit_client() # this is needed since there's two test here in one process
217

218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@pytest.mark.parametrize("num_server", [3])
@pytest.mark.parametrize("num_workers", [0, 4])
@pytest.mark.parametrize("dataloader_type", ["node"])
def test_dataloader(tmpdir, num_server, num_workers, dataloader_type):
    ip_config = open("mp_ip_config.txt", "w")
    for _ in range(num_server):
        ip_config.write('{}\n'.format(get_local_usable_addr()))
    ip_config.close()

    g = CitationGraphDataset("cora")[0]
    print(g.idtype)
    num_parts = num_server
    num_hops = 1

235
236
237
    orig_nid, orig_eid = partition_graph(g, 'test_sampling', num_parts, tmpdir,
                                         num_hops=num_hops, part_method='metis',
                                         reshuffle=True, return_mapping=True)
238
239
240
241
242
243
244
245
246
247
248
249

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
        p = ctx.Process(target=start_server, args=(
            i, tmpdir, num_server > 1, num_workers+1))
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    time.sleep(3)
    os.environ['DGL_DIST_MODE'] = 'distributed'
250
    os.environ['DGL_NUM_SAMPLER'] = str(num_workers)
251
252
253
    ptrainer_list = []
    if dataloader_type == 'node':
        p = ctx.Process(target=start_node_dataloader, args=(
254
            0, tmpdir, num_server, num_workers, orig_nid, orig_eid))
255
256
257
258
259
260
261
262
        p.start()
        time.sleep(1)
        ptrainer_list.append(p)
    for p in pserver_list:
        p.join()
    for p in ptrainer_list:
        p.join()

263
264
265
if __name__ == "__main__":
    import tempfile
    with tempfile.TemporaryDirectory() as tmpdirname:
266
        test_standalone(Path(tmpdirname))
267
268
269
270
271
        test_dist_dataloader(Path(tmpdirname), 3, 0, True, True)
        test_dist_dataloader(Path(tmpdirname), 3, 4, True, True)
        test_dist_dataloader(Path(tmpdirname), 3, 0, True, False)
        test_dist_dataloader(Path(tmpdirname), 3, 4, True, False)
        test_dataloader(Path(tmpdirname), 3, 4, 'node')