Unverified Commit cccde032 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[kvstore] Performance improvement for distributed kvstore (#972)

* Performance improvment for distributed kvstore

* update

* update
parent fdd0fe65
...@@ -17,46 +17,42 @@ def start_client(args): ...@@ -17,46 +17,42 @@ def start_client(args):
client.connect() client.connect()
# Initialize data on server # Initialize data on server
client.init_data(name='embed_0', shape=[10, 3], init_type='zero') client.init_data(name='embed_0', server_id=0, shape=[5, 3], init_type='zero')
client.init_data(name='embed_1', shape=[11, 3], init_type='uniform', low=0.0, high=0.0) client.init_data(name='embed_0', server_id=1, shape=[6, 3], init_type='zero')
client.init_data(name='embed_2', shape=[11], init_type='zero') client.init_data(name='embed_1', server_id=0, shape=[5], init_type='uniform', low=0.0, high=0.0)
client.init_data(name='embed_1', server_id=1, shape=[6], init_type='uniform', low=0.0, high=0.0)
tensor_id = mx.nd.array([0, 1, 2], dtype='int64') data_0 = mx.nd.array([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]])
tensor_data = mx.nd.array([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]]) data_1 = mx.nd.array([0., 1., 2.])
for i in range(5): for i in range(5):
client.push('embed_0', tensor_id, tensor_data) client.push(name='embed_0', server_id=0, id_tensor=mx.nd.array([0, 2, 4], dtype='int64'), data_tensor=data_0)
client.push('embed_1', tensor_id, tensor_data) client.push(name='embed_0', server_id=1, id_tensor=mx.nd.array([1, 3, 5], dtype='int64'), data_tensor=data_0)
client.push('embed_2', tensor_id, mx.nd.array([2., 2., 2.])) client.push(name='embed_1', server_id=0, id_tensor=mx.nd.array([0, 2, 4], dtype='int64'), data_tensor=data_1)
client.push(name='embed_1', server_id=1, id_tensor=mx.nd.array([1, 3, 5], dtype='int64'), data_tensor=data_1)
tensor_id = mx.nd.array([6, 7, 8], dtype='int64')
for i in range(5):
client.push('embed_0', tensor_id, tensor_data)
client.push('embed_1', tensor_id, tensor_data)
client.push('embed_2', tensor_id, mx.nd.array([3., 3., 3.]))
client.barrier() client.barrier()
if client.get_id() == 0: if client.get_id() == 0:
tensor_id = mx.nd.array([0,1,2,3,4,5,6,7,8,9], dtype='int64') client.pull(name='embed_0', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
new_tensor_0 = client.pull('embed_0', tensor_id) server_id, new_tensor_0 = client.pull_wait()
tensor_id = mx.nd.array([0,1,2,3,4,5,6,7,8,9,10], dtype='int64') assert server_id == 0
new_tensor_1 = client.pull('embed_1', tensor_id) client.pull(name='embed_0', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4, 5], dtype='int64'))
new_tensor_2 = client.pull('embed_2', tensor_id) server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
print("embed_0:")
print(mx.nd.concat(new_tensor_0, new_tensor_1, dim=0))
client.push_all('embed_0', new_tensor_0) client.pull(name='embed_1', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
client.push_all('embed_1', new_tensor_1) server_id, new_tensor_0 = client.pull_wait()
client.push_all('embed_2', new_tensor_2) assert server_id == 0
client.pull(name='embed_1', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4, 5], dtype='int64'))
server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
new_tensor_3 = client.pull_all('embed_0') print("embed_1:")
new_tensor_4 = client.pull_all('embed_1') print(mx.nd.concat(new_tensor_0, new_tensor_1, dim=0))
new_tensor_5 = client.pull_all('embed_2')
print("embed_0: ")
print(new_tensor_3)
print("embed_1: ")
print(new_tensor_4)
print("embed_2: ")
print(new_tensor_5)
# Shut-down all the servers # Shut-down all the servers
if client.get_id() == 0: if client.get_id() == 0:
......
# This is a simple pytorch client demo shows how to use DGL distributed kvstore. # This is a simple pytorch client demo shows how to use DGL distributed kvstore.
# In this demo, we initialize two embeddings on server and push/pull data to/from it. # In this demo, we initialize two embeddings on server and push/pull data to/from it.
import dgl import dgl
import torch
import time import time
import argparse import argparse
import torch as th import torch as th
...@@ -18,46 +17,42 @@ def start_client(args): ...@@ -18,46 +17,42 @@ def start_client(args):
client.connect() client.connect()
# Initialize data on server # Initialize data on server
client.init_data(name='embed_0', shape=[10, 3], init_type='zero') client.init_data(name='embed_0', server_id=0, shape=[5, 3], init_type='zero')
client.init_data(name='embed_1', shape=[11, 3], init_type='uniform', low=0.0, high=0.0) client.init_data(name='embed_0', server_id=1, shape=[6, 3], init_type='zero')
client.init_data(name='embed_2', shape=[11], init_type='zero') client.init_data(name='embed_1', server_id=0, shape=[5], init_type='uniform', low=0.0, high=0.0)
client.init_data(name='embed_1', server_id=1, shape=[6], init_type='uniform', low=0.0, high=0.0)
tensor_id = torch.tensor([0, 1, 2]) data_0 = th.tensor([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]])
tensor_data = torch.tensor([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]]) data_1 = th.tensor([0., 1., 2.])
for i in range(5): for i in range(5):
client.push('embed_0', tensor_id, tensor_data) client.push(name='embed_0', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_0)
client.push('embed_1', tensor_id, tensor_data) client.push(name='embed_0', server_id=1, id_tensor=th.tensor([1, 3, 5]), data_tensor=data_0)
client.push('embed_2', tensor_id, th.tensor([2., 2., 2.])) client.push(name='embed_1', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1)
client.push(name='embed_1', server_id=1, id_tensor=th.tensor([1, 3, 5]), data_tensor=data_1)
tensor_id = torch.tensor([6, 7, 8])
for i in range(5):
client.push('embed_0', tensor_id, tensor_data)
client.push('embed_1', tensor_id, tensor_data)
client.push('embed_2', tensor_id, th.tensor([3., 3., 3.]))
client.barrier() client.barrier()
if client.get_id() == 0: if client.get_id() == 0:
tensor_id = torch.tensor([0,1,2,3,4,5,6,7,8,9]) client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
new_tensor_0 = client.pull('embed_0', tensor_id) server_id, new_tensor_0 = client.pull_wait()
tensor_id = torch.tensor([0,1,2,3,4,5,6,7,8,9,10]) assert server_id == 0
new_tensor_1 = client.pull('embed_1', tensor_id) client.pull(name='embed_0', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4, 5]))
new_tensor_2 = client.pull('embed_2', tensor_id) server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
client.push_all('embed_0', new_tensor_0)
client.push_all('embed_1', new_tensor_1)
client.push_all('embed_2', new_tensor_2)
new_tensor_3 = client.pull_all('embed_0')
new_tensor_4 = client.pull_all('embed_1')
new_tensor_5 = client.pull_all('embed_2')
print("embed_0:") print("embed_0:")
print(new_tensor_3) print(th.cat([new_tensor_0, new_tensor_1]))
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor_0 = client.pull_wait()
assert server_id == 0
client.pull(name='embed_1', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4, 5]))
server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
print("embed_1:") print("embed_1:")
print(new_tensor_4) print(th.cat([new_tensor_0, new_tensor_1]))
print("embed_2:")
print(new_tensor_5)
# Shut-down all the servers # Shut-down all the servers
if client.get_id() == 0: if client.get_id() == 0:
......
...@@ -53,15 +53,10 @@ class KVServer(object): ...@@ -53,15 +53,10 @@ class KVServer(object):
"""KVServer is a lightweight key-value store service for DGL distributed training. """KVServer is a lightweight key-value store service for DGL distributed training.
In practice, developers can use KVServer to hold large-scale graph features or In practice, developers can use KVServer to hold large-scale graph features or
graph embeddings across machines in a distributed setting or storing them in one standalone graph embeddings across machines in a distributed setting. User can re-wriite _push_handler
machine with big memory capability. DGL KVServer uses a very simple range-partition scheme to and _pull_handler to support flexibale models.
partition data into different KVServer nodes. For example, if the total embedding size is 200 and
we have two KVServer nodes, the data (0~99) will be stored in kvserver_0, and the data (100~199) will
be stored in kvserver_1.
For KVServer, user can re-wriite UDF function for _push_handler and _pull_handler. Note that, DO NOT use KVServer in multiple threads!
DO NOT use KVServer in multiple threads!
Parameters Parameters
---------- ----------
...@@ -77,14 +72,14 @@ class KVServer(object): ...@@ -77,14 +72,14 @@ class KVServer(object):
server_addr : str server_addr : str
IP address of current KVServer node, e.g., '127.0.0.1:50051' IP address of current KVServer node, e.g., '127.0.0.1:50051'
net_type : str net_type : str
networking type, e.g., 'socket' (default) or 'mpi'. networking type, e.g., 'socket' (default) or 'mpi' (do not support yet).
""" """
def __init__(self, server_id, client_namebook, server_addr, net_type='socket'): def __init__(self, server_id, client_namebook, server_addr, net_type='socket'):
assert server_id >= 0, 'server_id cannot be a negative number.' assert server_id >= 0, 'server_id (%d) cannot be a negative number.' % server_id
assert len(client_namebook) > 0, 'client_namebook cannot be empty.' assert len(client_namebook) > 0, 'client_namebook cannot be empty.'
assert len(server_addr.split(':')) == 2, 'Incorrect IP format.' assert len(server_addr.split(':')) == 2, 'Incorrect IP format: %s' % server_addr
self._is_init = set() # Contains tensor name self._is_init = set() # Contains tensor name
self._data_store = {} # Key is name string and value is tensor self._data_store = {} # Key is name (string) and value is data (tensor)
self._barrier_count = 0; self._barrier_count = 0;
self._server_id = server_id self._server_id = server_id
self._client_namebook = client_namebook self._client_namebook = client_namebook
...@@ -130,13 +125,9 @@ class KVServer(object): ...@@ -130,13 +125,9 @@ class KVServer(object):
high=row_1[1]) high=row_1[1])
self._is_init.add(msg.name) self._is_init.add(msg.name)
elif msg.type == KVMsgType.PUSH: elif msg.type == KVMsgType.PUSH:
# convert global ID to local ID self._push_handler(msg.name, msg.id, msg.data)
local_id = self._remap_id(msg.name, msg.id)
self._push_handler(msg.name, local_id, msg.data)
elif msg.type == KVMsgType.PULL: elif msg.type == KVMsgType.PULL:
# convert global ID to local ID res_tensor = self._pull_handler(msg.name, msg.id)
local_id = self._remap_id(msg.name, msg.id)
res_tensor = self._pull_handler(msg.name, local_id)
back_msg = KVStoreMsg( back_msg = KVStoreMsg(
type=KVMsgType.PULL_BACK, type=KVMsgType.PULL_BACK,
rank=self._server_id, rank=self._server_id,
...@@ -157,7 +148,7 @@ class KVServer(object): ...@@ -157,7 +148,7 @@ class KVServer(object):
_send_kv_msg(self._sender, back_msg, i) _send_kv_msg(self._sender, back_msg, i)
self._barrier_count = 0 self._barrier_count = 0
elif msg.type == KVMsgType.FINAL: elif msg.type == KVMsgType.FINAL:
print("Exit KVStore service, server ID: %d" % self.get_id()) print("Exit KVStore service, server ID: %d" % self._server_id)
break # exit loop break # exit loop
else: else:
raise RuntimeError('Unknown type of kvstore message: %d' % msg.type.value) raise RuntimeError('Unknown type of kvstore message: %d' % msg.type.value)
...@@ -204,7 +195,7 @@ class KVServer(object): ...@@ -204,7 +195,7 @@ class KVServer(object):
raise RuntimeError('Unknown initial method') raise RuntimeError('Unknown initial method')
def _push_handler(self, name, ID, data): def _push_handler(self, name, ID, data):
"""User-defined handler for PUSH message. """Default handler for PUSH message.
On default, _push_handler perform ADD operation for the tensor. On default, _push_handler perform ADD operation for the tensor.
...@@ -213,15 +204,15 @@ class KVServer(object): ...@@ -213,15 +204,15 @@ class KVServer(object):
name : str name : str
data name data name
ID : tensor (mx.ndarray or torch.tensor) ID : tensor (mx.ndarray or torch.tensor)
a vector storing the IDs that has been re-mapped to local id. a vector storing the ID list.
data : tensor (mx.ndarray or torch.tensor) data : tensor (mx.ndarray or torch.tensor)
a matrix with the same row size of id a tensor with the same row size of id
""" """
for idx in range(ID.shape[0]): # For each row for idx in range(ID.shape[0]):
self._data_store[name][ID[idx]] += data[idx] self._data_store[name][ID[idx]] += data[idx]
def _pull_handler(self, name, ID): def _pull_handler(self, name, ID):
"""User-defined handler for PULL operation. """Default handler for PULL operation.
On default, _pull_handler perform gather_row() operation for the tensor. On default, _pull_handler perform gather_row() operation for the tensor.
...@@ -235,40 +226,26 @@ class KVServer(object): ...@@ -235,40 +226,26 @@ class KVServer(object):
Return Return
------ ------
tensor tensor
a matrix with the same row size of ID a tensor with the same row size of ID
""" """
new_tensor = F.gather_row(self._data_store[name], ID) new_tensor = F.gather_row(self._data_store[name], ID)
return new_tensor return new_tensor
def _remap_id(self, name, ID):
"""Re-mapping global-ID to local-ID.
Parameters
----------
name : str
data name
ID : tensor (mx.ndarray or torch.tensor)
a vector storing the global data ID
Return
------
tensor
re-mapped lcoal ID
"""
row_size = self._data_store[name].shape[0]
return ID % row_size
class KVClient(object): class KVClient(object):
"""KVClient is used to push/pull tensors to/from KVServer on DGL trainer. """KVClient is used to push/pull tensors to/from KVServer on DGL trainer.
There are five operations supported by KVClient: There are five operations supported by KVClient:
* init_data(name, shape, init_type, low, high): initialize tensor on KVServer * init_data(name, server_id, shape, init_type, low, high):
* push(name, id, data): push sparse data to KVServer given specified IDs initialize tensor on target KVServer.
* pull(name, id): pull sparse data from KVServer given specified IDs * push(name, server_id, id_tensor, data_tensor):
* push_all(name, data): push dense data to KVServer push sparse data to KVServer given specified ID.
* pull_all(name): pull sense data from KVServer * pull(name, server_id, id_tensor):
* shut_down(): shut down all KVServer nodes pull sparse data from KVServer given specified ID.
* pull_wait():
wait scheduled pull operation finish its job.
* shut_down():
shut down all KVServer nodes.
Note that, DO NOT use KVClient in multiple threads! Note that, DO NOT use KVClient in multiple threads!
...@@ -292,10 +269,6 @@ class KVClient(object): ...@@ -292,10 +269,6 @@ class KVClient(object):
assert client_id >= 0, 'client_id (%d) cannot be a nagative number.' % client_id assert client_id >= 0, 'client_id (%d) cannot be a nagative number.' % client_id
assert len(server_namebook) > 0, 'server_namebook cannot be empty.' assert len(server_namebook) > 0, 'server_namebook cannot be empty.'
assert len(client_addr.split(':')) == 2, 'Incorrect IP format: %s' % client_addr assert len(client_addr.split(':')) == 2, 'Incorrect IP format: %s' % client_addr
# self._data_size is a key-value store where the key is data name
# and value is the size of tensor. It is used to partition data into
# different KVServer nodes.
self._data_size = {}
self._client_id = client_id self._client_id = client_id
self._server_namebook = server_namebook self._server_namebook = server_namebook
self._server_count = len(server_namebook) self._server_count = len(server_namebook)
...@@ -319,13 +292,20 @@ class KVClient(object): ...@@ -319,13 +292,20 @@ class KVClient(object):
client_ip, client_port = self._addr.split(':') client_ip, client_port = self._addr.split(':')
_receiver_wait(self._receiver, client_ip, int(client_port), self._server_count) _receiver_wait(self._receiver, client_ip, int(client_port), self._server_count)
def init_data(self, name, shape, init_type='zero', low=0.0, high=0.0): def init_data(self, name, server_id, shape, init_type='zero', low=0.0, high=0.0):
"""Initialize kvstore tensor """Initialize kvstore tensor
we hack the msg format here: msg.id store the shape of target tensor,
msg.data has two row, and the first row is the init_type,
[0, 0] means 'zero' and [1,1] means 'uniform'.
The second row is the min & max threshold.
Parameters Parameters
---------- ----------
name : str name : str
data name data name
server_id : int
target server id
shape : list of int shape : list of int
shape of tensor shape of tensor
init_type : str init_type : str
...@@ -335,158 +315,86 @@ class KVClient(object): ...@@ -335,158 +315,86 @@ class KVClient(object):
high : float high : float
max threshold, if use 'uniform' max threshold, if use 'uniform'
""" """
self._data_size[name] = shape[0] tensor_shape = F.tensor(shape)
count = math.ceil(shape[0] / self._server_count)
# We hack the msg format here
init_type = 0.0 if init_type == 'zero' else 1.0 init_type = 0.0 if init_type == 'zero' else 1.0
threshold = F.tensor([[init_type, init_type], [low, high]]) threshold = F.tensor([[init_type, init_type], [low, high]])
# partition shape on server msg = KVStoreMsg(
for server_id in range(self._server_count): type=KVMsgType.INIT,
par_shape = shape.copy() rank=self._client_id,
if shape[0] - server_id*count >= count: name=name,
par_shape[0] = count id=tensor_shape,
else: data=threshold)
par_shape[0] = shape[0] - server_id*count _send_kv_msg(self._sender, msg, server_id)
tensor_shape = F.tensor(par_shape)
msg = KVStoreMsg(
type=KVMsgType.INIT,
rank=self._client_id,
name=name,
id=tensor_shape,
data=threshold)
_send_kv_msg(self._sender, msg, server_id)
def push(self, name, ID, data): def push(self, name, server_id, id_tensor, data_tensor):
"""Push sparse message to KVServer """Push sparse message to target KVServer.
The push() API will partition message into different Note that push() is an async operation that will return immediately after calling.
KVServer nodes automatically.
Note that we assume the row Ids in ID is in the ascending order.
Parameters Parameters
---------- ----------
name : str name : str
data name data name
ID : tensor (mx.ndarray or torch.tensor) server_id : int
a vector storing the global IDs target server id
data : tensor (mx.ndarray or torch.tensor) id_tensor : tensor (mx.ndarray or torch.tensor)
a vector storing the ID list
data_tensor : tensor (mx.ndarray or torch.tensor)
a tensor with the same row size of id a tensor with the same row size of id
""" """
assert F.ndim(ID) == 1, 'ID must be a vector.' assert server_id >= 0, 'server_id (%d) cannot be a negative number' % server_id
assert F.shape(ID)[0] == F.shape(data)[0], 'The data must has the same row size with ID.' assert server_id < self._server_count, 'server_id (%d) must be smaller than server_count' % server_id
group_size = [0] * self._server_count assert F.ndim(id_tensor) == 1, 'ID must be a vector.'
numpy_id = F.asnumpy(ID) assert F.shape(id_tensor)[0] == F.shape(data_tensor)[0], 'The data must has the same row size with ID.'
count = math.ceil(self._data_size[name] / self._server_count) msg = KVStoreMsg(
server_id = numpy_id / count type=KVMsgType.PUSH,
id_list, id_count = np.unique(server_id, return_counts=True) rank=self._client_id,
for idx in range(len(id_list)): name=name,
group_size[int(id_list[idx])] += id_count[idx] id=id_tensor,
min_idx = 0 data=data_tensor)
max_idx = 0 _send_kv_msg(self._sender, msg, server_id)
for idx in range(self._server_count):
if group_size[idx] == 0:
continue
max_idx += group_size[idx]
range_id = ID[min_idx:max_idx]
range_data = data[min_idx:max_idx]
min_idx = max_idx
msg = KVStoreMsg(
type=KVMsgType.PUSH,
rank=self._client_id,
name=name,
id=range_id,
data=range_data)
_send_kv_msg(self._sender, msg, idx)
def push_all(self, name, data):
"""Push the whole data to KVServer
The push_all() API will partition message into different
KVServer nodes automatically.
Note that we assume the row Ids in ID is in the ascending order.
Parameters
----------
name : str
data name
data : tensor (mx.ndarray or torch.tensor)
data tensor
"""
ID = F.zerocopy_from_numpy(np.arange(F.shape(data)[0]))
self.push(name, ID, data)
def pull(self, name, ID): def pull(self, name, server_id, id_tensor):
"""Pull sparse message from KVServer """Pull sparse message from KVServer
Note that we assume the row Ids in ID is in the ascending order. Note that pull() is async operation that will return immediately after calling.
User can use pull_wait() to get the real data pulled from the kvserver. The order
of received data that comes from the same server is deterministic.
Parameters Parameters
---------- ----------
name : str name : str
data name data name
ID : tensor (mx.ndarray or torch.tensor) server_id : int
a vector storing the IDs target server id
id_tensor : tensor (mx.ndarray or torch.tensor)
a vector storing the ID list
Return
------
tensor
a tensor with the same row size of ID
""" """
assert F.ndim(ID) == 1, 'ID must be a vector.' assert server_id >= 0, 'server_id (%d) cannot be a negative number' % server_id
group_size = [0] * self._server_count assert server_id < self._server_count, 'server_id (%d) must be smaller than server_count' % server_id
numpy_id = F.asnumpy(ID) assert F.ndim(id_tensor) == 1, 'ID must be a vector.'
count = math.ceil(self._data_size[name] / self._server_count) msg = KVStoreMsg(
server_id = numpy_id / count type=KVMsgType.PULL,
id_list, id_count = np.unique(server_id, return_counts=True) rank=self._client_id,
for idx in range(len(id_list)): name=name,
group_size[int(id_list[idx])] += id_count[idx] id=id_tensor,
min_idx = 0 data=None)
max_idx = 0 _send_kv_msg(self._sender, msg, server_id)
server_count = 0
for idx in range(self._server_count):
if group_size[idx] == 0:
continue
server_count += 1
max_idx += group_size[idx]
range_id = ID[min_idx:max_idx]
min_idx = max_idx
msg = KVStoreMsg(
type=KVMsgType.PULL,
rank=self._client_id,
name=name,
id=range_id,
data=None)
_send_kv_msg(self._sender, msg, idx)
# Recv back message
msg_list = []
for idx in range(self._server_count):
if group_size[idx] == 0:
continue
msg = _recv_kv_msg(self._receiver)
assert msg.type == KVMsgType.PULL_BACK, 'Recv kv msg error.'
msg_list.append(msg)
return self._merge_msg(msg_list)
def pull_all(self, name): def pull_wait(self):
"""Pull the whole data from KVServer """Wait pull() finish its job.
Note that we assume the row Ids in ID is in the ascending order. Returns
-------
Parameters msg.rank
---------- server_id
name : str msg.data
data name
Return
------
tensor
target data tensor target data tensor
""" """
ID = F.zerocopy_from_numpy(np.arange(self._data_size[name])) msg = _recv_kv_msg(self._receiver)
return self.pull(name, ID) assert msg.type == KVMsgType.PULL_BACK, 'Recv kv msg error.'
return msg.rank, msg.data
def barrier(self): def barrier(self):
"""Barrier for all client nodes """Barrier for all client nodes
...@@ -528,29 +436,3 @@ class KVClient(object): ...@@ -528,29 +436,3 @@ class KVClient(object):
KVClient ID KVClient ID
""" """
return self._client_id return self._client_id
def _sort_func(self, msg):
"""Sort function for KVStoreMsg: sort message by rank
Parameters
----------
msg : KVStoreMsg
KVstore message
"""
return msg.rank
def _merge_msg(self, msg_list):
"""Merge separated message to a big matrix
Parameters
----------
msg_list : list
a list of KVStoreMsg
Return
------
tensor (mx.ndarray or torch.tensor)
a merged data matrix
"""
msg_list.sort(key=self._sort_func)
return F.cat([msg.data for msg in msg_list], 0)
\ No newline at end of file
...@@ -2,7 +2,7 @@ import backend as F ...@@ -2,7 +2,7 @@ import backend as F
import numpy as np import numpy as np
import scipy as sp import scipy as sp
import dgl import dgl
import torch import torch as th
from dgl import utils from dgl import utils
import os import os
...@@ -28,70 +28,38 @@ def start_client(): ...@@ -28,70 +28,38 @@ def start_client():
client.connect() client.connect()
client.init_data(name='embed_0', shape=[10, 3], init_type='zero') # Initialize data on server
client.init_data(name='embed_1', shape=[11, 3], init_type='uniform', low=0.0, high=0.0) client.init_data(name='embed_0', server_id=0, shape=[5, 3], init_type='zero')
client.init_data(name='embed_2', shape=[11], init_type='zero') client.init_data(name='embed_1', server_id=0, shape=[5], init_type='uniform', low=0.0, high=0.0)
tensor_id = torch.tensor([0, 1, 2]) data_0 = th.tensor([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]])
tensor_data = torch.tensor([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]]) data_1 = th.tensor([0., 1., 2.])
# Push
for i in range(5): for i in range(5):
client.push('embed_0', tensor_id, tensor_data) client.push(name='embed_0', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_0)
client.push('embed_1', tensor_id, tensor_data) client.push(name='embed_1', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1)
client.push('embed_2', tensor_id, torch.tensor([2., 2., 2.]))
tensor_id = torch.tensor([6, 7, 8]) client.barrier()
for i in range(5):
client.push('embed_0', tensor_id, tensor_data) client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
client.push('embed_1', tensor_id, tensor_data) server_id, new_tensor = client.pull_wait()
client.push('embed_2', tensor_id, torch.tensor([3., 3., 3.])) assert server_id == 0
# Pull target_tensor = th.tensor(
tensor_id = torch.tensor([0, 1, 2, 6, 7, 8]) [[ 0., 0., 0.],
new_tensor_0 = client.pull('embed_0', tensor_id) [ 0., 0., 0.],
new_tensor_1 = client.pull('embed_1', tensor_id) [ 5., 5., 5.],
new_tensor_2 = client.pull('embed_2', tensor_id) [ 0., 0., 0.],
[10., 10., 10.]])
target_tensor = torch.tensor(
[[ 0., 0., 0.], assert th.equal(new_tensor, target_tensor) == True
[ 5., 5., 5.],
[10., 10., 10.], client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
[ 0., 0., 0.], server_id, new_tensor = client.pull_wait()
[ 5., 5., 5.],
[10., 10., 10.]]) target_tensor = th.tensor([ 0., 0., 5., 0., 10.])
assert torch.equal(new_tensor_0, target_tensor) == True assert th.equal(new_tensor, target_tensor) == True
assert torch.equal(new_tensor_1, target_tensor) == True
target_tensor = tensor.tensor([10., 10., 10., 15., 15., 15.])
assert torch.equal(new_tensor_2, target_tensor) == True
client.push_all('embed_0', client.pull_all('embed_0'))
client.push_all('embed_1', client.pull_all('embed_1'))
client.push_all('embed_2', client.pull_all('embed_2'))
# Pull
tensor_id = torch.tensor([0, 1, 2, 6, 7, 8])
new_tensor_0 = client.pull('embed_0', tensor_id)
new_tensor_1 = client.pull('embed_1', tensor_id)
new_tensor_2 = client.pull('embed_2', tensor_id)
target_tensor = torch.tensor(
[[ 0., 0., 0.],
[ 10., 10., 10.],
[20., 20., 20.],
[ 0., 0., 0.],
[ 10., 10., 10.],
[20., 20., 20.]])
assert torch.equal(new_tensor_0, target_tensor) == True
assert torch.equal(new_tensor_1, target_tensor) == True
target_tensor = tensor.tensor([20., 20., 20., 30., 30., 30.])
assert torch.equal(new_tensor_2, target_tensor) == True
client.shut_down() client.shut_down()
......
Subproject commit 0f3ddbc7240efa05bfffd5bca808ec262ce3630e Subproject commit 7ce90a342b0bda9b7f88e707a326496324d60efd
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment