"vscode:/vscode.git/clone" did not exist on "7c788f531c94dc00d9577102b9599100334b2ba0"
Unverified Commit a0193fd5 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

Small change for kvstore api (#981)

* Small change for kvstore api

* fix ci

* fix ci
parent 0b4935d4
...@@ -37,31 +37,31 @@ def start_client(args): ...@@ -37,31 +37,31 @@ def start_client(args):
if client.get_id() == 0: if client.get_id() == 0:
client.pull(name='embed_0', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64')) client.pull(name='embed_0', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
server_id, new_tensor_0 = client.pull_wait() msg_0 = client.pull_wait()
assert server_id == 0 assert msg_0.rank == 0
client.pull(name='embed_0', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4, 5], dtype='int64')) client.pull(name='embed_0', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4, 5], dtype='int64'))
server_id, new_tensor_1 = client.pull_wait() msg_1 = client.pull_wait()
assert server_id == 1 assert msg_1.rank == 1
print("embed_0:") print("embed_0:")
print(mx.nd.concat(new_tensor_0, new_tensor_1, dim=0)) print(mx.nd.concat(msg_0.data, msg_1.data, dim=0))
client.pull(name='embed_1', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64')) client.pull(name='embed_1', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
server_id, new_tensor_0 = client.pull_wait() msg_0 = client.pull_wait()
assert server_id == 0 assert msg_0.rank == 0
client.pull(name='embed_1', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4, 5], dtype='int64')) client.pull(name='embed_1', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4, 5], dtype='int64'))
server_id, new_tensor_1 = client.pull_wait() msg_1 = client.pull_wait()
assert server_id == 1 assert msg_1.rank == 1
print("embed_1:") print("embed_1:")
print(mx.nd.concat(new_tensor_0, new_tensor_1, dim=0)) print(mx.nd.concat(msg_0.data, msg_1.data, dim=0))
client.pull(name='server_embed', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64')) client.pull(name='server_embed', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
server_id, new_tensor_0 = client.pull_wait() msg_0 = client.pull_wait()
assert server_id == 0 assert msg_0.rank == 0
client.pull(name='server_embed', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64')) client.pull(name='server_embed', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
server_id, new_tensor_1 = client.pull_wait() msg_1 = client.pull_wait()
assert server_id == 1 assert msg_1.rank == 1
print("server_embed:") print("server_embed:")
print(mx.nd.concat(new_tensor_0, new_tensor_1, dim=0)) print(mx.nd.concat(msg_0.data, msg_1.data, dim=0))
# Shut-down all the servers # Shut-down all the servers
if client.get_id() == 0: if client.get_id() == 0:
......
...@@ -37,31 +37,31 @@ def start_client(args): ...@@ -37,31 +37,31 @@ def start_client(args):
if client.get_id() == 0: if client.get_id() == 0:
client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4])) client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor_0 = client.pull_wait() msg_0 = client.pull_wait()
assert server_id == 0 assert msg_0.rank == 0
client.pull(name='embed_0', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4, 5])) client.pull(name='embed_0', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4, 5]))
server_id, new_tensor_1 = client.pull_wait() msg_1 = client.pull_wait()
assert server_id == 1 assert msg_1.rank == 1
print("embed_0:") print("embed_0:")
print(th.cat([new_tensor_0, new_tensor_1])) print(th.cat([msg_0.data, msg_1.data]))
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4])) client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor_0 = client.pull_wait() msg_0 = client.pull_wait()
assert server_id == 0 assert msg_0.rank == 0
client.pull(name='embed_1', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4, 5])) client.pull(name='embed_1', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4, 5]))
server_id, new_tensor_1 = client.pull_wait() msg_1 = client.pull_wait()
assert server_id == 1 assert msg_1.rank == 1
print("embed_1:") print("embed_1:")
print(th.cat([new_tensor_0, new_tensor_1])) print(th.cat([msg_0.data, msg_1.data]))
client.pull(name='server_embed', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4])) client.pull(name='server_embed', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor_0 = client.pull_wait() msg_0 = client.pull_wait()
assert server_id == 0 assert msg_0.rank == 0
client.pull(name='server_embed', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4])) client.pull(name='server_embed', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor_1 = client.pull_wait() msg_1 = client.pull_wait()
assert server_id == 1 assert msg_1.rank == 1
print("server_embed:") print("server_embed:")
print(th.cat([new_tensor_0, new_tensor_1])) print(th.cat([msg_0.data, msg_1.data]))
# Shut-down all the servers # Shut-down all the servers
if client.get_id() == 0: if client.get_id() == 0:
......
...@@ -407,7 +407,7 @@ class KVClient(object): ...@@ -407,7 +407,7 @@ class KVClient(object):
""" """
msg = _recv_kv_msg(self._receiver) msg = _recv_kv_msg(self._receiver)
assert msg.type == KVMsgType.PULL_BACK, 'Recv kv msg error.' assert msg.type == KVMsgType.PULL_BACK, 'Recv kv msg error.'
return msg.rank, msg.data return msg
def barrier(self): def barrier(self):
"""Barrier for all client nodes """Barrier for all client nodes
......
...@@ -45,8 +45,8 @@ def start_client(server_embed): ...@@ -45,8 +45,8 @@ def start_client(server_embed):
client.barrier() client.barrier()
client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4])) client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor = client.pull_wait() msg = client.pull_wait()
assert server_id == 0 assert msg.rank == 0
target_tensor_0 = th.tensor( target_tensor_0 = th.tensor(
[[ 0., 0., 0.], [[ 0., 0., 0.],
...@@ -55,14 +55,14 @@ def start_client(server_embed): ...@@ -55,14 +55,14 @@ def start_client(server_embed):
[ 0., 0., 0.], [ 0., 0., 0.],
[10., 10., 10.]]) [10., 10., 10.]])
assert th.equal(new_tensor, target_tensor_0) == True assert th.equal(msg.data, target_tensor_0) == True
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4])) client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor = client.pull_wait() msg = client.pull_wait()
target_tensor_1 = th.tensor([ 0., 0., 5., 0., 10.]) target_tensor_1 = th.tensor([ 0., 0., 5., 0., 10.])
assert th.equal(new_tensor, target_tensor_1) == True assert th.equal(msg.data, target_tensor_1) == True
client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4])) client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4])) client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
...@@ -70,31 +70,32 @@ def start_client(server_embed): ...@@ -70,31 +70,32 @@ def start_client(server_embed):
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4])) client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
client.pull(name='server_embed', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4])) client.pull(name='server_embed', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
_, tensor_0 = client.pull_wait() msg_0 = client.pull_wait()
_, tensor_1 = client.pull_wait() msg_1 = client.pull_wait()
_, tensor_2 = client.pull_wait() msg_2 = client.pull_wait()
_, tensor_3 = client.pull_wait() msg_3 = client.pull_wait()
_, tensor_4 = client.pull_wait() msg_4 = client.pull_wait()
target_tensor_2 = th.tensor([ 2., 2., 7., 2., 12.]) target_tensor_2 = th.tensor([ 2., 2., 7., 2., 12.])
assert th.equal(tensor_0, target_tensor_0) == True assert th.equal(msg_0.data, target_tensor_0) == True
assert th.equal(tensor_1, target_tensor_1) == True assert th.equal(msg_1.data, target_tensor_1) == True
assert th.equal(tensor_2, target_tensor_0) == True assert th.equal(msg_2.data, target_tensor_0) == True
assert th.equal(tensor_3, target_tensor_1) == True assert th.equal(msg_3.data, target_tensor_1) == True
assert th.equal(tensor_4, target_tensor_2) == True assert th.equal(msg_4.data, target_tensor_2) == True
server_embed += target_tensor_2 server_embed += target_tensor_2
client.pull(name='server_embed', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4])) client.pull(name='server_embed', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
_, tensor_5 = client.pull_wait() msg_5 = client.pull_wait()
assert th.equal(tensor_5, target_tensor_2 * 2) == True assert th.equal(msg_5.data, target_tensor_2 * 2) == True
client.shut_down() client.shut_down()
if __name__ == '__main__': if __name__ == '__main__':
server_embed = th.tensor([2., 2., 2., 2., 2.]) server_embed = th.tensor([2., 2., 2., 2., 2.])
# use pytorch shared memory
server_embed.share_memory_() server_embed.share_memory_()
pid = os.fork() pid = os.fork()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment