"docs/source/vscode:/vscode.git/clone" did not exist on "06e9ebebd51c3db779dedec5556251c8ecc3a00a"
test_mp_dataloader.py 10.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import dgl
import unittest
import os
from dgl.data import CitationGraphDataset
from dgl.distributed import sample_neighbors
from dgl.distributed import partition_graph, load_partition, load_partition_book
import sys
import multiprocessing as mp
import numpy as np
import time
from utils import get_local_usable_addr
from pathlib import Path
from dgl.distributed import DistGraphServer, DistGraph, DistDataLoader
import pytest
import backend as F

class NeighborSampler(object):
    def __init__(self, g, fanouts, sample_neighbors):
        self.g = g
        self.fanouts = fanouts
        self.sample_neighbors = sample_neighbors

    def sample_blocks(self, seeds):
        import torch as th
        seeds = th.LongTensor(np.asarray(seeds))
        blocks = []
        for fanout in self.fanouts:
            # For each seed node, sample ``fanout`` neighbors.
            frontier = self.sample_neighbors(
                self.g, seeds, fanout, replace=True)
            # Then we compact the frontier into a bipartite graph for message passing.
            block = dgl.to_block(frontier, seeds)
            # Obtain the seed nodes for next layer.
            seeds = block.srcdata[dgl.NID]

            blocks.insert(0, block)
        return blocks


def start_server(rank, tmpdir, disable_shared_mem, num_clients):
    import dgl
    print('server: #clients=' + str(num_clients))
43
    g = DistGraphServer(rank, "mp_ip_config.txt", 1, num_clients,
44
45
46
47
                        tmpdir / 'test_sampling.json', disable_shared_mem=disable_shared_mem)
    g.start()


48
def start_dist_dataloader(rank, tmpdir, num_server, num_workers, drop_last):
49
50
    import dgl
    import torch as th
51
    dgl.distributed.initialize("mp_ip_config.txt", 1, num_workers=num_workers)
52
    gpb = None
53
    disable_shared_mem = num_server > 0
54
    if disable_shared_mem:
55
        _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank)
56
57
58
    num_nodes_to_sample = 202
    batch_size = 32
    train_nid = th.arange(num_nodes_to_sample)
59
    dist_graph = DistGraph("test_mp", gpb=gpb, part_config=tmpdir / 'test_sampling.json')
60

61
62
63
64
65
66
67
68
69
    orig_nid = F.arange(0, dist_graph.number_of_nodes())
    orig_eid = F.arange(0, dist_graph.number_of_edges())
    for i in range(num_server):
        part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
        if 'orig_id' in part.ndata:
            orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
        if 'orig_id' in part.edata:
            orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']

70
71
72
73
    # Create sampler
    sampler = NeighborSampler(dist_graph, [5, 10],
                              dgl.distributed.sample_neighbors)

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    # We need to test creating DistDataLoader multiple times.
    for i in range(2):
        # Create DataLoader for constructing blocks
        dataloader = DistDataLoader(
            dataset=train_nid.numpy(),
            batch_size=batch_size,
            collate_fn=sampler.sample_blocks,
            shuffle=False,
            drop_last=drop_last)

        groundtruth_g = CitationGraphDataset("cora")[0]
        max_nid = []

        for epoch in range(2):
            for idx, blocks in zip(range(0, num_nodes_to_sample, batch_size), dataloader):
                block = blocks[-1]
                o_src, o_dst =  block.edges()
                src_nodes_id = block.srcdata[dgl.NID][o_src]
                dst_nodes_id = block.dstdata[dgl.NID][o_dst]
93
94
95
96
                max_nid.append(np.max(F.asnumpy(dst_nodes_id)))

                src_nodes_id = orig_nid[src_nodes_id]
                dst_nodes_id = orig_nid[dst_nodes_id]
97
98
99
100
101
102
103
                has_edges = groundtruth_g.has_edges_between(src_nodes_id, dst_nodes_id)
                assert np.all(F.asnumpy(has_edges))
                # assert np.all(np.unique(np.sort(F.asnumpy(dst_nodes_id))) == np.arange(idx, batch_size))
            if drop_last:
                assert np.max(max_nid) == num_nodes_to_sample - 1 - num_nodes_to_sample % batch_size
            else:
                assert np.max(max_nid) == num_nodes_to_sample - 1
104
105
    del dataloader
    dgl.distributed.exit_client() # this is needed since there's two test here in one process
106
107
108
109
110
111

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
def test_standalone(tmpdir):
    ip_config = open("mp_ip_config.txt", "w")
    for _ in range(1):
112
        ip_config.write('{}\n'.format(get_local_usable_addr()))
113
114
115
116
117
118
119
120
    ip_config.close()

    g = CitationGraphDataset("cora")[0]
    print(g.idtype)
    num_parts = 1
    num_hops = 1

    partition_graph(g, 'test_sampling', num_parts, tmpdir,
121
                    num_hops=num_hops, part_method='metis', reshuffle=True)
122
123

    os.environ['DGL_DIST_MODE'] = 'standalone'
124
    try:
125
        start_dist_dataloader(0, tmpdir, 1, 2, True)
126
127
    except Exception as e:
        print(e)
128
129
130
131
132
133
    dgl.distributed.exit_client() # this is needed since there's two test here in one process


