run_dist_objects.py 25.7 KB
Newer Older
1
import json
2
import os
3
from itertools import product
4
5

import dgl
6
import dgl.backend as F
7
8
9

import numpy as np
from dgl.distributed import edge_split, load_partition_book, node_split
10

11
12
13
14
15
16
17
18
19
20
21
mode = os.environ.get("DIST_DGL_TEST_MODE", "")
graph_name = os.environ.get("DIST_DGL_TEST_GRAPH_NAME", "random_test_graph")
num_part = int(os.environ.get("DIST_DGL_TEST_NUM_PART"))
num_servers_per_machine = int(os.environ.get("DIST_DGL_TEST_NUM_SERVER"))
num_client_per_machine = int(os.environ.get("DIST_DGL_TEST_NUM_CLIENT"))
shared_workspace = os.environ.get("DIST_DGL_TEST_WORKSPACE")
graph_path = os.environ.get("DIST_DGL_TEST_GRAPH_PATH")
part_id = int(os.environ.get("DIST_DGL_TEST_PART_ID"))
ip_config = os.environ.get("DIST_DGL_TEST_IP_CONFIG", "ip_config.txt")

os.environ["DGL_DIST_MODE"] = "distributed"
22

23

24
def batched_assert_zero(tensor, size):
25
    BATCH_SIZE = 2**16
26
27
28
29
30
    curr_pos = 0
    while curr_pos < size:
        end = min(curr_pos + BATCH_SIZE, size)
        assert F.sum(tensor[F.arange(curr_pos, end)], 0) == 0
        curr_pos = end
31

32

33
34
35
def zeros_init(shape, dtype):
    return F.zeros(shape, dtype=dtype, ctx=F.cpu())

36

37
38
def rand_init(shape, dtype):
    return F.tensor((np.random.randint(0, 100, size=shape) > 30), dtype=dtype)
39

40

41
42
43
44
45
46
47
48
def run_server(
    graph_name,
    server_id,
    server_count,
    num_clients,
    shared_mem,
    keep_alive=False,
):
49
    # server_count = num_servers_per_machine
50
51
52
53
54
55
56
57
58
59
60
    g = dgl.distributed.DistGraphServer(
        server_id,
        ip_config,
        server_count,
        num_clients,
        graph_path + "/{}.json".format(graph_name),
        disable_shared_mem=not shared_mem,
        graph_format=["csc", "coo"],
        keep_alive=keep_alive,
    )
    print("start server", server_id)
61
62
    g.start()

63

64
65
66
67
##########################################
############### DistGraph ###############
##########################################

68
69

def node_split_test(g, force_even, ntype="_N"):
70
71
    gpb = g.get_partition_book()

72
    selected_nodes_dist_tensor = dgl.distributed.DistTensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
73
        [g.num_nodes(ntype)], F.uint8, init_func=rand_init
74
    )
75

76
77
78
    nodes = node_split(
        selected_nodes_dist_tensor, gpb, ntype=ntype, force_even=force_even
    )
79
80
    g.barrier()

81
82
83
    selected_nodes_dist_tensor[nodes] = F.astype(
        F.zeros_like(nodes), selected_nodes_dist_tensor.dtype
    )
84
85
86
    g.barrier()

    if g.rank() == 0:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
87
        batched_assert_zero(selected_nodes_dist_tensor, g.num_nodes(ntype))
88
89
90

    g.barrier()

91
92

def edge_split_test(g, force_even, etype="_E"):
93
94
    gpb = g.get_partition_book()

95
    selected_edges_dist_tensor = dgl.distributed.DistTensor(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
96
        [g.num_edges(etype)], F.uint8, init_func=rand_init
97
    )
98

99
100
101
    edges = edge_split(
        selected_edges_dist_tensor, gpb, etype=etype, force_even=force_even
    )
102
103
    g.barrier()

104
105
106
    selected_edges_dist_tensor[edges] = F.astype(
        F.zeros_like(edges), selected_edges_dist_tensor.dtype
    )
107
108
109
    g.barrier()

    if g.rank() == 0:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
110
        batched_assert_zero(selected_edges_dist_tensor, g.num_edges(etype))
111
112
113

    g.barrier()

114

115
def test_dist_graph(g):
116
    gpb_path = graph_path + "/{}.json".format(graph_name)
117
118
    with open(gpb_path) as conf_f:
        part_metadata = json.load(conf_f)
119
120
121
122
    assert "num_nodes" in part_metadata
    assert "num_edges" in part_metadata
    num_nodes = part_metadata["num_nodes"]
    num_edges = part_metadata["num_edges"]
123

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
124
125
    assert g.num_nodes() == num_nodes
    assert g.num_edges() == num_edges
126

127
128
    num_nodes = {ntype: g.num_nodes(ntype) for ntype in g.ntypes}
    num_edges = {etype: g.num_edges(etype) for etype in g.etypes}
129
130

    for key, n_nodes in num_nodes.items():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
131
        assert g.num_nodes(key) == n_nodes
132
        node_split_test(g, force_even=False, ntype=key)
133
        node_split_test(g, force_even=True, ntype=key)
134
135

    for key, n_edges in num_edges.items():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
136
        assert g.num_edges(key) == n_edges
137
        edge_split_test(g, force_even=False, etype=key)
138
139
        edge_split_test(g, force_even=True, etype=key)

140
141
142
143
144

##########################################
########### DistGraphServices ###########
##########################################

145

146
147
148
149
150
def find_edges_test(g, orig_nid_map):
    etypes = g.canonical_etypes

    etype_eids_uv_map = dict()
    for u_type, etype, v_type in etypes:
151
152
        orig_u = g.edges[etype].data["edge_u"]
        orig_v = g.edges[etype].data["edge_v"]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
153
        eids = F.tensor(np.random.randint(g.num_edges(etype), size=100))
154
155
156
157
158
159
        u, v = g.find_edges(eids, etype=etype)
        assert F.allclose(orig_nid_map[u_type][u], orig_u[eids])
        assert F.allclose(orig_nid_map[v_type][v], orig_v[eids])
        etype_eids_uv_map[etype] = (eids, F.cat([u, v], dim=0))
    return etype_eids_uv_map

160

161
162
163
164
165
166
167
168
def edge_subgraph_test(g, etype_eids_uv_map):
    etypes = g.canonical_etypes
    all_eids = dict()
    for t in etypes:
        all_eids[t] = etype_eids_uv_map[t[1]][0]

    sg = g.edge_subgraph(all_eids)
    for t in etypes:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
169
        assert sg.num_edges(t[1]) == len(all_eids[t])
170
171
172
173
174
175
176
177
178
179
        assert F.allclose(sg.edges[t].data[dgl.EID], all_eids[t])

    for u_type, etype, v_type in etypes:
        uv = etype_eids_uv_map[etype][1]
        sg_u_nids = sg.nodes[u_type].data[dgl.NID]
        sg_v_nids = sg.nodes[v_type].data[dgl.NID]
        sg_uv = F.cat([sg_u_nids, sg_v_nids], dim=0)
        for node_id in uv:
            assert node_id in sg_uv

180

181
def sample_neighbors_with_args(g, size, fanout):
182
    num_nodes = {ntype: g.num_nodes(ntype) for ntype in g.ntypes}
183
184
    etypes = g.canonical_etypes

185
186
187
188
189
190
191
    sampled_graph = g.sample_neighbors(
        {
            ntype: np.random.randint(0, n, size=size)
            for ntype, n in num_nodes.items()
        },
        fanout,
    )
192
193

    for ntype, n in num_nodes.items():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
194
        assert sampled_graph.num_nodes(ntype) == n
195
196
197
198
199
200
201
    for t in etypes:
        src, dst = sampled_graph.edges(etype=t)
        eids = sampled_graph.edges[t].data[dgl.EID]
        dist_u, dist_v = g.find_edges(eids, etype=t[1])
        assert F.allclose(dist_u, src)
        assert F.allclose(dist_v, dst)

202

203
204
205
206
207
208
209
210
def sample_neighbors_test(g):
    sample_neighbors_with_args(g, size=1024, fanout=3)
    sample_neighbors_with_args(g, size=1, fanout=10)
    sample_neighbors_with_args(g, size=1024, fanout=2)
    sample_neighbors_with_args(g, size=10, fanout=-1)
    sample_neighbors_with_args(g, size=2**10, fanout=1)
    sample_neighbors_with_args(g, size=2**12, fanout=1)

211

212
def test_dist_graph_services(g):
213
214
215
216
217
    # in_degrees and out_degrees does not support heterograph
    if len(g.etypes) == 1:
        nids = F.arange(0, 128)

        # Test in_degrees
218
        orig_in_degrees = g.ndata["in_degrees"]
219
220
221
222
        local_in_degrees = g.in_degrees(nids)
        F.allclose(local_in_degrees, orig_in_degrees[nids])

        # Test out_degrees
223
        orig_out_degrees = g.ndata["out_degrees"]
224
225
226
        local_out_degrees = g.out_degrees(nids)
        F.allclose(local_out_degrees, orig_out_degrees[nids])

227
    num_nodes = {ntype: g.num_nodes(ntype) for ntype in g.ntypes}
228
229

    orig_nid_map = dict()
230
    dtype = g.edges[g.etypes[0]].data["edge_u"].dtype
231
    for ntype, _ in num_nodes.items():
232
233
234
        orig_nid = F.tensor(
            np.load(graph_path + f"/orig_nid_array_{ntype}.npy"), dtype
        )
235
236
237
238
239
240
        orig_nid_map[ntype] = orig_nid

    etype_eids_uv_map = find_edges_test(g, orig_nid_map)
    edge_subgraph_test(g, etype_eids_uv_map)
    sample_neighbors_test(g)

241

242
243
244
245
##########################################
############### DistTensor ###############
##########################################

246

247
248
def dist_tensor_test_sanity(data_shape, name=None):
    local_rank = dgl.distributed.get_rank() % num_client_per_machine
249
250
251
    dist_ten = dgl.distributed.DistTensor(
        data_shape, F.int32, init_func=zeros_init, name=name
    )
252
253
    # arbitrary value
    stride = 3
254
255
    pos = (part_id // 2) * num_client_per_machine + local_rank
    if part_id % 2 == 0:
256
        dist_ten[pos * stride : (pos + 1) * stride] = F.ones(
257
258
            (stride, 2), dtype=F.int32, ctx=F.cpu()
        ) * (pos + 1)
259
260

    dgl.distributed.client_barrier()
261
    assert F.allclose(
262
        dist_ten[pos * stride : (pos + 1) * stride],
263
264
        F.ones((stride, 2), dtype=F.int32, ctx=F.cpu()) * (pos + 1),
    )
265

266
267

def dist_tensor_test_destroy_recreate(data_shape, name):
268
269
270
    dist_ten = dgl.distributed.DistTensor(
        data_shape, F.float32, name, init_func=zeros_init
    )
271
272
273
274
275
    del dist_ten

    dgl.distributed.client_barrier()

    new_shape = (data_shape[0], 4)
276
277
278
279
    dist_ten = dgl.distributed.DistTensor(
        new_shape, F.float32, name, init_func=zeros_init
    )

280
281

def dist_tensor_test_persistent(data_shape):
282
283
284
285
286
287
288
289
    dist_ten_name = "persistent_dist_tensor"
    dist_ten = dgl.distributed.DistTensor(
        data_shape,
        F.float32,
        dist_ten_name,
        init_func=zeros_init,
        persistent=True,
    )
290
291
    del dist_ten
    try:
292
293
294
295
        dist_ten = dgl.distributed.DistTensor(
            data_shape, F.float32, dist_ten_name
        )
        raise Exception("")
296
    except BaseException:
297
298
299
        pass


300
def test_dist_tensor(g):
301
    first_type = g.ntypes[0]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
302
    data_shape = (g.num_nodes(first_type), 2)
303
304
    dist_tensor_test_sanity(data_shape)
    dist_tensor_test_sanity(data_shape, name="DistTensorSanity")
305
306
307
308
    dist_tensor_test_destroy_recreate(data_shape, name="DistTensorRecreate")
    dist_tensor_test_persistent(data_shape)


309
310
311
312
##########################################
############# DistEmbedding ##############
##########################################

313

314
def dist_embedding_check_sanity(num_nodes, optimizer, name=None):
315
    local_rank = dgl.distributed.get_rank() % num_client_per_machine
316

317
318
319
    emb = dgl.distributed.DistEmbedding(
        num_nodes, 1, name=name, init_func=zeros_init
    )
320
321
322
323
324
325
    lr = 0.001
    optim = optimizer(params=[emb], lr=lr)

    stride = 3

    pos = (part_id // 2) * num_client_per_machine + local_rank
326
    idx = F.arange(pos * stride, (pos + 1) * stride)
327
328
329
330
331
332
333
334
335
336
337
338
339

    if part_id % 2 == 0:
        with F.record_grad():
            value = emb(idx)
            optim.zero_grad()
            loss = F.sum(value + 1, 0)
        loss.backward()
        optim.step()

    dgl.distributed.client_barrier()
    value = emb(idx)
    F.allclose(value, F.ones((len(idx), 1), dtype=F.int32, ctx=F.cpu()) * -lr)

340
341
342
    not_update_idx = F.arange(
        ((num_part + 1) / 2) * num_client_per_machine * stride, num_nodes
    )
343
344
345
346
347
348
    value = emb(not_update_idx)
    assert np.all(F.asnumpy(value) == np.zeros((len(not_update_idx), 1)))


def dist_embedding_check_existing(num_nodes):
    dist_emb_name = "UniqueEmb"
349
350
351
    emb = dgl.distributed.DistEmbedding(
        num_nodes, 1, name=dist_emb_name, init_func=zeros_init
    )
352
    try:
353
354
355
356
        emb1 = dgl.distributed.DistEmbedding(
            num_nodes, 2, name=dist_emb_name, init_func=zeros_init
        )
        raise Exception("")
357
    except BaseException:
358
359
        pass

360

361
def test_dist_embedding(g):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
362
    num_nodes = g.num_nodes(g.ntypes[0])
363
    dist_embedding_check_sanity(num_nodes, dgl.distributed.optim.SparseAdagrad)
364
365
366
367
368
369
    dist_embedding_check_sanity(
        num_nodes, dgl.distributed.optim.SparseAdagrad, name="SomeEmbedding"
    )
    dist_embedding_check_sanity(
        num_nodes, dgl.distributed.optim.SparseAdam, name="SomeEmbedding"
    )
370
371
372

    dist_embedding_check_existing(num_nodes)

373

374
375
376
377
378
379
##########################################
############# DistOptimizer ##############
##########################################


def dist_optimizer_check_store(g):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
380
    num_nodes = g.num_nodes(g.ntypes[0])
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
    rank = g.rank()
    try:
        emb = dgl.distributed.DistEmbedding(
            num_nodes, 1, name="optimizer_test", init_func=zeros_init
        )
        emb2 = dgl.distributed.DistEmbedding(
            num_nodes, 5, name="optimizer_test2", init_func=zeros_init
        )
        emb_optimizer = dgl.distributed.optim.SparseAdam([emb, emb2], lr=0.1)
        if rank == 0:
            name_to_state = {}
            for _, emb_states in emb_optimizer._state.items():
                for state in emb_states:
                    name_to_state[state.name] = F.uniform(
                        state.shape, F.float32, F.cpu(), 0, 1
                    )
                    state[
                        F.arange(0, num_nodes, F.int64, F.cpu())
                    ] = name_to_state[state.name]
        emb_optimizer.save("emb.pt")
        new_emb_optimizer = dgl.distributed.optim.SparseAdam(
            [emb, emb2], lr=000.1, eps=2e-08, betas=(0.1, 0.222)
        )
        new_emb_optimizer.load("emb.pt")
        if rank == 0:
            for _, emb_states in new_emb_optimizer._state.items():
                for new_state in emb_states:
                    state = name_to_state[new_state.name]
                    new_state = new_state[
                        F.arange(0, num_nodes, F.int64, F.cpu())
                    ]
412
                    assert F.allclose(state, new_state, 0.0, 0.0)
413
414
415
416
417
418
            assert new_emb_optimizer._lr == emb_optimizer._lr
            assert new_emb_optimizer._eps == emb_optimizer._eps
            assert new_emb_optimizer._beta1 == emb_optimizer._beta1
            assert new_emb_optimizer._beta2 == emb_optimizer._beta2
        g.barrier()
    finally:
419
        file = f"emb.pt_{rank}"
420
421
422
        if os.path.exists(file):
            os.remove(file)

423

424
425
426
def test_dist_optimizer(g):
    dist_optimizer_check_store(g)

427

428
429
430
431
##########################################
############# DistDataLoader #############
##########################################

432

433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
class NeighborSampler(object):
    def __init__(self, g, fanouts, sample_neighbors):
        self.g = g
        self.fanouts = fanouts
        self.sample_neighbors = sample_neighbors

    def sample_blocks(self, seeds):
        import torch as th

        seeds = th.LongTensor(np.asarray(seeds))
        blocks = []
        for fanout in self.fanouts:
            # For each seed node, sample ``fanout`` neighbors.
            frontier = self.sample_neighbors(
                self.g, seeds, fanout, replace=True
            )
            # Then we compact the frontier into a bipartite graph for
            # message passing.
            block = dgl.to_block(frontier, seeds)
            # Obtain the seed nodes for next layer.
            seeds = block.srcdata[dgl.NID]
            block.edata["original_eids"] = frontier.edata[dgl.EID]

            blocks.insert(0, block)
        return blocks

459

460
461
def distdataloader_test(g, batch_size, drop_last, shuffle):
    # We sample only a subset to minimize the test runtime
462
    num_nodes_to_sample = int(g.num_nodes() * 0.05)
463
464
465
466
467
    # To make sure that drop_last is tested
    if num_nodes_to_sample % batch_size == 0:
        num_nodes_to_sample -= 1

    orig_nid_map = dict()
468
    dtype = g.edges[g.etypes[0]].data["edge_u"].dtype
469
    for ntype in g.ntypes:
470
471
472
        orig_nid = F.tensor(
            np.load(graph_path + f"/orig_nid_array_{ntype}.npy"), dtype
        )
473
474
475
476
        orig_nid_map[ntype] = orig_nid

    orig_uv_map = dict()
    for etype in g.etypes:
477
478
479
480
        orig_uv_map[etype] = (
            g.edges[etype].data["edge_u"],
            g.edges[etype].data["edge_v"],
        )
481
482
483
484
485
486

    if len(g.ntypes) == 1:
        train_nid = F.arange(0, num_nodes_to_sample)
    else:
        train_nid = {g.ntypes[0]: F.arange(0, num_nodes_to_sample)}

487
    sampler = NeighborSampler(g, [5, 10], dgl.distributed.sample_neighbors)
488
489
490
491
492
493
494
495
496
497
498

    dataloader = dgl.dataloading.DistDataLoader(
        dataset=train_nid.numpy(),
        batch_size=batch_size,
        collate_fn=sampler.sample_blocks,
        shuffle=shuffle,
        drop_last=drop_last,
    )

    for _ in range(2):
        max_nid = []
499
500
501
        for idx, blocks in zip(
            range(0, num_nodes_to_sample, batch_size), dataloader
        ):
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
            block = blocks[-1]
            for src_type, etype, dst_type in block.canonical_etypes:
                orig_u, orig_v = orig_uv_map[etype]
                o_src, o_dst = block.edges(etype=etype)
                src_nodes_id = block.srcnodes[src_type].data[dgl.NID][o_src]
                dst_nodes_id = block.dstnodes[dst_type].data[dgl.NID][o_dst]
                max_nid.append(np.max(F.asnumpy(dst_nodes_id)))

                src_nodes_id = orig_nid_map[src_type][src_nodes_id]
                dst_nodes_id = orig_nid_map[dst_type][dst_nodes_id]
                eids = block.edata["original_eids"]
                F.allclose(src_nodes_id, orig_u[eids])
                F.allclose(dst_nodes_id, orig_v[eids])
        if not shuffle and len(max_nid) > 0:
            if drop_last:
                assert (
                    np.max(max_nid)
                    == num_nodes_to_sample
                    - 1
                    - num_nodes_to_sample % batch_size
                )
            else:
                assert np.max(max_nid) == num_nodes_to_sample - 1
    del dataloader

527
528
529
530

def distnodedataloader_test(
    g, batch_size, drop_last, shuffle, num_workers, orig_nid_map, orig_uv_map
):
531
    # We sample only a subset to minimize the test runtime
532
    num_nodes_to_sample = int(g.num_nodes(g.ntypes[-1]) * 0.05)
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
    # To make sure that drop_last is tested
    if num_nodes_to_sample % batch_size == 0:
        num_nodes_to_sample -= 1

    if len(g.ntypes) == 1:
        train_nid = F.arange(0, num_nodes_to_sample)
    else:
        train_nid = {g.ntypes[-1]: F.arange(0, num_nodes_to_sample)}

    if len(g.etypes) > 1:
        sampler = dgl.dataloading.MultiLayerNeighborSampler(
            [
                {etype: 5 for etype in g.etypes},
                10,
            ]
        )
    else:
        sampler = dgl.dataloading.MultiLayerNeighborSampler(
            [
                5,
                10,
            ]
        )

    dataloader = dgl.dataloading.DistNodeDataLoader(
        g,
        train_nid,
        sampler,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers,
    )

    for _ in range(2):
568
569
570
        for _, (_, _, blocks) in zip(
            range(0, num_nodes_to_sample, batch_size), dataloader
        ):
571
572
573
574
575
576
577
578
579
580
581
582
583
584
            block = blocks[-1]
            for src_type, etype, dst_type in block.canonical_etypes:
                orig_u, orig_v = orig_uv_map[etype]
                o_src, o_dst = block.edges(etype=etype)
                src_nodes_id = block.srcnodes[src_type].data[dgl.NID][o_src]
                dst_nodes_id = block.dstnodes[dst_type].data[dgl.NID][o_dst]
                src_nodes_id = orig_nid_map[src_type][src_nodes_id]
                dst_nodes_id = orig_nid_map[dst_type][dst_nodes_id]
                eids = block.edges[etype].data[dgl.EID]
                F.allclose(src_nodes_id, orig_u[eids])
                F.allclose(dst_nodes_id, orig_v[eids])
    del dataloader


585
586
587
588
589
590
591
592
593
594
def distedgedataloader_test(
    g,
    batch_size,
    drop_last,
    shuffle,
    num_workers,
    orig_nid_map,
    orig_uv_map,
    num_negs,
):
595
    # We sample only a subset to minimize the test runtime
596
    num_edges_to_sample = int(g.num_edges(g.etypes[-1]) * 0.05)
597
598
599
600
601
602
603
604
605
606
607
608
    # To make sure that drop_last is tested
    if num_edges_to_sample % batch_size == 0:
        num_edges_to_sample -= 1

    if len(g.etypes) == 1:
        train_eid = F.arange(0, num_edges_to_sample)
    else:
        train_eid = {g.etypes[-1]: F.arange(0, num_edges_to_sample)}

    sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 10])

    dataloader = dgl.dataloading.DistEdgeDataLoader(
609
610
611
612
613
614
615
616
617
618
        g,
        train_eid,
        sampler,
        batch_size=batch_size,
        negative_sampler=dgl.dataloading.negative_sampler.Uniform(num_negs)
        if num_negs > 0
        else None,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers,
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
    )
    for _ in range(2):
        for _, sampled_data in zip(
            range(0, num_edges_to_sample, batch_size), dataloader
        ):
            blocks = sampled_data[3 if num_negs > 0 else 2]
            block = blocks[-1]
            for src_type, etype, dst_type in block.canonical_etypes:
                orig_u, orig_v = orig_uv_map[etype]
                o_src, o_dst = block.edges(etype=etype)
                src_nodes_id = block.srcnodes[src_type].data[dgl.NID][o_src]
                dst_nodes_id = block.dstnodes[dst_type].data[dgl.NID][o_dst]
                src_nodes_id = orig_nid_map[src_type][src_nodes_id]
                dst_nodes_id = orig_nid_map[dst_type][dst_nodes_id]
                eids = block.edges[etype].data[dgl.EID]
                F.allclose(src_nodes_id, orig_u[eids])
                F.allclose(dst_nodes_id, orig_v[eids])
                if num_negs == 0:
                    pos_pair_graph = sampled_data[1]
                    assert np.all(
                        F.asnumpy(block.dstnodes[dst_type].data[dgl.NID])
640
641
642
                        == F.asnumpy(
                            pos_pair_graph.nodes[dst_type].data[dgl.NID]
                        )
643
644
645
646
647
648
649
650
651
652
653
                    )
                else:
                    pos_graph, neg_graph = sampled_data[1:3]
                    assert np.all(
                        F.asnumpy(block.dstnodes[dst_type].data[dgl.NID])
                        == F.asnumpy(pos_graph.nodes[dst_type].data[dgl.NID])
                    )
                    assert np.all(
                        F.asnumpy(block.dstnodes[dst_type].data[dgl.NID])
                        == F.asnumpy(neg_graph.nodes[dst_type].data[dgl.NID])
                    )
654
655
656
657
                    assert (
                        pos_graph.num_edges() * num_negs
                        == neg_graph.num_edges()
                    )
658
659
    del dataloader

660

661
def multi_distdataloader_test(g, dataloader_class):
662
663
664
665
666
    total_num_items = (
        g.num_nodes(g.ntypes[-1])
        if "Node" in dataloader_class.__name__
        else g.num_edges(g.etypes[-1])
    )
667

668
669
    num_dataloaders = 4
    batch_size = 32
670
671
672
673
    sampler = dgl.dataloading.NeighborSampler([-1])
    dataloaders = []
    dl_iters = []

674
675
676
677
678
679
    # We sample only a subset to minimize the test runtime
    num_items_to_sample = int(total_num_items * 0.05)
    # To make sure that drop_last is tested
    if num_items_to_sample % batch_size == 0:
        num_items_to_sample -= 1

680
681
682
    if len(g.ntypes) == 1:
        train_ids = F.arange(0, num_items_to_sample)
    else:
683
684
685
686
687
        train_ids = {
            g.ntypes[-1]
            if "Node" in dataloader_class.__name__
            else g.etypes[-1]: F.arange(0, num_items_to_sample)
        }
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704

    for _ in range(num_dataloaders):
        dataloader = dataloader_class(
            g, train_ids, sampler, batch_size=batch_size
        )
        dataloaders.append(dataloader)
        dl_iters.append(iter(dataloader))

    # iterate on multiple dataloaders randomly
    while len(dl_iters) > 0:
        current_dl = np.random.choice(len(dl_iters), 1)[0]
        try:
            _ = next(dl_iters[current_dl])
        except StopIteration:
            dl_iters.pop(current_dl)
            del dataloaders[current_dl]

