test_shared_mem_store.py 5.1 KB
Newer Older
1
import dgl
2
3
import sys
import random
4
5
6
7
8
9
import time
import numpy as np
from multiprocessing import Process
from scipy import sparse as spsp
import mxnet as mx
import backend as F
10
import unittest
11
import dgl.function as fn
12
13
14

num_nodes = 100
num_edges = int(num_nodes * num_nodes * 0.1)
15
16
rand_port = random.randint(5000, 8000)
print('run graph store with port ' + str(rand_port), file=sys.stderr)
17

18
19
20
21
22
23
24
25
26
27
28
def check_array_shared_memory(g, worker_id, arrays):
    if worker_id == 0:
        for i, arr in enumerate(arrays):
            arr[0] = i
        g._sync_barrier()
    else:
        g._sync_barrier()
        for i, arr in enumerate(arrays):
            assert np.all(arr[0].asnumpy() == i)

def check_init_func(worker_id, graph_name):
29
30
31
32
33
    time.sleep(3)
    print("worker starts")
    np.random.seed(0)
    csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64)

34
    g = dgl.contrib.graph_store.create_graph_from_store(graph_name, "shared_mem", port=rand_port)
35
36
37
38
39
    # Verify the graph structure loaded from the shared memory.
    src, dst = g.all_edges()
    coo = csr.tocoo()
    assert F.array_equal(dst, F.tensor(coo.row))
    assert F.array_equal(src, F.tensor(coo.col))
40
41
    assert F.array_equal(g.nodes[0].data['feat'], F.tensor(np.arange(10), dtype=np.float32))
    assert F.array_equal(g.edges[0].data['feat'], F.tensor(np.arange(10), dtype=np.float32))
42
43
44
    g.init_ndata('test4', (g.number_of_nodes(), 10), 'float32')
    g.init_edata('test4', (g.number_of_edges(), 10), 'float32')
    g._sync_barrier()
45
    check_array_shared_memory(g, worker_id, [g.nodes[:].data['test4'], g.edges[:].data['test4']])
46
47
48
49
50
51
52
53
54

    data = g.nodes[:].data['test4']
    g.set_n_repr({'test4': mx.nd.ones((1, 10)) * 10}, u=[0])
    assert np.all(data[0].asnumpy() == g.nodes[0].data['test4'].asnumpy())

    data = g.edges[:].data['test4']
    g.set_e_repr({'test4': mx.nd.ones((1, 10)) * 20}, edges=[0])
    assert np.all(data[0].asnumpy() == g.edges[0].data['test4'].asnumpy())

55
56
    g.destroy()

57
def server_func(num_workers, graph_name):
58
59
60
61
    print("server starts")
    np.random.seed(0)
    csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64)

62
63
    g = dgl.contrib.graph_store.create_graph_store_server(csr, graph_name, "shared_mem", num_workers,
                                                          False, edge_dir="in", port=rand_port)
64
65
66
67
68
69
    assert num_nodes == g._graph.number_of_nodes()
    assert num_edges == g._graph.number_of_edges()
    g.ndata['feat'] = mx.nd.arange(num_nodes * 10).reshape((num_nodes, 10))
    g.edata['feat'] = mx.nd.arange(num_edges * 10).reshape((num_edges, 10))
    g.run()

70
def test_init():
71
72
73
74
75
76
77
78
79
80
81
    serv_p = Process(target=server_func, args=(2, 'test_graph1'))
    work_p1 = Process(target=check_init_func, args=(0, 'test_graph1'))
    work_p2 = Process(target=check_init_func, args=(1, 'test_graph1'))
    serv_p.start()
    work_p1.start()
    work_p2.start()
    serv_p.join()
    work_p1.join()
    work_p2.join()


82
def check_compute_func(worker_id, graph_name):
83
84
85
86
    time.sleep(3)
    print("worker starts")
    g = dgl.contrib.graph_store.create_graph_from_store(graph_name, "shared_mem", port=rand_port)
    g._sync_barrier()
87
88
89
    in_feats = g.nodes[0].data['feat'].shape[1]

    # Test update all.
90
    g.update_all(fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='preprocess'))
91
    adj = g.adjacency_matrix()
92
93
    tmp = mx.nd.dot(adj, g.nodes[:].data['feat'])
    assert np.all((g.nodes[:].data['preprocess'] == tmp).asnumpy())
94
    g._sync_barrier()
95
    check_array_shared_memory(g, worker_id, [g.nodes[:].data['preprocess']])
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

    # Test apply nodes.
    data = g.nodes[:].data['feat']
    g.apply_nodes(func=lambda nodes: {'feat': mx.nd.ones((1, in_feats)) * 10}, v=0)
    assert np.all(data[0].asnumpy() == g.nodes[0].data['feat'].asnumpy())

    # Test apply edges.
    data = g.edges[:].data['feat']
    g.apply_edges(func=lambda edges: {'feat': mx.nd.ones((1, in_feats)) * 10}, edges=0)
    assert np.all(data[0].asnumpy() == g.edges[0].data['feat'].asnumpy())

    g.init_ndata('tmp', (g.number_of_nodes(), 10), 'float32')
    data = g.nodes[:].data['tmp']
    # Test pull
    assert np.all(data[1].asnumpy() != g.nodes[1].data['preprocess'].asnumpy())
    g.pull(1, fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='tmp'))
    assert np.all(data[1].asnumpy() == g.nodes[1].data['preprocess'].asnumpy())

    # Test send_and_recv
    # TODO(zhengda) it seems the test fails because send_and_recv has a bug
    #in_edges = g.in_edges(v=2)
    #assert np.all(data[2].asnumpy() != g.nodes[2].data['preprocess'].asnumpy())
    #g.send_and_recv(in_edges, fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='tmp'))
    #assert np.all(data[2].asnumpy() == g.nodes[2].data['preprocess'].asnumpy())

121
122
    g.destroy()

123
def test_compute():
124
    serv_p = Process(target=server_func, args=(2, 'test_graph3'))
125
126
    work_p1 = Process(target=check_compute_func, args=(0, 'test_graph3'))
    work_p2 = Process(target=check_compute_func, args=(1, 'test_graph3'))
127
128
129
130
131
132
133
134
    serv_p.start()
    work_p1.start()
    work_p2.start()
    serv_p.join()
    work_p1.join()
    work_p2.join()

if __name__ == '__main__':
135
    test_init()
136
    test_compute()