Unverified Commit 7897fa3b authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

add init api on kvserver (#975)

parent cccde032
...@@ -30,6 +30,8 @@ def start_client(args): ...@@ -30,6 +30,8 @@ def start_client(args):
client.push(name='embed_0', server_id=1, id_tensor=mx.nd.array([1, 3, 5], dtype='int64'), data_tensor=data_0) client.push(name='embed_0', server_id=1, id_tensor=mx.nd.array([1, 3, 5], dtype='int64'), data_tensor=data_0)
client.push(name='embed_1', server_id=0, id_tensor=mx.nd.array([0, 2, 4], dtype='int64'), data_tensor=data_1) client.push(name='embed_1', server_id=0, id_tensor=mx.nd.array([0, 2, 4], dtype='int64'), data_tensor=data_1)
client.push(name='embed_1', server_id=1, id_tensor=mx.nd.array([1, 3, 5], dtype='int64'), data_tensor=data_1) client.push(name='embed_1', server_id=1, id_tensor=mx.nd.array([1, 3, 5], dtype='int64'), data_tensor=data_1)
client.push(name='server_embed', server_id=0, id_tensor=mx.nd.array([0, 2, 4], dtype='int64'), data_tensor=data_1)
client.push(name='server_embed', server_id=1, id_tensor=mx.nd.array([0, 2, 4], dtype='int64'), data_tensor=data_1)
client.barrier() client.barrier()
...@@ -40,7 +42,6 @@ def start_client(args): ...@@ -40,7 +42,6 @@ def start_client(args):
client.pull(name='embed_0', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4, 5], dtype='int64')) client.pull(name='embed_0', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4, 5], dtype='int64'))
server_id, new_tensor_1 = client.pull_wait() server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1 assert server_id == 1
print("embed_0:") print("embed_0:")
print(mx.nd.concat(new_tensor_0, new_tensor_1, dim=0)) print(mx.nd.concat(new_tensor_0, new_tensor_1, dim=0))
...@@ -50,10 +51,18 @@ def start_client(args): ...@@ -50,10 +51,18 @@ def start_client(args):
client.pull(name='embed_1', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4, 5], dtype='int64')) client.pull(name='embed_1', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4, 5], dtype='int64'))
server_id, new_tensor_1 = client.pull_wait() server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1 assert server_id == 1
print("embed_1:") print("embed_1:")
print(mx.nd.concat(new_tensor_0, new_tensor_1, dim=0)) print(mx.nd.concat(new_tensor_0, new_tensor_1, dim=0))
client.pull(name='server_embed', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
server_id, new_tensor_0 = client.pull_wait()
assert server_id == 0
client.pull(name='server_embed', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
print("server_embed:")
print(mx.nd.concat(new_tensor_0, new_tensor_1, dim=0))
# Shut-down all the servers # Shut-down all the servers
if client.get_id() == 0: if client.get_id() == 0:
client.shut_down() client.shut_down()
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import dgl import dgl
import torch import torch
import argparse import argparse
import mxnet as mx
server_namebook, client_namebook = dgl.contrib.ReadNetworkConfigure('config.txt') server_namebook, client_namebook = dgl.contrib.ReadNetworkConfigure('config.txt')
...@@ -12,6 +13,8 @@ def start_server(args): ...@@ -12,6 +13,8 @@ def start_server(args):
client_namebook=client_namebook, client_namebook=client_namebook,
server_addr=server_namebook[args.id]) server_addr=server_namebook[args.id])
server.init_data(name='server_embed', data_tensor=mx.nd.array([0., 0., 0., 0., 0.]))
server.start() server.start()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -30,6 +30,8 @@ def start_client(args): ...@@ -30,6 +30,8 @@ def start_client(args):
client.push(name='embed_0', server_id=1, id_tensor=th.tensor([1, 3, 5]), data_tensor=data_0) client.push(name='embed_0', server_id=1, id_tensor=th.tensor([1, 3, 5]), data_tensor=data_0)
client.push(name='embed_1', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1) client.push(name='embed_1', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1)
client.push(name='embed_1', server_id=1, id_tensor=th.tensor([1, 3, 5]), data_tensor=data_1) client.push(name='embed_1', server_id=1, id_tensor=th.tensor([1, 3, 5]), data_tensor=data_1)
client.push(name='server_embed', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1)
client.push(name='server_embed', server_id=1, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1)
client.barrier() client.barrier()
...@@ -40,7 +42,6 @@ def start_client(args): ...@@ -40,7 +42,6 @@ def start_client(args):
client.pull(name='embed_0', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4, 5])) client.pull(name='embed_0', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4, 5]))
server_id, new_tensor_1 = client.pull_wait() server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1 assert server_id == 1
print("embed_0:") print("embed_0:")
print(th.cat([new_tensor_0, new_tensor_1])) print(th.cat([new_tensor_0, new_tensor_1]))
...@@ -50,10 +51,18 @@ def start_client(args): ...@@ -50,10 +51,18 @@ def start_client(args):
client.pull(name='embed_1', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4, 5])) client.pull(name='embed_1', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4, 5]))
server_id, new_tensor_1 = client.pull_wait() server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1 assert server_id == 1
print("embed_1:") print("embed_1:")
print(th.cat([new_tensor_0, new_tensor_1])) print(th.cat([new_tensor_0, new_tensor_1]))
client.pull(name='server_embed', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor_0 = client.pull_wait()
assert server_id == 0
client.pull(name='server_embed', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
print("server_embed:")
print(th.cat([new_tensor_0, new_tensor_1]))
# Shut-down all the servers # Shut-down all the servers
if client.get_id() == 0: if client.get_id() == 0:
client.shut_down() client.shut_down()
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import dgl import dgl
import torch import torch
import argparse import argparse
import torch as th
server_namebook, client_namebook = dgl.contrib.ReadNetworkConfigure('config.txt') server_namebook, client_namebook = dgl.contrib.ReadNetworkConfigure('config.txt')
...@@ -12,6 +13,8 @@ def start_server(args): ...@@ -12,6 +13,8 @@ def start_server(args):
client_namebook=client_namebook, client_namebook=client_namebook,
server_addr=server_namebook[args.id]) server_addr=server_namebook[args.id])
server.init_data(name='server_embed', data_tensor=th.tensor([0., 0., 0., 0., 0.]))
server.start() server.start()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -94,6 +94,19 @@ class KVServer(object): ...@@ -94,6 +94,19 @@ class KVServer(object):
_finalize_sender(self._sender) _finalize_sender(self._sender)
_finalize_receiver(self._receiver) _finalize_receiver(self._receiver)
def init_data(self, name, data_tensor):
"""KVServer supports data initialization on server.
Parameters
----------
name : str
data name
data_tensor : tensor
data tensor
"""
self._data_store[name] = data_tensor
self._is_init.add(name)
def start(self): def start(self):
"""Start service of KVServer """Start service of KVServer
""" """
......
...@@ -45,21 +45,36 @@ def start_client(): ...@@ -45,21 +45,36 @@ def start_client():
server_id, new_tensor = client.pull_wait() server_id, new_tensor = client.pull_wait()
assert server_id == 0 assert server_id == 0
target_tensor = th.tensor( target_tensor_0 = th.tensor(
[[ 0., 0., 0.], [[ 0., 0., 0.],
[ 0., 0., 0.], [ 0., 0., 0.],
[ 5., 5., 5.], [ 5., 5., 5.],
[ 0., 0., 0.], [ 0., 0., 0.],
[10., 10., 10.]]) [10., 10., 10.]])
assert th.equal(new_tensor, target_tensor) == True assert th.equal(new_tensor, target_tensor_0) == True
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4])) client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor = client.pull_wait() server_id, new_tensor = client.pull_wait()
target_tensor = th.tensor([ 0., 0., 5., 0., 10.]) target_tensor_1 = th.tensor([ 0., 0., 5., 0., 10.])
assert th.equal(new_tensor, target_tensor) == True assert th.equal(new_tensor, target_tensor_1) == True
client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
_, tensor_0 = client.pull_wait()
_, tensor_1 = client.pull_wait()
_, tensor_2 = client.pull_wait()
_, tensor_3 = client.pull_wait()
assert th.equal(new_tensor_0, target_tensor_0) == True
assert th.equal(new_tensor_1, target_tensor_1) == True
assert th.equal(new_tensor_2, target_tensor_0) == True
assert th.equal(new_tensor_3, target_tensor_1) == True
client.shut_down() client.shut_down()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment