"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "330b6c9b4c71f0ba455767ee7c661da98003b3e8"
Commit 372203f0 authored by Da Zheng's avatar Da Zheng Committed by Minjie Wang
Browse files

[Test] Fix tests in test_shared_mem_store. (#588)

* fix test.

* better assert.

* more asserts.

* print to stderr.

* destroy g.

* fix tests.

* add timeout in sync_barrier.

* test _sync_barrier.

* fix.

* avoid printing messages.

* fix test.

* fix test.

* fix.
parent 05548248
import os import os
import sys
import time import time
import scipy import scipy
from xmlrpc.server import SimpleXMLRPCServer from xmlrpc.server import SimpleXMLRPCServer
...@@ -394,7 +395,7 @@ class SharedMemoryStoreServer(object): ...@@ -394,7 +395,7 @@ class SharedMemoryStoreServer(object):
def all_enter(worker_id, barrier_id): def all_enter(worker_id, barrier_id):
return self._barrier.all_enter(worker_id, barrier_id) return self._barrier.all_enter(worker_id, barrier_id)
self.server = SimpleXMLRPCServer(("127.0.0.1", port)) self.server = SimpleXMLRPCServer(("127.0.0.1", port), logRequests=False)
self.server.register_function(register, "register") self.server.register_function(register, "register")
self.server.register_function(get_graph_info, "get_graph_info") self.server.register_function(get_graph_info, "get_graph_info")
self.server.register_function(init_ndata, "init_ndata") self.server.register_function(init_ndata, "init_ndata")
...@@ -616,14 +617,29 @@ class SharedMemoryDGLGraph(BaseGraphStore): ...@@ -616,14 +617,29 @@ class SharedMemoryDGLGraph(BaseGraphStore):
""" """
return self._worker_id return self._worker_id
def _sync_barrier(self): def _sync_barrier(self, timeout=None):
"""This is a sync barrier among all workers.
Parameters
----------
timeout: int
time out in seconds.
"""
# Here I manually implement multi-processing barrier with RPC. # Here I manually implement multi-processing barrier with RPC.
# It uses busy wait with RPC. Whenever, all_enter is called, there is # It uses busy wait with RPC. Whenever, all_enter is called, there is
# a context switch, so it doesn't burn CPUs so badly. # a context switch, so it doesn't burn CPUs so badly.
# if timeout isn't specified, we wait forever.
if timeout is None:
timeout = sys.maxsize
bid = self.proxy.enter_barrier(self._worker_id) bid = self.proxy.enter_barrier(self._worker_id)
while not self.proxy.all_enter(self._worker_id, bid): start = time.time()
while not self.proxy.all_enter(self._worker_id, bid) and time.time() - start < timeout:
continue continue
self.proxy.leave_barrier(self._worker_id, bid) self.proxy.leave_barrier(self._worker_id, bid)
if time.time() - start >= timeout and not self.proxy.all_enter(self._worker_id, bid):
raise TimeoutError("leave the sync barrier because of timeout.")
def init_ndata(self, ndata_name, shape, dtype, ctx=F.cpu()): def init_ndata(self, ndata_name, shape, dtype, ctx=F.cpu()):
"""Create node embedding. """Create node embedding.
......
...@@ -10,6 +10,7 @@ import backend as F ...@@ -10,6 +10,7 @@ import backend as F
import unittest import unittest
import dgl.function as fn import dgl.function as fn
import traceback import traceback
from numpy.testing import assert_almost_equal
num_nodes = 100 num_nodes = 100
num_edges = int(num_nodes * num_nodes * 0.1) num_edges = int(num_nodes * num_nodes * 0.1)
...@@ -20,55 +21,61 @@ def check_array_shared_memory(g, worker_id, arrays): ...@@ -20,55 +21,61 @@ def check_array_shared_memory(g, worker_id, arrays):
if worker_id == 0: if worker_id == 0:
for i, arr in enumerate(arrays): for i, arr in enumerate(arrays):
arr[0] = i arr[0] = i
g._sync_barrier() g._sync_barrier(60)
else: else:
g._sync_barrier() g._sync_barrier(60)
for i, arr in enumerate(arrays): for i, arr in enumerate(arrays):
assert np.all(arr[0].asnumpy() == i) assert_almost_equal(arr[0].asnumpy(), i)
def _check_init_func(worker_id, graph_name):
time.sleep(3)
print("worker starts")
np.random.seed(0)
csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64)
def create_graph_store(graph_name):
for _ in range(10): for _ in range(10):
try: try:
g = dgl.contrib.graph_store.create_graph_from_store(graph_name, "shared_mem", g = dgl.contrib.graph_store.create_graph_from_store(graph_name, "shared_mem",
port=rand_port) port=rand_port)
break return g
except: except ConnectionError as e:
print("fail to connect to the graph store server.") traceback.print_exc()
time.sleep(1) time.sleep(1)
# Verify the graph structure loaded from the shared memory. return None
src, dst = g.all_edges()
coo = csr.tocoo()
assert F.array_equal(dst, F.tensor(coo.row))
assert F.array_equal(src, F.tensor(coo.col))
assert F.array_equal(g.nodes[0].data['feat'], F.tensor(np.arange(10), dtype=np.float32))
assert F.array_equal(g.edges[0].data['feat'], F.tensor(np.arange(10), dtype=np.float32))
g.init_ndata('test4', (g.number_of_nodes(), 10), 'float32')
g.init_edata('test4', (g.number_of_edges(), 10), 'float32')
g._sync_barrier()
check_array_shared_memory(g, worker_id, [g.nodes[:].data['test4'], g.edges[:].data['test4']])
data = g.nodes[:].data['test4']
g.set_n_repr({'test4': mx.nd.ones((1, 10)) * 10}, u=[0])
assert np.all(data[0].asnumpy() == g.nodes[0].data['test4'].asnumpy())
data = g.edges[:].data['test4']
g.set_e_repr({'test4': mx.nd.ones((1, 10)) * 20}, edges=[0])
assert np.all(data[0].asnumpy() == g.edges[0].data['test4'].asnumpy())
g.destroy()
def check_init_func(worker_id, graph_name, return_dict): def check_init_func(worker_id, graph_name, return_dict):
time.sleep(3)
print("worker starts")
np.random.seed(0)
csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64)
# Verify the graph structure loaded from the shared memory.
try: try:
_check_init_func(worker_id, graph_name) g = create_graph_store(graph_name)
if g is None:
return_dict[worker_id] = -1
return
src, dst = g.all_edges()
coo = csr.tocoo()
assert F.array_equal(dst, F.tensor(coo.row))
assert F.array_equal(src, F.tensor(coo.col))
assert F.array_equal(g.nodes[0].data['feat'], F.tensor(np.arange(10), dtype=np.float32))
assert F.array_equal(g.edges[0].data['feat'], F.tensor(np.arange(10), dtype=np.float32))
g.init_ndata('test4', (g.number_of_nodes(), 10), 'float32')
g.init_edata('test4', (g.number_of_edges(), 10), 'float32')
g._sync_barrier(60)
check_array_shared_memory(g, worker_id, [g.nodes[:].data['test4'], g.edges[:].data['test4']])
data = g.nodes[:].data['test4']
g.set_n_repr({'test4': mx.nd.ones((1, 10)) * 10}, u=[0])
assert_almost_equal(data[0].asnumpy(), np.squeeze(g.nodes[0].data['test4'].asnumpy()))
data = g.edges[:].data['test4']
g.set_e_repr({'test4': mx.nd.ones((1, 10)) * 20}, edges=[0])
assert_almost_equal(data[0].asnumpy(), np.squeeze(g.edges[0].data['test4'].asnumpy()))
g.destroy()
return_dict[worker_id] = 0 return_dict[worker_id] = 0
except Exception as e: except Exception as e:
return_dict[worker_id] = -1 return_dict[worker_id] = -1
print(e) g.destroy()
print(e, file=sys.stderr)
traceback.print_exc() traceback.print_exc()
def server_func(num_workers, graph_name): def server_func(num_workers, graph_name):
...@@ -100,58 +107,53 @@ def test_init(): ...@@ -100,58 +107,53 @@ def test_init():
assert return_dict[worker_id] == 0, "worker %d fails" % worker_id assert return_dict[worker_id] == 0, "worker %d fails" % worker_id
def _check_compute_func(worker_id, graph_name): def check_compute_func(worker_id, graph_name, return_dict):
time.sleep(3) time.sleep(3)
print("worker starts") print("worker starts")
for _ in range(10):
try:
g = dgl.contrib.graph_store.create_graph_from_store(graph_name, "shared_mem",
port=rand_port)
break
except:
print("fail to connect to the graph store server.")
time.sleep(1)
g._sync_barrier()
in_feats = g.nodes[0].data['feat'].shape[1]
# Test update all.
g.update_all(fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='preprocess'))
adj = g.adjacency_matrix()
tmp = mx.nd.dot(adj, g.nodes[:].data['feat'])
assert np.all((g.nodes[:].data['preprocess'] == tmp).asnumpy())
g._sync_barrier()
check_array_shared_memory(g, worker_id, [g.nodes[:].data['preprocess']])
# Test apply nodes.
data = g.nodes[:].data['feat']
g.apply_nodes(func=lambda nodes: {'feat': mx.nd.ones((1, in_feats)) * 10}, v=0)
assert np.all(data[0].asnumpy() == g.nodes[0].data['feat'].asnumpy())
# Test apply edges.
data = g.edges[:].data['feat']
g.apply_edges(func=lambda edges: {'feat': mx.nd.ones((1, in_feats)) * 10}, edges=0)
assert np.all(data[0].asnumpy() == g.edges[0].data['feat'].asnumpy())
g.init_ndata('tmp', (g.number_of_nodes(), 10), 'float32')
data = g.nodes[:].data['tmp']
# Test pull
g.pull(1, fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='tmp'))
assert np.all(data[1].asnumpy() == g.nodes[1].data['preprocess'].asnumpy())
# Test send_and_recv
in_edges = g.in_edges(v=2)
g.send_and_recv(in_edges, fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='tmp'))
assert np.all(data[2].asnumpy() == g.nodes[2].data['preprocess'].asnumpy())
g.destroy()
def check_compute_func(worker_id, graph_name, return_dict):
try: try:
_check_compute_func(worker_id, graph_name) g = create_graph_store(graph_name)
if g is None:
return_dict[worker_id] = -1
return
g._sync_barrier(60)
in_feats = g.nodes[0].data['feat'].shape[1]
# Test update all.
g.update_all(fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='preprocess'))
adj = g.adjacency_matrix()
tmp = mx.nd.dot(adj, g.nodes[:].data['feat'])
assert_almost_equal(g.nodes[:].data['preprocess'].asnumpy(), tmp.asnumpy())
g._sync_barrier(60)
check_array_shared_memory(g, worker_id, [g.nodes[:].data['preprocess']])
# Test apply nodes.
data = g.nodes[:].data['feat']
g.apply_nodes(func=lambda nodes: {'feat': mx.nd.ones((1, in_feats)) * 10}, v=0)
assert_almost_equal(data[0].asnumpy(), np.squeeze(g.nodes[0].data['feat'].asnumpy()))
# Test apply edges.
data = g.edges[:].data['feat']
g.apply_edges(func=lambda edges: {'feat': mx.nd.ones((1, in_feats)) * 10}, edges=0)
assert_almost_equal(data[0].asnumpy(), np.squeeze(g.edges[0].data['feat'].asnumpy()))
g.init_ndata('tmp', (g.number_of_nodes(), 10), 'float32')
data = g.nodes[:].data['tmp']
# Test pull
g.pull(1, fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='tmp'))
assert_almost_equal(data[1].asnumpy(), np.squeeze(g.nodes[1].data['preprocess'].asnumpy()))
# Test send_and_recv
in_edges = g.in_edges(v=2)
g.send_and_recv(in_edges, fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='tmp'))
assert_almost_equal(data[2].asnumpy(), np.squeeze(g.nodes[2].data['preprocess'].asnumpy()))
g.destroy()
return_dict[worker_id] = 0 return_dict[worker_id] = 0
except Exception as e: except Exception as e:
return_dict[worker_id] = -1 return_dict[worker_id] = -1
print(e) g.destroy()
print(e, file=sys.stderr)
traceback.print_exc() traceback.print_exc()
def test_compute(): def test_compute():
...@@ -169,6 +171,52 @@ def test_compute(): ...@@ -169,6 +171,52 @@ def test_compute():
for worker_id in return_dict.keys(): for worker_id in return_dict.keys():
assert return_dict[worker_id] == 0, "worker %d fails" % worker_id assert return_dict[worker_id] == 0, "worker %d fails" % worker_id
def check_sync_barrier(worker_id, graph_name, return_dict):
time.sleep(3)
print("worker starts")
try:
g = create_graph_store(graph_name)
if g is None:
return_dict[worker_id] = -1
return
if worker_id == 1:
g.destroy()
return_dict[worker_id] = 0
return
start = time.time()
try:
g._sync_barrier(10)
except TimeoutError as e:
# this is very loose.
print("timeout: " + str(abs(time.time() - start)), file=sys.stderr)
assert 5 < abs(time.time() - start) < 15
g.destroy()
return_dict[worker_id] = 0
except Exception as e:
return_dict[worker_id] = -1
g.destroy()
print(e, file=sys.stderr)
traceback.print_exc()
def test_sync_barrier():
manager = Manager()
return_dict = manager.dict()
serv_p = Process(target=server_func, args=(2, 'test_graph4'))
work_p1 = Process(target=check_sync_barrier, args=(0, 'test_graph4', return_dict))
work_p2 = Process(target=check_sync_barrier, args=(1, 'test_graph4', return_dict))
serv_p.start()
work_p1.start()
work_p2.start()
serv_p.join()
work_p1.join()
work_p2.join()
for worker_id in return_dict.keys():
assert return_dict[worker_id] == 0, "worker %d fails" % worker_id
if __name__ == '__main__': if __name__ == '__main__':
test_init() test_init()
test_sync_barrier()
test_compute() test_compute()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment