Unverified Commit 9eb0efcf authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[KVStore] Test multi-server and multi-client (#1732)

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update
parent 2c88f7c0
...@@ -792,6 +792,7 @@ class KVClient(object): ...@@ -792,6 +792,7 @@ class KVClient(object):
func : callable func : callable
The function to be called. The function to be called.
""" """
self.barrier()
if self._client_id == 0: if self._client_id == 0:
request = RegisterPushHandlerRequest(name, func) request = RegisterPushHandlerRequest(name, func)
# send request to all the server nodes # send request to all the server nodes
...@@ -824,6 +825,7 @@ class KVClient(object): ...@@ -824,6 +825,7 @@ class KVClient(object):
func : callable func : callable
The function to be called. The function to be called.
""" """
self.barrier()
if self._client_id == 0: if self._client_id == 0:
request = RegisterPullHandlerRequest(name, func) request = RegisterPullHandlerRequest(name, func)
# send request to all the server nodes # send request to all the server nodes
...@@ -859,6 +861,7 @@ class KVClient(object): ...@@ -859,6 +861,7 @@ class KVClient(object):
assert len(shape) > 0, 'shape cannot be empty' assert len(shape) > 0, 'shape cannot be empty'
assert policy_str in ('edge', 'node'), 'policy_str must be \'edge\' or \'node\'.' assert policy_str in ('edge', 'node'), 'policy_str must be \'edge\' or \'node\'.'
assert name not in self._data_name_list, 'data name: %s already exists.' % name assert name not in self._data_name_list, 'data name: %s already exists.' % name
self.barrier()
shape = list(shape) shape = list(shape)
if self._client_id == 0: if self._client_id == 0:
for machine_id in range(self._machine_count): for machine_id in range(self._machine_count):
...@@ -906,6 +909,7 @@ class KVClient(object): ...@@ -906,6 +909,7 @@ class KVClient(object):
self._full_data_shape[name] = tuple(shape) self._full_data_shape[name] = tuple(shape)
self._pull_handlers[name] = default_pull_handler self._pull_handlers[name] = default_pull_handler
self._push_handlers[name] = default_push_handler self._push_handlers[name] = default_push_handler
self.barrier()
def map_shared_data(self, partition_book): def map_shared_data(self, partition_book):
"""Mapping shared-memory tensor from server to client. """Mapping shared-memory tensor from server to client.
...@@ -916,6 +920,7 @@ class KVClient(object): ...@@ -916,6 +920,7 @@ class KVClient(object):
Store the partition information Store the partition information
""" """
# Get shared data from server side # Get shared data from server side
self.barrier()
request = GetSharedDataRequest(GET_SHARED_MSG) request = GetSharedDataRequest(GET_SHARED_MSG)
rpc.send_request(self._main_server_id, request) rpc.send_request(self._main_server_id, request)
response = rpc.recv_response() response = rpc.recv_response()
...@@ -957,6 +962,7 @@ class KVClient(object): ...@@ -957,6 +962,7 @@ class KVClient(object):
response = rpc.recv_response() response = rpc.recv_response()
assert response.msg == SEND_META_TO_BACKUP_MSG assert response.msg == SEND_META_TO_BACKUP_MSG
self._data_name_list.add(name) self._data_name_list.add(name)
self.barrier()
def data_name_list(self): def data_name_list(self):
"""Get all the data name""" """Get all the data name"""
......
...@@ -88,7 +88,10 @@ def init_zero_func(shape, dtype): ...@@ -88,7 +88,10 @@ def init_zero_func(shape, dtype):
return F.zeros(shape, dtype, F.cpu()) return F.zeros(shape, dtype, F.cpu())
def udf_push(target, name, id_tensor, data_tensor): def udf_push(target, name, id_tensor, data_tensor):
target[name][id_tensor] = data_tensor * data_tensor target[name][id_tensor] = data_tensor * data_tensor
def add_push(target, name, id_tensor, data_tensor):
target[name][id_tensor] += data_tensor
@unittest.skipIf(os.name == 'nt' or os.getenv('DGLBACKEND') == 'tensorflow', reason='Do not support windows and TF yet') @unittest.skipIf(os.name == 'nt' or os.getenv('DGLBACKEND') == 'tensorflow', reason='Do not support windows and TF yet')
def test_partition_policy(): def test_partition_policy():
...@@ -107,30 +110,36 @@ def test_partition_policy(): ...@@ -107,30 +110,36 @@ def test_partition_policy():
assert node_policy.get_data_size() == len(node_map) assert node_policy.get_data_size() == len(node_map)
assert edge_policy.get_data_size() == len(edge_map) assert edge_policy.get_data_size() == len(edge_map)
def start_server(): def start_server(server_id, num_clients):
# Init kvserver # Init kvserver
kvserver = dgl.distributed.KVServer(server_id=0, kvserver = dgl.distributed.KVServer(server_id=server_id,
ip_config='kv_ip_config.txt', ip_config='kv_ip_config.txt',
num_clients=1) num_clients=num_clients)
kvserver.add_part_policy(node_policy) kvserver.add_part_policy(node_policy)
kvserver.add_part_policy(edge_policy) kvserver.add_part_policy(edge_policy)
kvserver.init_data('data_0', 'node', data_0) if kvserver.is_backup_server():
kvserver.init_data('data_0_1', 'node', data_0_1) kvserver.init_data('data_0', 'node')
kvserver.init_data('data_0_2', 'node', data_0_2) kvserver.init_data('data_0_1', 'node')
kvserver.init_data('data_0_3', 'node', data_0_3) kvserver.init_data('data_0_2', 'node')
kvserver.init_data('data_0_3', 'node')
else:
kvserver.init_data('data_0', 'node', data_0)
kvserver.init_data('data_0_1', 'node', data_0_1)
kvserver.init_data('data_0_2', 'node', data_0_2)
kvserver.init_data('data_0_3', 'node', data_0_3)
# start server # start server
server_state = dgl.distributed.ServerState(kv_store=kvserver, local_g=None, partition_book=None) server_state = dgl.distributed.ServerState(kv_store=kvserver, local_g=None, partition_book=None)
dgl.distributed.start_server(server_id=0, dgl.distributed.start_server(server_id=server_id,
ip_config='kv_ip_config.txt', ip_config='kv_ip_config.txt',
num_clients=1, num_clients=num_clients,
server_state=server_state) server_state=server_state)
def start_client(): def start_client(num_clients):
# Note: connect to server first ! # Note: connect to server first !
dgl.distributed.connect_to_server(ip_config='kv_ip_config.txt') dgl.distributed.connect_to_server(ip_config='kv_ip_config.txt')
# Init kvclient # Init kvclient
kvclient = dgl.distributed.KVClient(ip_config='kv_ip_config.txt') kvclient = dgl.distributed.KVClient(ip_config='kv_ip_config.txt')
assert dgl.distributed.get_num_client() == 1 assert dgl.distributed.get_num_client() == num_clients
kvclient.init_data(name='data_1', kvclient.init_data(name='data_1',
shape=F.shape(data_1), shape=F.shape(data_1),
dtype=F.dtype(data_1), dtype=F.dtype(data_1),
...@@ -224,6 +233,7 @@ def start_client(): ...@@ -224,6 +233,7 @@ def start_client():
kvclient.push(name='data_2', kvclient.push(name='data_2',
id_tensor=id_tensor, id_tensor=id_tensor,
data_tensor=data_tensor) data_tensor=data_tensor)
kvclient.barrier()
data_tensor = data_tensor * data_tensor data_tensor = data_tensor * data_tensor
res = kvclient.pull(name='data_0', id_tensor=id_tensor) res = kvclient.pull(name='data_0', id_tensor=id_tensor)
assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor)) assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))
...@@ -231,24 +241,53 @@ def start_client(): ...@@ -231,24 +241,53 @@ def start_client():
assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor)) assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))
res = kvclient.pull(name='data_2', id_tensor=id_tensor) res = kvclient.pull(name='data_2', id_tensor=id_tensor)
assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor)) assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))
# Register new push handler
kvclient.init_data(name='data_3',
shape=F.shape(data_2),
dtype=F.dtype(data_2),
policy_str='node',
partition_book=gpb,
init_func=init_zero_func)
kvclient.register_push_handler('data_3', add_push)
kvclient.map_shared_data(partition_book=gpb)
data_tensor = F.tensor([[6.,6.],[6.,6.],[6.,6.]], F.float32)
time.sleep(kvclient.client_id + 1)
print("add...")
kvclient.push(name='data_3',
id_tensor=id_tensor,
data_tensor=data_tensor)
kvclient.barrier()
res = kvclient.pull(name='data_3', id_tensor=id_tensor)
data_tensor = data_tensor * num_clients
assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))
# clean up # clean up
kvclient.barrier()
dgl.distributed.shutdown_servers() dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client() dgl.distributed.finalize_client()
@unittest.skipIf(os.name == 'nt' or os.getenv('DGLBACKEND') == 'tensorflow', reason='Do not support windows and TF yet') @unittest.skipIf(os.name == 'nt' or os.getenv('DGLBACKEND') == 'tensorflow', reason='Do not support windows and TF yet')
def test_kv_store(): def test_kv_store():
# start 10 server and 10 client
ip_config = open("kv_ip_config.txt", "w") ip_config = open("kv_ip_config.txt", "w")
ip_addr = get_local_usable_addr() ip_addr = get_local_usable_addr()
ip_config.write('%s 1\n' % ip_addr) ip_config.write('%s 10\n' % ip_addr)
ip_config.close() ip_config.close()
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
pserver = ctx.Process(target=start_server) pserver_list = []
pclient = ctx.Process(target=start_client) pclient_list = []
pserver.start() for i in range(10):
time.sleep(1) pserver = ctx.Process(target=start_server, args=(i, 10))
pclient.start() pserver.start()
pserver.join() pserver_list.append(pserver)
pclient.join() time.sleep(2)
for i in range(10):
pclient = ctx.Process(target=start_client, args=(10,))
pclient.start()
pclient_list.append(pclient)
for i in range(10):
pclient_list[i].join()
for i in range(10):
pserver_list[i].join()
if __name__ == '__main__': if __name__ == '__main__':
test_partition_policy() test_partition_policy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment