Unverified Commit 9eb0efcf authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[KVStore] Test multi-server and multi-client (#1732)

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update
parent 2c88f7c0
......@@ -792,6 +792,7 @@ class KVClient(object):
func : callable
The function to be called.
"""
self.barrier()
if self._client_id == 0:
request = RegisterPushHandlerRequest(name, func)
# send request to all the server nodes
......@@ -824,6 +825,7 @@ class KVClient(object):
func : callable
The function to be called.
"""
self.barrier()
if self._client_id == 0:
request = RegisterPullHandlerRequest(name, func)
# send request to all the server nodes
......@@ -859,6 +861,7 @@ class KVClient(object):
assert len(shape) > 0, 'shape cannot be empty'
assert policy_str in ('edge', 'node'), 'policy_str must be \'edge\' or \'node\'.'
assert name not in self._data_name_list, 'data name: %s already exists.' % name
self.barrier()
shape = list(shape)
if self._client_id == 0:
for machine_id in range(self._machine_count):
......@@ -906,6 +909,7 @@ class KVClient(object):
self._full_data_shape[name] = tuple(shape)
self._pull_handlers[name] = default_pull_handler
self._push_handlers[name] = default_push_handler
self.barrier()
def map_shared_data(self, partition_book):
"""Mapping shared-memory tensor from server to client.
......@@ -916,6 +920,7 @@ class KVClient(object):
Store the partition information
"""
# Get shared data from server side
self.barrier()
request = GetSharedDataRequest(GET_SHARED_MSG)
rpc.send_request(self._main_server_id, request)
response = rpc.recv_response()
......@@ -957,6 +962,7 @@ class KVClient(object):
response = rpc.recv_response()
assert response.msg == SEND_META_TO_BACKUP_MSG
self._data_name_list.add(name)
self.barrier()
def data_name_list(self):
"""Get all the data name"""
......
......@@ -88,7 +88,10 @@ def init_zero_func(shape, dtype):
return F.zeros(shape, dtype, F.cpu())
def udf_push(target, name, id_tensor, data_tensor):
target[name][id_tensor] = data_tensor * data_tensor
target[name][id_tensor] = data_tensor * data_tensor
def add_push(target, name, id_tensor, data_tensor):
target[name][id_tensor] += data_tensor
@unittest.skipIf(os.name == 'nt' or os.getenv('DGLBACKEND') == 'tensorflow', reason='Do not support windows and TF yet')
def test_partition_policy():
......@@ -107,30 +110,36 @@ def test_partition_policy():
assert node_policy.get_data_size() == len(node_map)
assert edge_policy.get_data_size() == len(edge_map)
def start_server():
def start_server(server_id, num_clients):
# Init kvserver
kvserver = dgl.distributed.KVServer(server_id=0,
kvserver = dgl.distributed.KVServer(server_id=server_id,
ip_config='kv_ip_config.txt',
num_clients=1)
num_clients=num_clients)
kvserver.add_part_policy(node_policy)
kvserver.add_part_policy(edge_policy)
kvserver.init_data('data_0', 'node', data_0)
kvserver.init_data('data_0_1', 'node', data_0_1)
kvserver.init_data('data_0_2', 'node', data_0_2)
kvserver.init_data('data_0_3', 'node', data_0_3)
if kvserver.is_backup_server():
kvserver.init_data('data_0', 'node')
kvserver.init_data('data_0_1', 'node')
kvserver.init_data('data_0_2', 'node')
kvserver.init_data('data_0_3', 'node')
else:
kvserver.init_data('data_0', 'node', data_0)
kvserver.init_data('data_0_1', 'node', data_0_1)
kvserver.init_data('data_0_2', 'node', data_0_2)
kvserver.init_data('data_0_3', 'node', data_0_3)
# start server
server_state = dgl.distributed.ServerState(kv_store=kvserver, local_g=None, partition_book=None)
dgl.distributed.start_server(server_id=0,
dgl.distributed.start_server(server_id=server_id,
ip_config='kv_ip_config.txt',
num_clients=1,
num_clients=num_clients,
server_state=server_state)
def start_client():
def start_client(num_clients):
# Note: connect to server first !
dgl.distributed.connect_to_server(ip_config='kv_ip_config.txt')
# Init kvclient
kvclient = dgl.distributed.KVClient(ip_config='kv_ip_config.txt')
assert dgl.distributed.get_num_client() == 1
assert dgl.distributed.get_num_client() == num_clients
kvclient.init_data(name='data_1',
shape=F.shape(data_1),
dtype=F.dtype(data_1),
......@@ -224,6 +233,7 @@ def start_client():
kvclient.push(name='data_2',
id_tensor=id_tensor,
data_tensor=data_tensor)
kvclient.barrier()
data_tensor = data_tensor * data_tensor
res = kvclient.pull(name='data_0', id_tensor=id_tensor)
assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))
......@@ -231,24 +241,53 @@ def start_client():
assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))
res = kvclient.pull(name='data_2', id_tensor=id_tensor)
assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))
# Register new push handler
kvclient.init_data(name='data_3',
shape=F.shape(data_2),
dtype=F.dtype(data_2),
policy_str='node',
partition_book=gpb,
init_func=init_zero_func)
kvclient.register_push_handler('data_3', add_push)
kvclient.map_shared_data(partition_book=gpb)
data_tensor = F.tensor([[6.,6.],[6.,6.],[6.,6.]], F.float32)
time.sleep(kvclient.client_id + 1)
print("add...")
kvclient.push(name='data_3',
id_tensor=id_tensor,
data_tensor=data_tensor)
kvclient.barrier()
res = kvclient.pull(name='data_3', id_tensor=id_tensor)
data_tensor = data_tensor * num_clients
assert_array_equal(F.asnumpy(res), F.asnumpy(data_tensor))
# clean up
kvclient.barrier()
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
@unittest.skipIf(os.name == 'nt' or os.getenv('DGLBACKEND') == 'tensorflow', reason='Do not support windows and TF yet')
def test_kv_store():
# start 10 server and 10 client
ip_config = open("kv_ip_config.txt", "w")
ip_addr = get_local_usable_addr()
ip_config.write('%s 1\n' % ip_addr)
ip_config.write('%s 10\n' % ip_addr)
ip_config.close()
ctx = mp.get_context('spawn')
pserver = ctx.Process(target=start_server)
pclient = ctx.Process(target=start_client)
pserver.start()
time.sleep(1)
pclient.start()
pserver.join()
pclient.join()
pserver_list = []
pclient_list = []
for i in range(10):
pserver = ctx.Process(target=start_server, args=(i, 10))
pserver.start()
pserver_list.append(pserver)
time.sleep(2)
for i in range(10):
pclient = ctx.Process(target=start_client, args=(10,))
pclient.start()
pclient_list.append(pclient)
for i in range(10):
pclient_list[i].join()
for i in range(10):
pserver_list[i].join()
if __name__ == '__main__':
test_partition_policy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment