Unverified Commit ddf8c858 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[KVStore] Small fix on kvstore (#1284)

* update

* str to int

* update

* update

* update

* test remove sync

* update

* update

* update

* update

* update

* update

* update

* update

* clear
parent 97e79265
...@@ -44,7 +44,6 @@ def start_client(args): ...@@ -44,7 +44,6 @@ def start_client(args):
if my_client.get_id() % args.num_worker == 0: if my_client.get_id() % args.num_worker == 0:
my_client.set_partition_book(name='entity_embed', partition_book=partition) my_client.set_partition_book(name='entity_embed', partition_book=partition)
else: else:
time.sleep(3)
my_client.set_partition_book(name='entity_embed') my_client.set_partition_book(name='entity_embed')
my_client.print() my_client.print()
...@@ -53,8 +52,9 @@ def start_client(args): ...@@ -53,8 +52,9 @@ def start_client(args):
print("send request...") print("send request...")
for i in range(4): for i in range(100):
my_client.push(name='entity_embed', id_tensor=ID[i], data_tensor=DATA[i]) for i in range(4):
my_client.push(name='entity_embed', id_tensor=ID[i], data_tensor=DATA[i])
my_client.barrier() my_client.barrier()
...@@ -64,7 +64,8 @@ def start_client(args): ...@@ -64,7 +64,8 @@ def start_client(args):
my_client.barrier() my_client.barrier()
my_client.push(name='entity_embed', id_tensor=ID[my_client.get_machine_id()], data_tensor=mx.nd.array([[0.,0.,0.],[0.,0.,0.]])) for i in range(100):
my_client.push(name='entity_embed', id_tensor=ID[my_client.get_machine_id()], data_tensor=mx.nd.array([[0.,0.,0.],[0.,0.,0.]]))
my_client.barrier() my_client.barrier()
...@@ -72,6 +73,7 @@ def start_client(args): ...@@ -72,6 +73,7 @@ def start_client(args):
res = my_client.pull(name='entity_embed', id_tensor=mx.nd.array([0,1,2,3,4,5,6,7], dtype='int64')) res = my_client.pull(name='entity_embed', id_tensor=mx.nd.array([0,1,2,3,4,5,6,7], dtype='int64'))
print(res) print(res)
if my_client.get_id() == 0:
my_client.shut_down() my_client.shut_down()
......
0 172.31.6.94 30050 2 172.31.5.143 30050 2
1 172.31.4.10 30050 2 172.31.2.169 30050 2
2 172.31.11.99 30050 2 172.31.8.6 30050 2
3 172.31.2.252 30050 2 172.31.7.129 30050 2
\ No newline at end of file \ No newline at end of file
...@@ -43,7 +43,6 @@ def start_server(args): ...@@ -43,7 +43,6 @@ def start_server(args):
my_server.set_global2local(name='entity_embed', global2local=g2l[my_server.get_machine_id()]) my_server.set_global2local(name='entity_embed', global2local=g2l[my_server.get_machine_id()])
my_server.init_data(name='entity_embed', data_tensor=data[my_server.get_machine_id()]) my_server.init_data(name='entity_embed', data_tensor=data[my_server.get_machine_id()])
else: else:
time.sleep(3)
my_server.set_global2local(name='entity_embed') my_server.set_global2local(name='entity_embed')
my_server.init_data(name='entity_embed') my_server.init_data(name='entity_embed')
......
...@@ -44,7 +44,6 @@ def start_client(args): ...@@ -44,7 +44,6 @@ def start_client(args):
if my_client.get_id() % args.num_worker == 0: if my_client.get_id() % args.num_worker == 0:
my_client.set_partition_book(name='entity_embed', partition_book=partition) my_client.set_partition_book(name='entity_embed', partition_book=partition)
else: else:
time.sleep(3)
my_client.set_partition_book(name='entity_embed') my_client.set_partition_book(name='entity_embed')
my_client.print() my_client.print()
...@@ -53,8 +52,9 @@ def start_client(args): ...@@ -53,8 +52,9 @@ def start_client(args):
print("send request...") print("send request...")
for i in range(4): for i in range(100):
my_client.push(name='entity_embed', id_tensor=ID[i], data_tensor=DATA[i]) for i in range(4):
my_client.push(name='entity_embed', id_tensor=ID[i], data_tensor=DATA[i])
my_client.barrier() my_client.barrier()
...@@ -64,7 +64,8 @@ def start_client(args): ...@@ -64,7 +64,8 @@ def start_client(args):
my_client.barrier() my_client.barrier()
my_client.push(name='entity_embed', id_tensor=ID[my_client.get_machine_id()], data_tensor=th.tensor([[0.,0.,0.],[0.,0.,0.]])) for i in range(100):
my_client.push(name='entity_embed', id_tensor=ID[my_client.get_machine_id()], data_tensor=th.tensor([[0.,0.,0.],[0.,0.,0.]]))
my_client.barrier() my_client.barrier()
...@@ -72,6 +73,7 @@ def start_client(args): ...@@ -72,6 +73,7 @@ def start_client(args):
res = my_client.pull(name='entity_embed', id_tensor=th.tensor([0,1,2,3,4,5,6,7])) res = my_client.pull(name='entity_embed', id_tensor=th.tensor([0,1,2,3,4,5,6,7]))
print(res) print(res)
if my_client.get_id() == 0:
my_client.shut_down() my_client.shut_down()
......
0 172.31.6.94 30050 2 172.31.5.143 30050 2
1 172.31.4.10 30050 2 172.31.2.169 30050 2
2 172.31.11.99 30050 2 172.31.8.6 30050 2
3 172.31.2.252 30050 2 172.31.7.129 30050 2
\ No newline at end of file \ No newline at end of file
...@@ -43,7 +43,6 @@ def start_server(args): ...@@ -43,7 +43,6 @@ def start_server(args):
my_server.set_global2local(name='entity_embed', global2local=g2l[my_server.get_machine_id()]) my_server.set_global2local(name='entity_embed', global2local=g2l[my_server.get_machine_id()])
my_server.init_data(name='entity_embed', data_tensor=data[my_server.get_machine_id()]) my_server.init_data(name='entity_embed', data_tensor=data[my_server.get_machine_id()])
else: else:
time.sleep(3)
my_server.set_global2local(name='entity_embed') my_server.set_global2local(name='entity_embed')
my_server.init_data(name='entity_embed') my_server.init_data(name='entity_embed')
......
...@@ -21,22 +21,19 @@ if os.name != 'nt': ...@@ -21,22 +21,19 @@ if os.name != 'nt':
import struct import struct
GARBAGE_COLLECTION_COUNT = 2000 # Perform grabage collection when message count is large than 2000
def read_ip_config(filename): def read_ip_config(filename):
"""Read network configuration information of kvstore from file. """Read network configuration information of kvstore from file.
The format of configuration file should be: The format of configuration file should be:
[machine_id] [ip] [base_port] [server_count] [ip] [base_port] [server_count]
0 172.31.40.143 30050 2 172.31.40.143 30050 2
1 172.31.36.140 30050 2 172.31.36.140 30050 2
2 172.31.47.147 30050 2 172.31.47.147 30050 2
3 172.31.30.180 30050 2 172.31.30.180 30050 2
Note that, DGL KVStore supports multiple servers that can shared data with each other Note that, DGL KVStore supports multiple servers that can share data with each other
on the same machine via shared-tensor. So the server_count should be >= 1. on the same machine via shared-tensor. So the server_count should be >= 1.
Parameters Parameters
...@@ -49,16 +46,16 @@ def read_ip_config(filename): ...@@ -49,16 +46,16 @@ def read_ip_config(filename):
dict dict
server namebook. e.g., server namebook. e.g.,
[server_id]:[machine_id, ip, port] [server_id]:[machine_id, ip, port, group_count]
{0:'[0, 172.31.40.143, 30050], {0:'[0, 172.31.40.143, 30050, 2],
1:'[0, 172.31.40.143, 30051], 1:'[0, 172.31.40.143, 30051, 2],
2:'[1, 172.31.36.140, 30050], 2:'[1, 172.31.36.140, 30050, 2],
3:'[1, 172.31.36.140, 30051], 3:'[1, 172.31.36.140, 30051, 2],
4:'[2, 172.31.47.147, 30050], 4:'[2, 172.31.47.147, 30050, 2],
5:'[2, 172.31.47.147, 30051], 5:'[2, 172.31.47.147, 30051, 2],
6:'[3, 172.31.30.180, 30050], 6:'[3, 172.31.30.180, 30050, 2],
7:'[3, 172.31.30.180, 30051]} 7:'[3, 172.31.30.180, 30051, 2]}
""" """
assert len(filename) > 0, 'filename cannot be empty.' assert len(filename) > 0, 'filename cannot be empty.'
...@@ -66,12 +63,14 @@ def read_ip_config(filename): ...@@ -66,12 +63,14 @@ def read_ip_config(filename):
try: try:
server_id = 0 server_id = 0
machine_id = 0
lines = [line.rstrip('\n') for line in open(filename)] lines = [line.rstrip('\n') for line in open(filename)]
for line in lines: for line in lines:
machine_id, ip, port, server_count = line.split(' ') ip, port, server_count = line.split(' ')
for s_count in range(int(server_count)): for s_count in range(int(server_count)):
server_namebook[server_id] = [int(machine_id), ip, int(port)+s_count] server_namebook[server_id] = [int(machine_id), ip, int(port)+s_count, int(server_count)]
server_id += 1 server_id += 1
machine_id += 1
except: except:
print("Error: data format on each line should be: [machine_id] [ip] [base_port] [server_count]") print("Error: data format on each line should be: [machine_id] [ip] [base_port] [server_count]")
...@@ -86,7 +85,7 @@ class KVServer(object): ...@@ -86,7 +85,7 @@ class KVServer(object):
and _pull_handler() API to support flexibale algorithms. and _pull_handler() API to support flexibale algorithms.
DGL kvstore supports multiple-servers on single-machine. That means we can lunach many servers on the same machine and all of DGL kvstore supports multiple-servers on single-machine. That means we can lunach many servers on the same machine and all of
these servers will share the same shared-memory tensor. these servers will share the same shared-memory tensor for load-balance.
Note that, DO NOT use KVServer in multiple threads on Python because this behavior is not defined. Note that, DO NOT use KVServer in multiple threads on Python because this behavior is not defined.
...@@ -113,20 +112,20 @@ class KVServer(object): ...@@ -113,20 +112,20 @@ class KVServer(object):
Total number of client nodes. Total number of client nodes.
queue_size : int queue_size : int
Sise (bytes) of kvstore message queue buffer (~20 GB on default). Sise (bytes) of kvstore message queue buffer (~20 GB on default).
Note that the 20 GB is just an upper-bound and system will not allocate 20GB memory at once. Note that the 20 GB is just an upper-bound number and DGL will not allocate 20GB memory.
net_type : str net_type : str
networking type, e.g., 'socket' (default) or 'mpi' (do not support yet). networking type, e.g., 'socket' (default) or 'mpi' (do not support yet).
""" """
def __init__(self, server_id, server_namebook, num_client, queue_size=2*1024*1024*1024, net_type='socket'): def __init__(self, server_id, server_namebook, num_client, queue_size=20*1024*1024*1024, net_type='socket'):
assert server_id >= 0, 'server_id (%d) cannot be a negative number.' % server_id assert server_id >= 0, 'server_id (%d) cannot be a negative number.' % server_id
assert len(server_namebook) > 0, 'server_namebook cannot be empty.' assert len(server_namebook) > 0, 'server_namebook cannot be empty.'
assert num_client >= 0, 'num_client (%d) cannot be a negative number.' % num_client assert num_client >= 0, 'num_client (%d) cannot be a negative number.' % num_client
assert queue_size > 0, 'queue_size cannot be a negative number.' assert queue_size > 0, 'queue_size (%d) cannot be a negative number.' % queue_size
assert net_type == 'socket' or net_type == 'mpi', 'net_type (%s) can only be \'socket\' or \'mpi\'.' % net_type assert net_type == 'socket' or net_type == 'mpi', 'net_type (%s) can only be \'socket\' or \'mpi\'.' % net_type
# check if target data has been initialized # check if target data has been initialized
self._has_data = set() self._has_data = set()
# Store the tensor data with data name # Store the tensor data with specified data name
self._data_store = {} self._data_store = {}
# Used for barrier() API on KVClient # Used for barrier() API on KVClient
self._barrier_count = 0 self._barrier_count = 0
...@@ -136,15 +135,14 @@ class KVServer(object): ...@@ -136,15 +135,14 @@ class KVServer(object):
self._machine_id = server_namebook[server_id][0] self._machine_id = server_namebook[server_id][0]
self._ip = server_namebook[server_id][1] self._ip = server_namebook[server_id][1]
self._port = server_namebook[server_id][2] self._port = server_namebook[server_id][2]
self._group_count = self._get_group_count() self._group_count = server_namebook[server_id][3]
# client_namebook will be sent from remote client nodes # client_namebook will be sent from remote client nodes
self._client_namebook = {} self._client_namebook = {}
self._client_count = num_client self._client_count = num_client
# Create C communicator of sender and receiver # Create C communicator of sender and receiver
self._sender = _create_sender(net_type, queue_size) self._sender = _create_sender(net_type, queue_size)
self._receiver = _create_receiver(net_type, queue_size) self._receiver = _create_receiver(net_type, queue_size)
# A naive garbage collocetion for kvstore # Delete temp file when kvstore service is closed
self._garbage_msg = []
self._open_file_list = [] self._open_file_list = []
# record for total message count # record for total message count
self._msg_count = 0 self._msg_count = 0
...@@ -156,14 +154,14 @@ class KVServer(object): ...@@ -156,14 +154,14 @@ class KVServer(object):
# Finalize C communicator of sender and receiver # Finalize C communicator of sender and receiver
_finalize_sender(self._sender) _finalize_sender(self._sender)
_finalize_receiver(self._receiver) _finalize_receiver(self._receiver)
# Delete temp file # Delete temp file when kvstore service is closed
for file in self._open_file_list: for file in self._open_file_list:
if(os.path.exists(file)): if (os.path.exists(file)):
os.remove(file) os.remove(file)
def set_global2local(self, name, global2local=None): def set_global2local(self, name, global2local=None):
"""Set data of global ID to local ID. """Set data mapping of global ID to local ID.
Parameters Parameters
---------- ----------
...@@ -184,9 +182,16 @@ class KVServer(object): ...@@ -184,9 +182,16 @@ class KVServer(object):
dlpack = shared_data.to_dlpack() dlpack = shared_data.to_dlpack()
self._data_store[name+'-g2l-'] = F.zerocopy_from_dlpack(dlpack) self._data_store[name+'-g2l-'] = F.zerocopy_from_dlpack(dlpack)
self._data_store[name+'-g2l-'][:] = global2local[:] self._data_store[name+'-g2l-'][:] = global2local[:]
# write data information to temp file that can be read by other processes
self._write_data_shape(name+'-g2l-shape', global2local) self._write_data_shape(name+'-g2l-shape', global2local)
self._open_file_list.append(name+'-g2l-shape') self._open_file_list.append(name+'-g2l-shape')
else: # Read shared-tensor else: # Read shared-tensor
while True:
if (os.path.exists(name+'-g2l-shape')):
time.sleep(2) # wait writing finish
break
else:
time.sleep(2) # wait until the file been created
data_shape = self._read_data_shape(name+'-g2l-shape') data_shape = self._read_data_shape(name+'-g2l-shape')
shared_data = empty_shared_mem(name+'-g2l-', False, data_shape, 'int64') shared_data = empty_shared_mem(name+'-g2l-', False, data_shape, 'int64')
dlpack = shared_data.to_dlpack() dlpack = shared_data.to_dlpack()
...@@ -217,6 +222,12 @@ class KVServer(object): ...@@ -217,6 +222,12 @@ class KVServer(object):
self._write_data_shape(name+'-data-shape', data_tensor) self._write_data_shape(name+'-data-shape', data_tensor)
self._open_file_list.append(name+'-data-shape') self._open_file_list.append(name+'-data-shape')
else: # Read shared-tensor else: # Read shared-tensor
while True:
if (os.path.exists(name+'-data-shape')):
time.sleep(2) # wait writing finish
break
else:
time.sleep(2) # wait until the file been created
data_shape = self._read_data_shape(name+'-data-shape') data_shape = self._read_data_shape(name+'-data-shape')
shared_data = empty_shared_mem(name+'-data-', False, data_shape, 'float32') shared_data = empty_shared_mem(name+'-data-', False, data_shape, 'float32')
dlpack = shared_data.to_dlpack() dlpack = shared_data.to_dlpack()
...@@ -268,6 +279,7 @@ class KVServer(object): ...@@ -268,6 +279,7 @@ class KVServer(object):
""" """
return self._group_count return self._group_count
def get_message_count(self): def get_message_count(self):
"""Get total message count on current KVServer """Get total message count on current KVServer
...@@ -282,17 +294,17 @@ class KVServer(object): ...@@ -282,17 +294,17 @@ class KVServer(object):
def print(self): def print(self):
"""Print server information (Used by debug) """Print server information (Used by debug)
""" """
print("----------") print("----- KVStore Info -----")
print("server id: %d" % self.get_id()) print("server id: %d" % self.get_id())
print("data:") print("data:")
for name, data in self._data_store.items(): for name, data in self._data_store.items():
print(name) print(name)
print(data) print(data)
print("---------") print("------------------------")
def start(self): def start(self):
"""Start service of KVServer """Start service of KVServer.
The start() api performs the following things: The start() api performs the following things:
...@@ -342,9 +354,7 @@ class KVServer(object): ...@@ -342,9 +354,7 @@ class KVServer(object):
shared_tensor = '' shared_tensor = ''
for name in self._has_data: for name in self._has_data:
shared_tensor += self._serialize_shared_tensor( shared_tensor += self._serialize_shared_tensor(
name, name, F.dtype(self._data_store[name]))
F.shape(self._data_store[name]),
F.dtype(self._data_store[name]))
shared_tensor += '|' shared_tensor += '|'
msg = KVStoreMsg( msg = KVStoreMsg(
...@@ -406,24 +416,18 @@ class KVServer(object): ...@@ -406,24 +416,18 @@ class KVServer(object):
else: else:
raise RuntimeError('Unknown type of kvstore message: %d' % msg.type.value) raise RuntimeError('Unknown type of kvstore message: %d' % msg.type.value)
# garbage collection _clear_kv_msg(msg)
self._garbage_msg.append(msg)
if len(self._garbage_msg) > GARBAGE_COLLECTION_COUNT:
_clear_kv_msg(self._garbage_msg)
self._garbage_msg = []
self._msg_count += 1 self._msg_count += 1
def _serialize_shared_tensor(self, name, shape, dtype): def _serialize_shared_tensor(self, name, dtype):
"""Serialize shared tensor information. """Serialize shared tensor information.
Parameters Parameters
---------- ----------
name : str name : str
tensor name tensor name
shape : tuple of int
tensor shape
dtype : str dtype : str
data type data type
...@@ -433,13 +437,9 @@ class KVServer(object): ...@@ -433,13 +437,9 @@ class KVServer(object):
serialized string serialized string
""" """
assert len(name) > 0, 'data name cannot be empty.' assert len(name) > 0, 'data name cannot be empty.'
assert len(shape) > 0, 'data shape cannot be empty.'
str_data = name str_data = name
str_data += '/' str_data += '/'
for s in shape:
str_data += str(s)
str_data += '/'
if 'float32' in str(dtype): if 'float32' in str(dtype):
str_data += 'float32' str_data += 'float32'
elif 'int64' in str(dtype): elif 'int64' in str(dtype):
...@@ -460,6 +460,8 @@ class KVServer(object): ...@@ -460,6 +460,8 @@ class KVServer(object):
data : tensor (mx.ndarray or torch.tensor) data : tensor (mx.ndarray or torch.tensor)
data tensor data tensor
""" """
assert len(filename) > 0, 'filename cannot be empty.'
if(os.path.exists(filename)): if(os.path.exists(filename)):
os.remove(filename) os.remove(filename)
...@@ -486,38 +488,19 @@ class KVServer(object): ...@@ -486,38 +488,19 @@ class KVServer(object):
tuple tuple
data shape data shape
""" """
assert len(filename) > 0, 'filename cannot be empty.'
f = open(filename, "r") f = open(filename, "r")
str_data = f.read() str_data = f.read()
data_list = str_data.split('|') data_list = str_data.split('|')
data_shape = [] data_shape = []
for i in range(len(data_list)-1): for i in range(len(data_list)-1):
data_shape.append(int(data_list[i])) data_shape.append(int(data_list[i]))
f.close() f.close()
return data_shape return data_shape
def _get_group_count(self):
"""Get count of backup server
Return
------
int
count of backup server
"""
group_count = 0
pre_id = 0
for ID, data in self._server_namebook.items():
machine_id = data[0]
if machine_id != pre_id:
break
group_count += 1
pre_id = machine_id
return group_count
def _push_handler(self, name, ID, data, target): def _push_handler(self, name, ID, data, target):
"""Default handler for PUSH message. """Default handler for PUSH message.
...@@ -561,7 +544,7 @@ class KVServer(object): ...@@ -561,7 +544,7 @@ class KVServer(object):
class KVClient(object): class KVClient(object):
"""KVClient is used to push/pull tensors to/from KVServer. If the server node and client node are on the """KVClient is used to push/pull tensors to/from KVServer. If the server node and client node are on the
same machine, they can commuincate with each other using shared-memory tensor, instead of TCP/IP connections. same machine, they can commuincate with each other using local shared-memory tensor, instead of TCP/IP connections.
Note that, DO NOT use KVClient in multiple threads on Python because this behavior is not defined. Note that, DO NOT use KVClient in multiple threads on Python because this behavior is not defined.
...@@ -571,25 +554,25 @@ class KVClient(object): ...@@ -571,25 +554,25 @@ class KVClient(object):
---------- ----------
server_namebook: dict server_namebook: dict
IP address namebook of KVServer, where key is the KVServer's ID IP address namebook of KVServer, where key is the KVServer's ID
(start from 0) and value is the server's machine_id, IP address and port, e.g., (start from 0) and value is the server's machine_id, IP address and port, and group_count, e.g.,
{0:'[0, 172.31.40.143, 30050], {0:'[0, 172.31.40.143, 30050, 2],
1:'[0, 172.31.40.143, 30051], 1:'[0, 172.31.40.143, 30051, 2],
2:'[1, 172.31.36.140, 30050], 2:'[1, 172.31.36.140, 30050, 2],
3:'[1, 172.31.36.140, 30051], 3:'[1, 172.31.36.140, 30051, 2],
4:'[2, 172.31.47.147, 30050], 4:'[2, 172.31.47.147, 30050, 2],
5:'[2, 172.31.47.147, 30051], 5:'[2, 172.31.47.147, 30051, 2],
6:'[3, 172.31.30.180, 30050], 6:'[3, 172.31.30.180, 30050, 2],
7:'[3, 172.31.30.180, 30051]} 7:'[3, 172.31.30.180, 30051, 2]}
queue_size : int queue_size : int
Sise (bytes) of kvstore message queue buffer (~20 GB on default). Sise (bytes) of kvstore message queue buffer (~20 GB on default).
net_type : str net_type : str
networking type, e.g., 'socket' (default) or 'mpi'. networking type, e.g., 'socket' (default) or 'mpi'.
""" """
def __init__(self, server_namebook, queue_size=2*1024*1024*1024, net_type='socket'): def __init__(self, server_namebook, queue_size=20*1024*1024*1024, net_type='socket'):
assert len(server_namebook) > 0, 'server_namebook cannot be empty.' assert len(server_namebook) > 0, 'server_namebook cannot be empty.'
assert queue_size > 0, 'queue_size cannot be a negative number.' assert queue_size > 0, 'queue_size (%d) cannot be a negative number.' % queue_size
assert net_type == 'socket' or net_type == 'mpi', 'net_type (%s) can only be \'socket\' or \'mpi\'.' % net_type assert net_type == 'socket' or net_type == 'mpi', 'net_type (%s) can only be \'socket\' or \'mpi\'.' % net_type
# check if target data has been initialized # check if target data has been initialized
...@@ -599,17 +582,18 @@ class KVClient(object): ...@@ -599,17 +582,18 @@ class KVClient(object):
# Server information # Server information
self._server_namebook = server_namebook self._server_namebook = server_namebook
self._server_count = len(server_namebook) self._server_count = len(server_namebook)
self._group_count = self._get_group_count() self._group_count = server_namebook[0][3]
# client ID will be assign by server after connecting to server # client ID will be assign by server after connecting to server
self._client_id = -1 self._client_id = -1
# Get local machine id via server_namebook # Get local machine id via server_namebook
self._machine_id = self._get_machine_id() self._machine_id = self._get_local_machine_id()
# create C communicator of sender and receiver # create C communicator of sender and receiver
self._sender = _create_sender(net_type, queue_size) self._sender = _create_sender(net_type, queue_size)
self._receiver = _create_receiver(net_type, queue_size) self._receiver = _create_receiver(net_type, queue_size)
# A naive garbage collocetion for kvstore # Delete temp file when kvstore service is closed
self._garbage_msg = []
self._open_file_list = [] self._open_file_list = []
# Gargage_collection
self._garbage_msg = []
# Used load-balance # Used load-balance
random.seed(time.time()) random.seed(time.time())
...@@ -620,14 +604,14 @@ class KVClient(object): ...@@ -620,14 +604,14 @@ class KVClient(object):
# finalize C communicator of sender and receiver # finalize C communicator of sender and receiver
_finalize_sender(self._sender) _finalize_sender(self._sender)
_finalize_receiver(self._receiver) _finalize_receiver(self._receiver)
# Delete temp file # Delete temp file whhen kvstore service is closed
for file in self._open_file_list: for file in self._open_file_list:
if(os.path.exists(file)): if(os.path.exists(file)):
os.remove(file) os.remove(file)
def set_partition_book(self, name, partition_book=None): def set_partition_book(self, name, partition_book=None):
"""Partition book contains the mapping of global ID to machine ID. """Partition book contains the data mapping of global ID to machine ID.
Parameters Parameters
---------- ----------
...@@ -636,7 +620,7 @@ class KVClient(object): ...@@ -636,7 +620,7 @@ class KVClient(object):
partition_book : list or tensor (mx.ndarray or torch.tensor) partition_book : list or tensor (mx.ndarray or torch.tensor)
Mapping global ID to target machine ID. Mapping global ID to target machine ID.
Note that, if the partition_book is None KVClient will read shared-tensor. Note that, if the partition_book is None KVClient will read shared-tensor by name.
""" """
assert len(name) > 0, 'name connot be empty.' assert len(name) > 0, 'name connot be empty.'
...@@ -650,6 +634,12 @@ class KVClient(object): ...@@ -650,6 +634,12 @@ class KVClient(object):
self._write_data_shape(name+'-part-shape', partition_book) self._write_data_shape(name+'-part-shape', partition_book)
self._open_file_list.append(name+'-part-shape') self._open_file_list.append(name+'-part-shape')
else: # Read shared-tensor else: # Read shared-tensor
while True:
if (os.path.exists(name+'-part-shape')):
time.sleep(2) # wait writing finish
break
else:
time.sleep(2) # wait until the file been created
data_shape = self._read_data_shape(name+'-part-shape') data_shape = self._read_data_shape(name+'-part-shape')
shared_data = empty_shared_mem(name+'-part-', False, data_shape, 'int64') shared_data = empty_shared_mem(name+'-part-', False, data_shape, 'int64')
dlpack = shared_data.to_dlpack() dlpack = shared_data.to_dlpack()
...@@ -704,7 +694,14 @@ class KVClient(object): ...@@ -704,7 +694,14 @@ class KVClient(object):
data_str = msg.name.split('|') data_str = msg.name.split('|')
for data in data_str: for data in data_str:
if data != '': if data != '':
tensor_name, shape, dtype = self._deserialize_shared_tensor(data) tensor_name, dtype = self._deserialize_shared_tensor(data)
while True:
if (os.path.exists(tensor_name+'shape')):
time.sleep(2) # wait writing finish
break
else:
time.sleep(2) # wait until the file been created
shape = self._read_data_shape(tensor_name+'shape')
shared_data = empty_shared_mem(tensor_name, False, shape, dtype) shared_data = empty_shared_mem(tensor_name, False, shape, dtype)
dlpack = shared_data.to_dlpack() dlpack = shared_data.to_dlpack()
self._data_store[tensor_name] = F.zerocopy_from_dlpack(dlpack) self._data_store[tensor_name] = F.zerocopy_from_dlpack(dlpack)
...@@ -716,13 +713,13 @@ class KVClient(object): ...@@ -716,13 +713,13 @@ class KVClient(object):
def print(self): def print(self):
"""Print client information (Used by debug) """Print client information (Used by debug)
""" """
print("----------") print("----- KVClient Info -----")
print("client id: %d" % self.get_id()) print("client id: %d" % self.get_id())
print("data:") print("data:")
for name, data in self._data_store.items(): for name, data in self._data_store.items():
print(name) print(name)
print(data) print(data)
print("----------") print("-------------------------")
def get_id(self): def get_id(self):
...@@ -794,6 +791,8 @@ class KVClient(object): ...@@ -794,6 +791,8 @@ class KVClient(object):
partial_id = id_tensor[start:end] partial_id = id_tensor[start:end]
partial_data = data_tensor[start:end] partial_data = data_tensor[start:end]
if machine[idx] == self._machine_id: # local push if machine[idx] == self._machine_id: # local push
# Note that DO NOT push local data right now because we can overlap
# communication-local_push here
if (name+'-g2l-' in self._has_data) == True: if (name+'-g2l-' in self._has_data) == True:
local_id = self._data_store[name+'-g2l-'][partial_id] local_id = self._data_store[name+'-g2l-'][partial_id]
else: else:
...@@ -813,7 +812,7 @@ class KVClient(object): ...@@ -813,7 +812,7 @@ class KVClient(object):
start += count[idx] start += count[idx]
if local_id is not None: if local_id is not None: # local push
self._push_handler(name+'-data-', local_id, local_data, self._data_store) self._push_handler(name+'-data-', local_id, local_data, self._data_store)
...@@ -835,9 +834,9 @@ class KVClient(object): ...@@ -835,9 +834,9 @@ class KVClient(object):
assert len(name) > 0, 'name cannot be empty.' assert len(name) > 0, 'name cannot be empty.'
assert F.ndim(id_tensor) == 1, 'ID must be a vector.' assert F.ndim(id_tensor) == 1, 'ID must be a vector.'
if len(self._garbage_msg) > GARBAGE_COLLECTION_COUNT: for msg in self._garbage_msg:
_clear_kv_msg(self._garbage_msg) _clear_kv_msg(msg)
self._garbage_msg = [] self._garbage_msg = []
# partition data # partition data
machine_id = self._data_store[name+'-part-'][id_tensor] machine_id = self._data_store[name+'-part-'][id_tensor]
...@@ -850,13 +849,14 @@ class KVClient(object): ...@@ -850,13 +849,14 @@ class KVClient(object):
start = 0 start = 0
pull_count = 0 pull_count = 0
local_id = None local_id = None
for idx in range(len(machine)): for idx in range(len(machine)):
end = start + count[idx] end = start + count[idx]
if start == end: # No data for target machine if start == end: # No data for target machine
continue continue
partial_id = id_tensor[start:end] partial_id = id_tensor[start:end]
if machine[idx] == self._machine_id: # local pull if machine[idx] == self._machine_id: # local pull
# Note that DO NOT pull local data right now because we can overlap
# communication-local_pull here
if (name+'-g2l-' in self._has_data) == True: if (name+'-g2l-' in self._has_data) == True:
local_id = self._data_store[name+'-g2l-'][partial_id] local_id = self._data_store[name+'-g2l-'][partial_id]
else: else:
...@@ -869,6 +869,7 @@ class KVClient(object): ...@@ -869,6 +869,7 @@ class KVClient(object):
id=partial_id, id=partial_id,
data=None, data=None,
c_ptr=None) c_ptr=None)
# randomly select a server node in target machine for load-balance
s_id = random.randint(machine[idx]*self._group_count, (machine[idx]+1)*self._group_count-1) s_id = random.randint(machine[idx]*self._group_count, (machine[idx]+1)*self._group_count-1)
_send_kv_msg(self._sender, msg, s_id) _send_kv_msg(self._sender, msg, s_id)
pull_count += 1 pull_count += 1
...@@ -876,8 +877,7 @@ class KVClient(object): ...@@ -876,8 +877,7 @@ class KVClient(object):
start += count[idx] start += count[idx]
msg_list = [] msg_list = []
if local_id is not None: # local pull
if local_id is not None:
local_data = self._pull_handler(name+'-data-', local_id, self._data_store) local_data = self._pull_handler(name+'-data-', local_id, self._data_store)
s_id = random.randint(self._machine_id*self._group_count, (self._machine_id+1)*self._group_count-1) s_id = random.randint(self._machine_id*self._group_count, (self._machine_id+1)*self._group_count-1)
local_msg = KVStoreMsg( local_msg = KVStoreMsg(
...@@ -896,7 +896,7 @@ class KVClient(object): ...@@ -896,7 +896,7 @@ class KVClient(object):
msg_list.append(remote_msg) msg_list.append(remote_msg)
self._garbage_msg.append(remote_msg) self._garbage_msg.append(remote_msg)
# sort msg by server id # sort msg by server id and merge tensor together
msg_list.sort(key=self._takeId) msg_list.sort(key=self._takeId)
data_tensor = F.cat(seq=[msg.data for msg in msg_list], dim=0) data_tensor = F.cat(seq=[msg.data for msg in msg_list], dim=0)
...@@ -967,27 +967,7 @@ class KVClient(object): ...@@ -967,27 +967,7 @@ class KVClient(object):
return IP + ':' + str(port) return IP + ':' + str(port)
def _get_group_count(self): def _get_local_machine_id(self):
"""Get count of backup server
Return
------
int
count of backup server
"""
group_count = 0
pre_id = 0
for ID, data in self._server_namebook.items():
machine_id = data[0]
if machine_id != pre_id:
break
group_count += 1
pre_id = machine_id
return group_count
def _get_machine_id(self):
"""Get local machine ID from server_namebook """Get local machine ID from server_namebook
Return Return
...@@ -1035,20 +1015,14 @@ class KVClient(object): ...@@ -1035,20 +1015,14 @@ class KVClient(object):
------- -------
str str
tensor name tensor name
tuple of int
tensor shape
str str
data type data type
""" """
data_list = data.split('/') data_list = data.split('/')
tensor_name = data_list[0] tensor_name = data_list[0]
data_type = data_list[-1] data_type = data_list[-1]
tensor_shape = []
for i in range(1, len(data_list)-1):
tensor_shape.append(int(data_list[i]))
tensor_shape = tuple(tensor_shape)
return tensor_name, tensor_shape, data_type return tensor_name, data_type
def _write_data_shape(self, filename, data): def _write_data_shape(self, filename, data):
...@@ -1061,6 +1035,8 @@ class KVClient(object): ...@@ -1061,6 +1035,8 @@ class KVClient(object):
data : tensor (mx.ndarray or torch.tensor) data : tensor (mx.ndarray or torch.tensor)
data tensor data tensor
""" """
assert len(filename) > 0, 'filename cannot be empty.'
if(os.path.exists(filename)): if(os.path.exists(filename)):
os.remove(filename) os.remove(filename)
...@@ -1087,6 +1063,8 @@ class KVClient(object): ...@@ -1087,6 +1063,8 @@ class KVClient(object):
tuple tuple
data shape data shape
""" """
assert len(filename) > 0, 'filename cannot be empty.'
f = open(filename, "r") f = open(filename, "r")
str_data = f.read() str_data = f.read()
data_list = str_data.split('|') data_list = str_data.split('|')
......
...@@ -321,11 +321,9 @@ def _recv_kv_msg(receiver): ...@@ -321,11 +321,9 @@ def _recv_kv_msg(receiver):
raise RuntimeError('Unknown message type: %d' % msg_type.value) raise RuntimeError('Unknown message type: %d' % msg_type.value)
def _clear_kv_msg(garbage_msg): def _clear_kv_msg(msg):
"""Clear data of kvstore message """Clear data of kvstore message
""" """
F.sync() F.sync()
for msg in garbage_msg: if msg.c_ptr is not None:
if msg.c_ptr is not None: _CAPI_DeleteKVMsg(msg.c_ptr)
_CAPI_DeleteKVMsg(msg.c_ptr)
garbage_msg = []
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment