test_shared_mem_store.py 6.25 KB
Newer Older
1
import dgl
2
3
import sys
import random
4
5
import time
import numpy as np
Da Zheng's avatar
Da Zheng committed
6
from multiprocessing import Process, Manager
7
8
9
from scipy import sparse as spsp
import mxnet as mx
import backend as F
10
import unittest
11
import dgl.function as fn
Da Zheng's avatar
Da Zheng committed
12
import traceback
13
14
15

num_nodes = 100
num_edges = int(num_nodes * num_nodes * 0.1)
16
17
rand_port = random.randint(5000, 8000)
print('run graph store with port ' + str(rand_port), file=sys.stderr)
18

19
20
21
22
23
24
25
26
27
28
def check_array_shared_memory(g, worker_id, arrays):
    if worker_id == 0:
        for i, arr in enumerate(arrays):
            arr[0] = i
        g._sync_barrier()
    else:
        g._sync_barrier()
        for i, arr in enumerate(arrays):
            assert np.all(arr[0].asnumpy() == i)

Da Zheng's avatar
Da Zheng committed
29
def _check_init_func(worker_id, graph_name):
30
31
32
33
34
    time.sleep(3)
    print("worker starts")
    np.random.seed(0)
    csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64)

Da Zheng's avatar
Da Zheng committed
35
36
37
38
39
40
41
42
    for _ in range(10):
        try:
            g = dgl.contrib.graph_store.create_graph_from_store(graph_name, "shared_mem",
                                                                port=rand_port)
            break
        except:
            print("fail to connect to the graph store server.")
            time.sleep(1)
43
44
45
46
47
    # Verify the graph structure loaded from the shared memory.
    src, dst = g.all_edges()
    coo = csr.tocoo()
    assert F.array_equal(dst, F.tensor(coo.row))
    assert F.array_equal(src, F.tensor(coo.col))
48
49
    assert F.array_equal(g.nodes[0].data['feat'], F.tensor(np.arange(10), dtype=np.float32))
    assert F.array_equal(g.edges[0].data['feat'], F.tensor(np.arange(10), dtype=np.float32))
50
51
52
    g.init_ndata('test4', (g.number_of_nodes(), 10), 'float32')
    g.init_edata('test4', (g.number_of_edges(), 10), 'float32')
    g._sync_barrier()
53
    check_array_shared_memory(g, worker_id, [g.nodes[:].data['test4'], g.edges[:].data['test4']])
54
55
56
57
58
59
60
61
62

    data = g.nodes[:].data['test4']
    g.set_n_repr({'test4': mx.nd.ones((1, 10)) * 10}, u=[0])
    assert np.all(data[0].asnumpy() == g.nodes[0].data['test4'].asnumpy())

    data = g.edges[:].data['test4']
    g.set_e_repr({'test4': mx.nd.ones((1, 10)) * 20}, edges=[0])
    assert np.all(data[0].asnumpy() == g.edges[0].data['test4'].asnumpy())

63
64
    g.destroy()

Da Zheng's avatar
Da Zheng committed
65
66
67
68
69
70
71
72
73
def check_init_func(worker_id, graph_name, return_dict):
    try:
        _check_init_func(worker_id, graph_name)
        return_dict[worker_id] = 0
    except Exception as e:
        return_dict[worker_id] = -1
        print(e)
        traceback.print_exc()

74
def server_func(num_workers, graph_name):
75
76
77
78
    print("server starts")
    np.random.seed(0)
    csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64)

79
80
    g = dgl.contrib.graph_store.create_graph_store_server(csr, graph_name, "shared_mem", num_workers,
                                                          False, edge_dir="in", port=rand_port)
81
82
83
84
85
86
    assert num_nodes == g._graph.number_of_nodes()
    assert num_edges == g._graph.number_of_edges()
    g.ndata['feat'] = mx.nd.arange(num_nodes * 10).reshape((num_nodes, 10))
    g.edata['feat'] = mx.nd.arange(num_edges * 10).reshape((num_edges, 10))
    g.run()

87
def test_init():
Da Zheng's avatar
Da Zheng committed
88
89
    manager = Manager()
    return_dict = manager.dict()