@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@pytest.mark.parametrize("num_server", [3])
134
@pytest.mark.parametrize("num_workers", [0, 4])
135
@pytest.mark.parametrize("drop_last", [True, False])
136
137
@pytest.mark.parametrize("reshuffle", [True, False])
def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle):
138
139
    ip_config = open("mp_ip_config.txt", "w")
    for _ in range(num_server):
140
        ip_config.write('{}\n'.format(get_local_usable_addr()))
141
142
143
144
145
146
147
148
    ip_config.close()

    g = CitationGraphDataset("cora")[0]
    print(g.idtype)
    num_parts = num_server
    num_hops = 1

    partition_graph(g, 'test_sampling', num_parts, tmpdir,
149
                    num_hops=num_hops, part_method='metis', reshuffle=reshuffle)
150
151
152
153
154
155
156
157
158
159
160

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
        p = ctx.Process(target=start_server, args=(
            i, tmpdir, num_server > 1, num_workers+1))
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    time.sleep(3)
161
    os.environ['DGL_DIST_MODE'] = 'distributed'
162
    ptrainer = ctx.Process(target=start_dist_dataloader, args=(
163
        0, tmpdir, num_server, num_workers, drop_last))
164
165
166
167
168
169
170
    ptrainer.start()
    time.sleep(1)

    for p in pserver_list:
        p.join()
    ptrainer.join()

171
def start_node_dataloader(rank, tmpdir, num_server, num_workers):
172
173
174
175
    import dgl
    import torch as th
    dgl.distributed.initialize("mp_ip_config.txt", 1, num_workers=num_workers)
    gpb = None
176
    disable_shared_mem = num_server > 1
177
    if disable_shared_mem:
178
        _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank)
179
180
181
182
183
    num_nodes_to_sample = 202
    batch_size = 32
    train_nid = th.arange(num_nodes_to_sample)
    dist_graph = DistGraph("test_mp", gpb=gpb, part_config=tmpdir / 'test_sampling.json')

184
185
186
187
188
189
190
    orig_nid = F.zeros((dist_graph.number_of_nodes(),), dtype=F.int64)
    orig_eid = F.zeros((dist_graph.number_of_edges(),), dtype=F.int64)
    for i in range(num_server):
        part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
        orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
        orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']

191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    # Create sampler
    sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 10])

    # We need to test creating DistDataLoader multiple times.
    for i in range(2):
        # Create DataLoader for constructing blocks
        dataloader = dgl.dataloading.NodeDataLoader(
            dist_graph,
            train_nid,
            sampler,
            batch_size=batch_size,
            shuffle=True,
            drop_last=False,
            num_workers=num_workers)

        groundtruth_g = CitationGraphDataset("cora")[0]
        max_nid = []

        for epoch in range(2):
            for idx, (_, _, blocks) in zip(range(0, num_nodes_to_sample, batch_size), dataloader):
                block = blocks[-1]
                o_src, o_dst =  block.edges()
                src_nodes_id = block.srcdata[dgl.NID][o_src]
                dst_nodes_id = block.dstdata[dgl.NID][o_dst]
215
216
                src_nodes_id = orig_nid[src_nodes_id]
                dst_nodes_id = orig_nid[dst_nodes_id]
217
218
219
220
221
                has_edges = groundtruth_g.has_edges_between(src_nodes_id, dst_nodes_id)
                assert np.all(F.asnumpy(has_edges))
                max_nid.append(np.max(F.asnumpy(dst_nodes_id)))
                # assert np.all(np.unique(np.sort(F.asnumpy(dst_nodes_id))) == np.arange(idx, batch_size))
    del dataloader
222
    dgl.distributed.exit_client() # this is needed since there's two test here in one process
223

224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@pytest.mark.parametrize("num_server", [3])
@pytest.mark.parametrize("num_workers", [0, 4])
@pytest.mark.parametrize("dataloader_type", ["node"])
def test_dataloader(tmpdir, num_server, num_workers, dataloader_type):
    ip_config = open("mp_ip_config.txt", "w")
    for _ in range(num_server):
        ip_config.write('{}\n'.format(get_local_usable_addr()))
    ip_config.close()

    g = CitationGraphDataset("cora")[0]
    print(g.idtype)
    num_parts = num_server
    num_hops = 1

    partition_graph(g, 'test_sampling', num_parts, tmpdir,
242
                    num_hops=num_hops, part_method='metis', reshuffle=True)
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
        p = ctx.Process(target=start_server, args=(
            i, tmpdir, num_server > 1, num_workers+1))
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    time.sleep(3)
    os.environ['DGL_DIST_MODE'] = 'distributed'
    ptrainer_list = []
    if dataloader_type == 'node':
        p = ctx.Process(target=start_node_dataloader, args=(
258
            0, tmpdir, num_server, num_workers))
259
260
261
262
263
264
265
266
        p.start()
        time.sleep(1)
        ptrainer_list.append(p)
    for p in pserver_list:
        p.join()
    for p in ptrainer_list:
        p.join()

267
268
269
if __name__ == "__main__":
    import tempfile
    with tempfile.TemporaryDirectory() as tmpdirname:
270
        test_standalone(Path(tmpdirname))
271
272
273
274
275
        test_dist_dataloader(Path(tmpdirname), 3, 0, True, True)
        test_dist_dataloader(Path(tmpdirname), 3, 4, True, True)
        test_dist_dataloader(Path(tmpdirname), 3, 0, True, False)
        test_dist_dataloader(Path(tmpdirname), 3, 4, True, False)
        test_dataloader(Path(tmpdirname), 3, 4, 'node')