705

706
707
def test_dist_dataloader(g):
    orig_nid_map = dict()
708
    dtype = g.edges[g.etypes[0]].data["edge_u"].dtype
709
    for ntype in g.ntypes:
710
711
712
        orig_nid = F.tensor(
            np.load(graph_path + f"/orig_nid_array_{ntype}.npy"), dtype
        )
713
714
715
716
        orig_nid_map[ntype] = orig_nid

    orig_uv_map = dict()
    for etype in g.etypes:
717
718
719
720
        orig_uv_map[etype] = (
            g.edges[etype].data["edge_u"],
            g.edges[etype].data["edge_v"],
        )
721
722
723
724
725
726

    batch_size_l = [64]
    drop_last_l = [False, True]
    num_workers_l = [0, 4]
    shuffle_l = [False, True]

727
728
729
    for batch_size, drop_last, shuffle, num_workers in product(
        batch_size_l, drop_last_l, shuffle_l, num_workers_l
    ):
730
731
        if len(g.ntypes) == 1 and num_workers == 0:
            distdataloader_test(g, batch_size, drop_last, shuffle)
732
733
734
735
736
737
738
739
740
        distnodedataloader_test(
            g,
            batch_size,
            drop_last,
            shuffle,
            num_workers,
            orig_nid_map,
            orig_uv_map,
        )
741
        # No negssampling
742
743
744
745
746
747
748
749
750
751
        distedgedataloader_test(
            g,
            batch_size,
            drop_last,
            shuffle,
            num_workers,
            orig_nid_map,
            orig_uv_map,
            num_negs=0,
        )
752
        # negsampling 15
753
754
755
756
757
758
759
760
761
762
        distedgedataloader_test(
            g,
            batch_size,
            drop_last,
            shuffle,
            num_workers,
            orig_nid_map,
            orig_uv_map,
            num_negs=15,
        )
763
764
765
766

    multi_distdataloader_test(g, dgl.dataloading.DistNodeDataLoader)
    multi_distdataloader_test(g, dgl.dataloading.DistEdgeDataLoader)

767

768
if mode == "server":
769
770
771
772
773
774
775
776
777
778
    shared_mem = bool(int(os.environ.get("DIST_DGL_TEST_SHARED_MEM")))
    server_id = int(os.environ.get("DIST_DGL_TEST_SERVER_ID"))
    run_server(
        graph_name,
        server_id,
        server_count=num_servers_per_machine,
        num_clients=num_part * num_client_per_machine,
        shared_mem=shared_mem,
        keep_alive=False,
    )
779
elif mode == "client":
780
    os.environ["DGL_NUM_SERVER"] = str(num_servers_per_machine)
781
    dgl.distributed.initialize(ip_config)
782

783
    gpb, graph_name, _, _ = load_partition_book(
784
        graph_path + "/{}.json".format(graph_name), part_id
785
    )
786
    g = dgl.distributed.DistGraph(graph_name, gpb=gpb)
787

788
    target_func_map = {
789
790
        "DistGraph": test_dist_graph,
        "DistGraphServices": test_dist_graph_services,
791
792
        "DistTensor": test_dist_tensor,
        "DistEmbedding": test_dist_embedding,
793
        "DistOptimizer": test_dist_optimizer,
794
        "DistDataLoader": test_dist_dataloader,
795
    }
796

797
    targets = os.environ.get("DIST_DGL_TEST_OBJECT_TYPE", "")
798
    targets = targets.replace(" ", "").split(",") if targets else []
799
    blacklist = os.environ.get("DIST_DGL_TEST_OBJECT_TYPE_BLACKLIST", "")
800
    blacklist = blacklist.replace(" ", "").split(",") if blacklist else []
801
802
803
804
805

    for to_bl in blacklist:
        target_func_map.pop(to_bl, None)

    if not targets:
806
807
        for test_func in target_func_map.values():
            test_func(g)
808
    else:
809
810
811
812
813
        for target in targets:
            if target in target_func_map:
                target_func_map[target](g)
            else:
                print(f"Tests not implemented for target '{target}'")
814

815
816
else:
    exit(1)