test_distributed_sampling.py 4.91 KB
Newer Older
Jinjing Zhou's avatar
Jinjing Zhou committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import dgl
import unittest
import os
from dgl.data import CitationGraphDataset
from dgl.distributed.sampling import sample_neighbors
from dgl.distributed import partition_graph, load_partition, load_partition_book
import sys
import multiprocessing as mp
import numpy as np
import backend as F
import time
from utils import get_local_usable_addr
from pathlib import Path

from dgl.distributed import DistGraphServer, DistGraph


18
def start_server(rank, tmpdir, disable_shared_mem):
Jinjing Zhou's avatar
Jinjing Zhou committed
19
20
    import dgl
    g = DistGraphServer(rank, "rpc_sampling_ip_config.txt", 1, "test_sampling",
21
                        tmpdir / 'test_sampling.json', disable_shared_mem=disable_shared_mem)
Jinjing Zhou's avatar
Jinjing Zhou committed
22
23
24
    g.start()


25
def start_client(rank, tmpdir, disable_shared_mem):
Jinjing Zhou's avatar
Jinjing Zhou committed
26
    import dgl
27
28
29
    gpb = None
    if disable_shared_mem:
        _, _, _, gpb = load_partition(tmpdir / 'test_sampling.json', rank)
Jinjing Zhou's avatar
Jinjing Zhou committed
30
31
32
33
34
35
36
    dist_graph = DistGraph("rpc_sampling_ip_config.txt", "test_sampling", gpb=gpb)
    sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
    dgl.distributed.shutdown_servers()
    dgl.distributed.finalize_client()
    return sampled_graph


37
def check_rpc_sampling(tmpdir, num_server):
Jinjing Zhou's avatar
Jinjing Zhou committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    ip_config = open("rpc_sampling_ip_config.txt", "w")
    for _ in range(num_server):
        ip_config.write('{} 1\n'.format(get_local_usable_addr()))
    ip_config.close()

    g = CitationGraphDataset("cora")[0]
    g.readonly()
    print(g.idtype)
    num_parts = num_server
    num_hops = 1

    partition_graph(g, 'test_sampling', num_parts, tmpdir,
                    num_hops=num_hops, part_method='metis', reshuffle=False)

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
55
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1))
Jinjing Zhou's avatar
Jinjing Zhou committed
56
57
58
59
60
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    time.sleep(3)
61
    sampled_graph = start_client(0, tmpdir, num_server > 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    print("Done sampling")
    for p in pserver_list:
        p.join()

    src, dst = sampled_graph.edges()
    assert sampled_graph.number_of_nodes() == g.number_of_nodes()
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    assert np.array_equal(
        F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids))

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_rpc_sampling():
    import tempfile
    with tempfile.TemporaryDirectory() as tmpdirname:
        tmpdirname = "/tmp/sampling"
79
        check_rpc_sampling(Path(tmpdirname), 2)
Jinjing Zhou's avatar
Jinjing Zhou committed
80

81
def check_rpc_sampling_shuffle(tmpdir, num_server):
Jinjing Zhou's avatar
Jinjing Zhou committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    ip_config = open("rpc_sampling_ip_config.txt", "w")
    for _ in range(num_server):
        ip_config.write('{} 1\n'.format(get_local_usable_addr()))
    ip_config.close()
    
    g = CitationGraphDataset("cora")[0]
    g.readonly()
    num_parts = num_server
    num_hops = 1

    partition_graph(g, 'test_sampling', num_parts, tmpdir,
                    num_hops=num_hops, part_method='metis', reshuffle=True)

    pserver_list = []
    ctx = mp.get_context('spawn')
    for i in range(num_server):
98
        p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1))
Jinjing Zhou's avatar
Jinjing Zhou committed
99
100
101
102
103
        p.start()
        time.sleep(1)
        pserver_list.append(p)

    time.sleep(3)
104
    sampled_graph = start_client(0, tmpdir, num_server > 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    print("Done sampling")
    for p in pserver_list:
        p.join()

    orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64)
    orig_eid = F.zeros((g.number_of_edges(),), dtype=F.int64)
    for i in range(num_server):
        part, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
        orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
        orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']

    src, dst = sampled_graph.edges()
    src = orig_nid[src]
    dst = orig_nid[dst]
    assert sampled_graph.number_of_nodes() == g.number_of_nodes()
    assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
    eids = g.edge_ids(src, dst)
    eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
    assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))

# Wait non shared memory graph store
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_rpc_sampling_shuffle():
    import tempfile
    with tempfile.TemporaryDirectory() as tmpdirname:
        tmpdirname = "/tmp/sampling"
132
133
        check_rpc_sampling_shuffle(Path(tmpdirname), 2)
        check_rpc_sampling_shuffle(Path(tmpdirname), 1)
Jinjing Zhou's avatar
Jinjing Zhou committed
134
135
136
137
138

if __name__ == "__main__":
    import tempfile
    with tempfile.TemporaryDirectory() as tmpdirname:
        tmpdirname = "/tmp/sampling"
139
140
141
142
        check_rpc_sampling_shuffle(Path(tmpdirname), 1)
        check_rpc_sampling_shuffle(Path(tmpdirname), 2)
        check_rpc_sampling(Path(tmpdirname), 2)
        check_rpc_sampling(Path(tmpdirname), 1)