90
    serv_p = Process(target=server_func, args=(2, 'test_graph1'))
Da Zheng's avatar
Da Zheng committed
91
92
    work_p1 = Process(target=check_init_func, args=(0, 'test_graph1', return_dict))
    work_p2 = Process(target=check_init_func, args=(1, 'test_graph1', return_dict))
93
94
95
96
97
98
    serv_p.start()
    work_p1.start()
    work_p2.start()
    serv_p.join()
    work_p1.join()
    work_p2.join()
Da Zheng's avatar
Da Zheng committed
99
100
    for worker_id in return_dict.keys():
        assert return_dict[worker_id] == 0, "worker %d fails" % worker_id
101
102


Da Zheng's avatar
Da Zheng committed
103
def _check_compute_func(worker_id, graph_name):
104
105
    time.sleep(3)
    print("worker starts")
Da Zheng's avatar
Da Zheng committed
106
107
108
109
110
111
112
113
    for _ in range(10):
        try:
            g = dgl.contrib.graph_store.create_graph_from_store(graph_name, "shared_mem",
                                                                port=rand_port)
            break
        except:
            print("fail to connect to the graph store server.")
            time.sleep(1)
114
    g._sync_barrier()
115
116
117
    in_feats = g.nodes[0].data['feat'].shape[1]

    # Test update all.
118
    g.update_all(fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='preprocess'))
119
    adj = g.adjacency_matrix()
120
121
    tmp = mx.nd.dot(adj, g.nodes[:].data['feat'])
    assert np.all((g.nodes[:].data['preprocess'] == tmp).asnumpy())
122
    g._sync_barrier()
123
    check_array_shared_memory(g, worker_id, [g.nodes[:].data['preprocess']])
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141

    # Test apply nodes.
    data = g.nodes[:].data['feat']
    g.apply_nodes(func=lambda nodes: {'feat': mx.nd.ones((1, in_feats)) * 10}, v=0)
    assert np.all(data[0].asnumpy() == g.nodes[0].data['feat'].asnumpy())

    # Test apply edges.
    data = g.edges[:].data['feat']
    g.apply_edges(func=lambda edges: {'feat': mx.nd.ones((1, in_feats)) * 10}, edges=0)
    assert np.all(data[0].asnumpy() == g.edges[0].data['feat'].asnumpy())

    g.init_ndata('tmp', (g.number_of_nodes(), 10), 'float32')
    data = g.nodes[:].data['tmp']
    # Test pull
    g.pull(1, fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='tmp'))
    assert np.all(data[1].asnumpy() == g.nodes[1].data['preprocess'].asnumpy())

    # Test send_and_recv
142
143
144
    in_edges = g.in_edges(v=2)
    g.send_and_recv(in_edges, fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='tmp'))
    assert np.all(data[2].asnumpy() == g.nodes[2].data['preprocess'].asnumpy())
145

146
147
    g.destroy()

Da Zheng's avatar
Da Zheng committed
148
149
150
151
152
153
154
155
156
def check_compute_func(worker_id, graph_name, return_dict):
    try:
        _check_compute_func(worker_id, graph_name)
        return_dict[worker_id] = 0
    except Exception as e:
        return_dict[worker_id] = -1
        print(e)
        traceback.print_exc()

157
def test_compute():
Da Zheng's avatar
Da Zheng committed
158
159
    manager = Manager()
    return_dict = manager.dict()
160
    serv_p = Process(target=server_func, args=(2, 'test_graph3'))
Da Zheng's avatar
Da Zheng committed
161
162
    work_p1 = Process(target=check_compute_func, args=(0, 'test_graph3', return_dict))
    work_p2 = Process(target=check_compute_func, args=(1, 'test_graph3', return_dict))
163
164
165
166
167
168
    serv_p.start()
    work_p1.start()
    work_p2.start()
    serv_p.join()
    work_p1.join()
    work_p2.join()
Da Zheng's avatar
Da Zheng committed
169
170
    for worker_id in return_dict.keys():
        assert return_dict[worker_id] == 0, "worker %d fails" % worker_id
171
172

if __name__ == '__main__':
173
    test_init()
174
    test_compute()