Unverified Commit afb7f5b8 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] fix init_data in kvstore. (#1750)



* fix

* fix.

* update.

* fix

* add assert
Co-authored-by: default avatarChao Ma <mctt90@gmail.com>
parent 2f6ab43a
...@@ -404,7 +404,8 @@ class DistGraph: ...@@ -404,7 +404,8 @@ class DistGraph:
assert shape[0] == self.number_of_nodes() assert shape[0] == self.number_of_nodes()
if init_func is None: if init_func is None:
init_func = _default_init_data init_func = _default_init_data
self._client.init_data(_get_ndata_name(name), shape, dtype, 'node', self._gpb, init_func) policy = PartitionPolicy('node', self._gpb)
self._client.init_data(_get_ndata_name(name), shape, dtype, policy, init_func)
self._ndata._add(name) self._ndata._add(name)
def init_edata(self, name, shape, dtype, init_func=None): def init_edata(self, name, shape, dtype, init_func=None):
...@@ -434,7 +435,8 @@ class DistGraph: ...@@ -434,7 +435,8 @@ class DistGraph:
assert shape[0] == self.number_of_edges() assert shape[0] == self.number_of_edges()
if init_func is None: if init_func is None:
init_func = _default_init_data init_func = _default_init_data
self._client.init_data(_get_edata_name(name), shape, dtype, 'edge', self._gpb, init_func) policy = PartitionPolicy('edge', self._gpb)
self._client.init_data(_get_edata_name(name), shape, dtype, policy, init_func)
self._edata._add(name) self._edata._add(name)
@property @property
......
...@@ -156,7 +156,7 @@ class InitDataRequest(rpc.Request): ...@@ -156,7 +156,7 @@ class InitDataRequest(rpc.Request):
def process_request(self, server_state): def process_request(self, server_state):
kv_store = server_state.kv_store kv_store = server_state.kv_store
dtype = F.data_type_dict[self.dtype] dtype = F.data_type_dict[self.dtype]
if kv_store.is_backup_server() is False: if not kv_store.is_backup_server():
data_tensor = self.init_func(self.shape, dtype) data_tensor = self.init_func(self.shape, dtype)
kv_store.init_data(name=self.name, kv_store.init_data(name=self.name,
policy_str=self.policy_str, policy_str=self.policy_str,
...@@ -403,7 +403,7 @@ class GetPartShapeRequest(rpc.Request): ...@@ -403,7 +403,7 @@ class GetPartShapeRequest(rpc.Request):
def process_request(self, server_state): def process_request(self, server_state):
kv_store = server_state.kv_store kv_store = server_state.kv_store
if kv_store.data_store.__contains__(self.name) is False: if self.name not in kv_store.data_store:
raise RuntimeError("KVServer Cannot find data tensor with name: %s" % self.name) raise RuntimeError("KVServer Cannot find data tensor with name: %s" % self.name)
data_shape = F.shape(kv_store.data_store[self.name]) data_shape = F.shape(kv_store.data_store[self.name])
res = GetPartShapeResponse(data_shape) res = GetPartShapeResponse(data_shape)
...@@ -652,7 +652,7 @@ class KVServer(object): ...@@ -652,7 +652,7 @@ class KVServer(object):
read shared-memory when client invoking get_shared_data(). read shared-memory when client invoking get_shared_data().
""" """
assert len(name) > 0, 'name cannot be empty.' assert len(name) > 0, 'name cannot be empty.'
if self._data_store.__contains__(name): if name in self._data_store:
raise RuntimeError("Data %s has already exists!" % name) raise RuntimeError("Data %s has already exists!" % name)
if data_tensor is not None: # Create shared-tensor if data_tensor is not None: # Create shared-tensor
data_type = F.reverse_data_type_dict[F.dtype(data_tensor)] data_type = F.reverse_data_type_dict[F.dtype(data_tensor)]
...@@ -839,7 +839,7 @@ class KVClient(object): ...@@ -839,7 +839,7 @@ class KVClient(object):
self._pull_handlers[name] = func self._pull_handlers[name] = func
self.barrier() self.barrier()
def init_data(self, name, shape, dtype, policy_str, partition_book, init_func): def init_data(self, name, shape, dtype, part_policy, init_func):
"""Send message to kvserver to initialize new data tensor and mapping this """Send message to kvserver to initialize new data tensor and mapping this
data from server side to client side. data from server side to client side.
...@@ -851,57 +851,46 @@ class KVClient(object): ...@@ -851,57 +851,46 @@ class KVClient(object):
data shape data shape
dtype : dtype dtype : dtype
data type data type
policy_str : str part_policy : PartitionPolicy
partition-policy string, e.g., 'edge' or 'node'. partition policy.
partition_book : GraphPartitionBook or RangePartitionBook
Store the partition information
init_func : func init_func : func
UDF init function UDF init function
""" """
assert len(name) > 0, 'name cannot be empty.' assert len(name) > 0, 'name cannot be empty.'
assert len(shape) > 0, 'shape cannot be empty' assert len(shape) > 0, 'shape cannot be empty'
assert policy_str in ('edge', 'node'), 'policy_str must be \'edge\' or \'node\'.'
assert name not in self._data_name_list, 'data name: %s already exists.' % name assert name not in self._data_name_list, 'data name: %s already exists.' % name
self.barrier() self.barrier()
shape = list(shape) shape = list(shape)
if self._client_id == 0: # One of the clients in each machine will issue requests to the local server.
for machine_id in range(self._machine_count): assert rpc.get_num_client() % part_policy.partition_book.num_partitions() == 0, \
if policy_str == 'edge': '#clients ({}) is not divisable by #partitions ({})'.format(
part_dim = partition_book.get_edge_size() rpc.get_num_client(), part_policy.partition_book.num_partitions())
elif policy_str == 'node': num_clients_per_part = rpc.get_num_client() / part_policy.partition_book.num_partitions()
part_dim = partition_book.get_node_size() if self._client_id % num_clients_per_part == 0:
else: part_shape = shape.copy()
raise RuntimeError("Cannot support policy: %s" % policy_str) part_shape[0] = part_policy.get_data_size()
part_shape = shape.copy() request = InitDataRequest(name,
part_shape[0] = part_dim tuple(part_shape),
request = InitDataRequest(name, F.reverse_data_type_dict[dtype],
tuple(part_shape), part_policy.policy_str,
F.reverse_data_type_dict[dtype], init_func)
policy_str, for n in range(self._group_count):
init_func) server_id = part_policy.part_id * self._group_count + n
for n in range(self._group_count): rpc.send_request(server_id, request)
server_id = machine_id * self._group_count + n for _ in range(self._group_count):
rpc.send_request(server_id, request)
for _ in range(self._server_count):
response = rpc.recv_response() response = rpc.recv_response()
assert response.msg == INIT_MSG assert response.msg == INIT_MSG
self.barrier() self.barrier()
# Create local shared-data # Create local shared-data
if policy_str == 'edge':
local_dim = partition_book.get_edge_size()
elif policy_str == 'node':
local_dim = partition_book.get_node_size()
else:
raise RuntimeError("Cannot support policy: %s" % policy_str)
local_shape = shape.copy() local_shape = shape.copy()
local_shape[0] = local_dim local_shape[0] = part_policy.get_data_size()
if self._part_policy.__contains__(name): if name in self._part_policy:
raise RuntimeError("Policy %s has already exists!" % name) raise RuntimeError("Policy %s has already exists!" % name)
if self._data_store.__contains__(name): if name in self._data_store:
raise RuntimeError("Data %s has already exists!" % name) raise RuntimeError("Data %s has already exists!" % name)
if self._full_data_shape.__contains__(name): if name in self._full_data_shape:
raise RuntimeError("Data shape %s has already exists!" % name) raise RuntimeError("Data shape %s has already exists!" % name)
self._part_policy[name] = PartitionPolicy(policy_str, partition_book) self._part_policy[name] = part_policy
shared_data = empty_shared_mem(name+'-kvdata-', False, \ shared_data = empty_shared_mem(name+'-kvdata-', False, \
local_shape, F.reverse_data_type_dict[dtype]) local_shape, F.reverse_data_type_dict[dtype])
dlpack = shared_data.to_dlpack() dlpack = shared_data.to_dlpack()
......
...@@ -107,7 +107,7 @@ class SparseAdagrad: ...@@ -107,7 +107,7 @@ class SparseAdagrad:
policy = emb._tensor.part_policy policy = emb._tensor.part_policy
kvstore.init_data(name + "_sum", kvstore.init_data(name + "_sum",
(emb._tensor.shape[0],), emb._tensor.dtype, (emb._tensor.shape[0],), emb._tensor.dtype,
policy.policy_str, policy.partition_book, _init_state) policy, _init_state)
kvstore.register_push_handler(name, SparseAdagradUDF(self._lr)) kvstore.register_push_handler(name, SparseAdagradUDF(self._lr))
def step(self): def step(self):
......
...@@ -141,14 +141,12 @@ def start_client(num_clients): ...@@ -141,14 +141,12 @@ def start_client(num_clients):
kvclient.init_data(name='data_1', kvclient.init_data(name='data_1',
shape=F.shape(data_1), shape=F.shape(data_1),
dtype=F.dtype(data_1), dtype=F.dtype(data_1),
policy_str='edge', part_policy=edge_policy,
partition_book=gpb,
init_func=init_zero_func) init_func=init_zero_func)
kvclient.init_data(name='data_2', kvclient.init_data(name='data_2',
shape=F.shape(data_2), shape=F.shape(data_2),
dtype=F.dtype(data_2), dtype=F.dtype(data_2),
policy_str='node', part_policy=node_policy,
partition_book=gpb,
init_func=init_zero_func) init_func=init_zero_func)
kvclient.map_shared_data(partition_book=gpb) kvclient.map_shared_data(partition_book=gpb)
...@@ -243,8 +241,7 @@ def start_client(num_clients): ...@@ -243,8 +241,7 @@ def start_client(num_clients):
kvclient.init_data(name='data_3', kvclient.init_data(name='data_3',
shape=F.shape(data_2), shape=F.shape(data_2),
dtype=F.dtype(data_2), dtype=F.dtype(data_2),
policy_str='node', part_policy=node_policy,
partition_book=gpb,
init_func=init_zero_func) init_func=init_zero_func)
kvclient.register_push_handler('data_3', add_push) kvclient.register_push_handler('data_3', add_push)
kvclient.map_shared_data(partition_book=gpb) kvclient.map_shared_data(partition_book=gpb)
...@@ -265,26 +262,27 @@ def start_client(num_clients): ...@@ -265,26 +262,27 @@ def start_client(num_clients):
@unittest.skipIf(os.name == 'nt' or os.getenv('DGLBACKEND') == 'tensorflow', reason='Do not support windows and TF yet') @unittest.skipIf(os.name == 'nt' or os.getenv('DGLBACKEND') == 'tensorflow', reason='Do not support windows and TF yet')
def test_kv_store(): def test_kv_store():
# start 10 server and 10 client
ip_config = open("kv_ip_config.txt", "w") ip_config = open("kv_ip_config.txt", "w")
num_servers = 2
num_clients = 2
ip_addr = get_local_usable_addr() ip_addr = get_local_usable_addr()
ip_config.write('%s 10\n' % ip_addr) ip_config.write('{} {}\n'.format(ip_addr, num_servers))
ip_config.close() ip_config.close()
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
pserver_list = [] pserver_list = []
pclient_list = [] pclient_list = []
for i in range(10): for i in range(num_servers):
pserver = ctx.Process(target=start_server, args=(i, 10)) pserver = ctx.Process(target=start_server, args=(i, num_clients))
pserver.start() pserver.start()
pserver_list.append(pserver) pserver_list.append(pserver)
time.sleep(2) time.sleep(2)
for i in range(10): for i in range(num_clients):
pclient = ctx.Process(target=start_client, args=(10,)) pclient = ctx.Process(target=start_client, args=(num_clients,))
pclient.start() pclient.start()
pclient_list.append(pclient) pclient_list.append(pclient)
for i in range(10): for i in range(num_clients):
pclient_list[i].join() pclient_list[i].join()
for i in range(10): for i in range(num_servers):
pserver_list[i].join() pserver_list[i].join()
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment