test_shared_mem_store.py 10.6 KB
Newer Older
1
2
3
4
5
""" NOTE(zihao) The unittest on shared memory store is temporally disabled because we 
have not fixed the bug described in https://github.com/dmlc/dgl/issues/755 yet.
The bug causes CI failures occasionally but does not affect other parts of DGL.
As a result, we decide to disable this test until we fixed the bug.
"""
6
import dgl
7
import sys
8
import os
9
import random
10
11
import time
import numpy as np
12
from numpy.testing import assert_array_equal
13
from multiprocessing import Process, Manager, Condition, Value
14
15
from scipy import sparse as spsp
import backend as F
16
import unittest
17
import dgl.function as fn
Da Zheng's avatar
Da Zheng committed
18
import traceback
19
from numpy.testing import assert_almost_equal
20

21

22
23
num_nodes = 100
num_edges = int(num_nodes * num_nodes * 0.1)
24
25
rand_port = random.randint(5000, 8000)
print('run graph store with port ' + str(rand_port), file=sys.stderr)
26

27
28
29
def check_array_shared_memory(g, worker_id, arrays):
    if worker_id == 0:
        for i, arr in enumerate(arrays):
30
            arr[0] = i + 10
31
        g._sync_barrier(60)
32
    else:
33
        g._sync_barrier(60)
34
        for i, arr in enumerate(arrays):
35
            assert_almost_equal(F.asnumpy(arr[0]), i + 10)
36

37
def create_graph_store(graph_name):
Da Zheng's avatar
Da Zheng committed
38
39
40
41
    for _ in range(10):
        try:
            g = dgl.contrib.graph_store.create_graph_from_store(graph_name, "shared_mem",
                                                                port=rand_port)
42
43
44
            return g
        except ConnectionError as e:
            traceback.print_exc()
Da Zheng's avatar
Da Zheng committed
45
            time.sleep(1)
46
    return None
47

Da Zheng's avatar
Da Zheng committed
48
def check_init_func(worker_id, graph_name, return_dict):
49
50
    np.random.seed(0)
    csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64)
51
    tmp_g = dgl.DGLGraph(csr, readonly=True, multigraph=False)
52
53

    # Verify the graph structure loaded from the shared memory.
Da Zheng's avatar
Da Zheng committed
54
    try:
55
56
57
58
59
        g = create_graph_store(graph_name)
        if g is None:
            return_dict[worker_id] = -1
            return

60
61
62
63
        src, dst = g.all_edges(order='srcdst')
        src1, dst1 = tmp_g.all_edges(order='srcdst')
        assert_array_equal(F.asnumpy(dst), F.asnumpy(dst1))
        assert_array_equal(F.asnumpy(src), F.asnumpy(src1))
64
65
66
67
        feat = F.asnumpy(g.nodes[0].data['feat'])
        assert_array_equal(np.squeeze(feat), np.arange(10, dtype=feat.dtype))
        feat = F.asnumpy(g.edges[0].data['feat'])
        assert_array_equal(np.squeeze(feat), np.arange(10, dtype=feat.dtype))
68
69
70
71
        g.init_ndata('test4', (g.number_of_nodes(), 10), 'float32')
        g.init_edata('test4', (g.number_of_edges(), 10), 'float32')
        g._sync_barrier(60)
        check_array_shared_memory(g, worker_id, [g.nodes[:].data['test4'], g.edges[:].data['test4']])
72
        g._sync_barrier(60)
73
74

        data = g.nodes[:].data['test4']
75
76
        g.set_n_repr({'test4': F.ones((1, 10)) * 10}, u=[0])
        assert_almost_equal(F.asnumpy(data[0]), np.squeeze(F.asnumpy(g.nodes[0].data['test4'])))
77
78

        data = g.edges[:].data['test4']
79
80
        g.set_e_repr({'test4': F.ones((1, 10)) * 20}, edges=[0])
        assert_almost_equal(F.asnumpy(data[0]), np.squeeze(F.asnumpy(g.edges[0].data['test4'])))
81
82

        g.destroy()
Da Zheng's avatar
Da Zheng committed
83
84
85
        return_dict[worker_id] = 0
    except Exception as e:
        return_dict[worker_id] = -1
86
87
        g.destroy()
        print(e, file=sys.stderr)
Da Zheng's avatar
Da Zheng committed
88
89
        traceback.print_exc()

90
def server_func(num_workers, graph_name, server_init):
91
92
93
    np.random.seed(0)
    csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64)

94
    g = dgl.contrib.graph_store.create_graph_store_server(csr, graph_name, "shared_mem", num_workers,
95
                                                          False, port=rand_port)
96
97
    assert num_nodes == g._graph.number_of_nodes()
    assert num_edges == g._graph.number_of_edges()
98
99
100
101
    nfeat = np.arange(0, num_nodes * 10).astype('float32').reshape((num_nodes, 10))
    efeat = np.arange(0, num_edges * 10).astype('float32').reshape((num_edges, 10))
    g.ndata['feat'] = F.tensor(nfeat)
    g.edata['feat'] = F.tensor(efeat)
102
    server_init.value = 1
103
104
    g.run()

105
@unittest.skipIf(True, reason="skip this test")
106
def test_init():
Da Zheng's avatar
Da Zheng committed
107
108
    manager = Manager()
    return_dict = manager.dict()
109
110
111
112
113
114
115

    # make server init before worker
    server_init = Value('i', False)
    serv_p = Process(target=server_func, args=(2, 'test_graph1', server_init))
    serv_p.start()
    while server_init.value == 0:
      time.sleep(1)
Da Zheng's avatar
Da Zheng committed
116
117
    work_p1 = Process(target=check_init_func, args=(0, 'test_graph1', return_dict))
    work_p2 = Process(target=check_init_func, args=(1, 'test_graph1', return_dict))
118
119
120
121
122
    work_p1.start()
    work_p2.start()
    serv_p.join()
    work_p1.join()
    work_p2.join()
Da Zheng's avatar
Da Zheng committed
123
124
    for worker_id in return_dict.keys():
        assert return_dict[worker_id] == 0, "worker %d fails" % worker_id
125

126
def check_compute_func(worker_id, graph_name, return_dict):
Da Zheng's avatar
Da Zheng committed
127
    try:
128
129
130
131
132
133
134
135
136
        g = create_graph_store(graph_name)
        if g is None:
            return_dict[worker_id] = -1
            return

        g._sync_barrier(60)
        in_feats = g.nodes[0].data['feat'].shape[1]
        # Test update all.
        g.update_all(fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='preprocess'))
137
        adj = g.adjacency_matrix(transpose=True)
138
139
        tmp = F.spmm(adj, g.nodes[:].data['feat'])
        assert_almost_equal(F.asnumpy(g.nodes[:].data['preprocess']), F.asnumpy(tmp))
140
141
        g._sync_barrier(60)
        check_array_shared_memory(g, worker_id, [g.nodes[:].data['preprocess']])
142
        g._sync_barrier(60)
143
144
145

        # Test apply nodes.
        data = g.nodes[:].data['feat']
146
147
        g.apply_nodes(func=lambda nodes: {'feat': F.ones((1, in_feats)) * 10}, v=0)
        assert_almost_equal(F.asnumpy(data[0]), np.squeeze(F.asnumpy(g.nodes[0].data['feat'])))
148
149
150

        # Test apply edges.
        data = g.edges[:].data['feat']
151
152
        g.apply_edges(func=lambda edges: {'feat': F.ones((1, in_feats)) * 10}, edges=0)
        assert_almost_equal(F.asnumpy(data[0]), np.squeeze(F.asnumpy(g.edges[0].data['feat'])))
153
154
155
156
157

        g.init_ndata('tmp', (g.number_of_nodes(), 10), 'float32')
        data = g.nodes[:].data['tmp']
        # Test pull
        g.pull(1, fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='tmp'))
158
        assert_almost_equal(F.asnumpy(data[1]), np.squeeze(F.asnumpy(g.nodes[1].data['preprocess'])))
159
160
161
162

        # Test send_and_recv
        in_edges = g.in_edges(v=2)
        g.send_and_recv(in_edges, fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='tmp'))
163
        assert_almost_equal(F.asnumpy(data[2]), np.squeeze(F.asnumpy(g.nodes[2].data['preprocess'])))
164
165

        g.destroy()
Da Zheng's avatar
Da Zheng committed
166
167
168
        return_dict[worker_id] = 0
    except Exception as e:
        return_dict[worker_id] = -1
169
170
        g.destroy()
        print(e, file=sys.stderr)
Da Zheng's avatar
Da Zheng committed
171
172
        traceback.print_exc()

173
@unittest.skipIf(True, reason="skip this test")
174
def test_compute():
Da Zheng's avatar
Da Zheng committed
175
176
    manager = Manager()
    return_dict = manager.dict()
177
178
179
180
181
182
183

    # make server init before worker
    server_init = Value('i', 0)
    serv_p = Process(target=server_func, args=(2, 'test_graph3', server_init))
    serv_p.start()
    while server_init.value == 0:
      time.sleep(1)
Da Zheng's avatar
Da Zheng committed
184
185
    work_p1 = Process(target=check_compute_func, args=(0, 'test_graph3', return_dict))
    work_p2 = Process(target=check_compute_func, args=(1, 'test_graph3', return_dict))
186
187
188
189
190
    work_p1.start()
    work_p2.start()
    serv_p.join()
    work_p1.join()
    work_p2.join()
Da Zheng's avatar
Da Zheng committed
191
192
    for worker_id in return_dict.keys():
        assert return_dict[worker_id] == 0, "worker %d fails" % worker_id
193

194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def check_sync_barrier(worker_id, graph_name, return_dict):
    try:
        g = create_graph_store(graph_name)
        if g is None:
            return_dict[worker_id] = -1
            return

        if worker_id == 1:
            g.destroy()
            return_dict[worker_id] = 0
            return

        start = time.time()
        try:
            g._sync_barrier(10)
        except TimeoutError as e:
            # this is very loose.
            print("timeout: " + str(abs(time.time() - start)), file=sys.stderr)
            assert 5 < abs(time.time() - start) < 15
        g.destroy()
        return_dict[worker_id] = 0
    except Exception as e:
        return_dict[worker_id] = -1
        g.destroy()
        print(e, file=sys.stderr)
        traceback.print_exc()

221
@unittest.skipIf(True, reason="skip this test")
222
223
224
def test_sync_barrier():
    manager = Manager()
    return_dict = manager.dict()
225
226
227
228
229
230
231

    # make server init before worker
    server_init = Value('i', 0)
    serv_p = Process(target=server_func, args=(2, 'test_graph4', server_init))
    serv_p.start()
    while server_init.value == 0:
      time.sleep(1)
232
233
234
235
236
237
238
239
240
241
    work_p1 = Process(target=check_sync_barrier, args=(0, 'test_graph4', return_dict))
    work_p2 = Process(target=check_sync_barrier, args=(1, 'test_graph4', return_dict))
    work_p1.start()
    work_p2.start()
    serv_p.join()
    work_p1.join()
    work_p2.join()
    for worker_id in return_dict.keys():
        assert return_dict[worker_id] == 0, "worker %d fails" % worker_id

242
243
244
def create_mem(gidx, cond_v, shared_v):
    # serialize create_mem before check_mem
    cond_v.acquire()
245
    gidx1 = gidx.copyto_shared_mem("test_graph5")
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    shared_v.value = 1;
    cond_v.notify()
    cond_v.release()

    # sync for exit
    cond_v.acquire()
    while shared_v.value == 1:
      cond_v.wait()
    cond_v.release()

def check_mem(gidx, cond_v, shared_v):
    # check_mem should run after create_mem
    cond_v.acquire()
    while shared_v.value == 0:
      cond_v.wait()
    cond_v.release()
262

263
    gidx1 = dgl.graph_index.from_shared_mem_graph_index("test_graph5")
264
265
    in_csr = gidx.adjacency_matrix_scipy(True, "csr")
    out_csr = gidx.adjacency_matrix_scipy(False, "csr")
266

267
    in_csr1 = gidx1.adjacency_matrix_scipy(True, "csr")
268
269
    assert_array_equal(in_csr.indptr, in_csr1.indptr)
    assert_array_equal(in_csr.indices, in_csr1.indices)
270
    out_csr1 = gidx1.adjacency_matrix_scipy(False, "csr")
271
272
273
    assert_array_equal(out_csr.indptr, out_csr1.indptr)
    assert_array_equal(out_csr.indices, out_csr1.indices)

274
    gidx1 = gidx1.copyto_shared_mem("test_graph5")
275

276
277
278
279
280
281
    #sync for exit
    cond_v.acquire()
    shared_v.value = 0;
    cond_v.notify()
    cond_v.release()

282
@unittest.skipIf(True, reason="skip this test")
283
284
def test_copy_shared_mem():
    csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64)
285
    gidx = dgl.graph_index.create_graph_index(csr, True)
286
287
288
289
290

    cond_v = Condition()
    shared_v = Value('i', 0)
    p1 = Process(target=create_mem, args=(gidx, cond_v, shared_v))
    p2 = Process(target=check_mem, args=(gidx, cond_v, shared_v))
291
292
293
294
295
    p1.start()
    p2.start()
    p1.join()
    p2.join()

296
297
298
299
300
301
# Skip test this file
#if __name__ == '__main__':
#    test_copy_shared_mem()
#    test_init()
#    test_sync_barrier()
#    test_compute()