Unverified Commit 7897fa3b authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

add init api on kvserver (#975)

parent cccde032
......@@ -30,6 +30,8 @@ def start_client(args):
client.push(name='embed_0', server_id=1, id_tensor=mx.nd.array([1, 3, 5], dtype='int64'), data_tensor=data_0)
client.push(name='embed_1', server_id=0, id_tensor=mx.nd.array([0, 2, 4], dtype='int64'), data_tensor=data_1)
client.push(name='embed_1', server_id=1, id_tensor=mx.nd.array([1, 3, 5], dtype='int64'), data_tensor=data_1)
client.push(name='server_embed', server_id=0, id_tensor=mx.nd.array([0, 2, 4], dtype='int64'), data_tensor=data_1)
client.push(name='server_embed', server_id=1, id_tensor=mx.nd.array([0, 2, 4], dtype='int64'), data_tensor=data_1)
client.barrier()
......@@ -40,7 +42,6 @@ def start_client(args):
client.pull(name='embed_0', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4, 5], dtype='int64'))
server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
print("embed_0:")
print(mx.nd.concat(new_tensor_0, new_tensor_1, dim=0))
......@@ -50,10 +51,18 @@ def start_client(args):
client.pull(name='embed_1', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4, 5], dtype='int64'))
server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
print("embed_1:")
print(mx.nd.concat(new_tensor_0, new_tensor_1, dim=0))
client.pull(name='server_embed', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
server_id, new_tensor_0 = client.pull_wait()
assert server_id == 0
client.pull(name='server_embed', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
print("server_embed:")
print(mx.nd.concat(new_tensor_0, new_tensor_1, dim=0))
# Shut-down all the servers
if client.get_id() == 0:
client.shut_down()
......
......@@ -3,6 +3,7 @@
import dgl
import torch
import argparse
import mxnet as mx
server_namebook, client_namebook = dgl.contrib.ReadNetworkConfigure('config.txt')
......@@ -12,6 +13,8 @@ def start_server(args):
client_namebook=client_namebook,
server_addr=server_namebook[args.id])
server.init_data(name='server_embed', data_tensor=mx.nd.array([0., 0., 0., 0., 0.]))
server.start()
if __name__ == '__main__':
......
......@@ -30,6 +30,8 @@ def start_client(args):
client.push(name='embed_0', server_id=1, id_tensor=th.tensor([1, 3, 5]), data_tensor=data_0)
client.push(name='embed_1', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1)
client.push(name='embed_1', server_id=1, id_tensor=th.tensor([1, 3, 5]), data_tensor=data_1)
client.push(name='server_embed', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1)
client.push(name='server_embed', server_id=1, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1)
client.barrier()
......@@ -40,7 +42,6 @@ def start_client(args):
client.pull(name='embed_0', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4, 5]))
server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
print("embed_0:")
print(th.cat([new_tensor_0, new_tensor_1]))
......@@ -50,10 +51,18 @@ def start_client(args):
client.pull(name='embed_1', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4, 5]))
server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
print("embed_1:")
print(th.cat([new_tensor_0, new_tensor_1]))
client.pull(name='server_embed', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor_0 = client.pull_wait()
assert server_id == 0
client.pull(name='server_embed', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
print("server_embed:")
print(th.cat([new_tensor_0, new_tensor_1]))
# Shut-down all the servers
if client.get_id() == 0:
client.shut_down()
......
......@@ -3,6 +3,7 @@
import dgl
import torch
import argparse
import torch as th
server_namebook, client_namebook = dgl.contrib.ReadNetworkConfigure('config.txt')
......@@ -12,6 +13,8 @@ def start_server(args):
client_namebook=client_namebook,
server_addr=server_namebook[args.id])
server.init_data(name='server_embed', data_tensor=th.tensor([0., 0., 0., 0., 0.]))
server.start()
if __name__ == '__main__':
......
......@@ -94,6 +94,19 @@ class KVServer(object):
_finalize_sender(self._sender)
_finalize_receiver(self._receiver)
def init_data(self, name, data_tensor):
"""KVServer supports data initialization on server.
Parameters
----------
name : str
data name
data_tensor : tensor
data tensor
"""
self._data_store[name] = data_tensor
self._is_init.add(name)
def start(self):
"""Start service of KVServer
"""
......
......@@ -45,21 +45,36 @@ def start_client():
server_id, new_tensor = client.pull_wait()
assert server_id == 0
target_tensor = th.tensor(
target_tensor_0 = th.tensor(
[[ 0., 0., 0.],
[ 0., 0., 0.],
[ 5., 5., 5.],
[ 0., 0., 0.],
[10., 10., 10.]])
assert th.equal(new_tensor, target_tensor) == True
assert th.equal(new_tensor, target_tensor_0) == True
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor = client.pull_wait()
target_tensor = th.tensor([ 0., 0., 5., 0., 10.])
target_tensor_1 = th.tensor([ 0., 0., 5., 0., 10.])
assert th.equal(new_tensor, target_tensor) == True
assert th.equal(new_tensor, target_tensor_1) == True
client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
_, tensor_0 = client.pull_wait()
_, tensor_1 = client.pull_wait()
_, tensor_2 = client.pull_wait()
_, tensor_3 = client.pull_wait()
assert th.equal(new_tensor_0, target_tensor_0) == True
assert th.equal(new_tensor_1, target_tensor_1) == True
assert th.equal(new_tensor_2, target_tensor_0) == True
assert th.equal(new_tensor_3, target_tensor_1) == True
client.shut_down()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment