"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "15782fd506e8c4a7c2b288fc2e558bd77fdfa51a"
Unverified Commit a0193fd5 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

Small change for kvstore api (#981)

* Small change for kvstore api

* fix ci

* fix ci
parent 0b4935d4
......@@ -37,31 +37,31 @@ def start_client(args):
if client.get_id() == 0:
client.pull(name='embed_0', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
server_id, new_tensor_0 = client.pull_wait()
assert server_id == 0
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='embed_0', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4, 5], dtype='int64'))
server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("embed_0:")
print(mx.nd.concat(new_tensor_0, new_tensor_1, dim=0))
print(mx.nd.concat(msg_0.data, msg_1.data, dim=0))
client.pull(name='embed_1', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
server_id, new_tensor_0 = client.pull_wait()
assert server_id == 0
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='embed_1', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4, 5], dtype='int64'))
server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("embed_1:")
print(mx.nd.concat(new_tensor_0, new_tensor_1, dim=0))
print(mx.nd.concat(msg_0.data, msg_1.data, dim=0))
client.pull(name='server_embed', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
server_id, new_tensor_0 = client.pull_wait()
assert server_id == 0
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='server_embed', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("server_embed:")
print(mx.nd.concat(new_tensor_0, new_tensor_1, dim=0))
print(mx.nd.concat(msg_0.data, msg_1.data, dim=0))
# Shut-down all the servers
if client.get_id() == 0:
......
......@@ -37,31 +37,31 @@ def start_client(args):
if client.get_id() == 0:
client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor_0 = client.pull_wait()
assert server_id == 0
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='embed_0', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4, 5]))
server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("embed_0:")
print(th.cat([new_tensor_0, new_tensor_1]))
print(th.cat([msg_0.data, msg_1.data]))
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor_0 = client.pull_wait()
assert server_id == 0
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='embed_1', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4, 5]))
server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("embed_1:")
print(th.cat([new_tensor_0, new_tensor_1]))
print(th.cat([msg_0.data, msg_1.data]))
client.pull(name='server_embed', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor_0 = client.pull_wait()
assert server_id == 0
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='server_embed', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor_1 = client.pull_wait()
assert server_id == 1
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("server_embed:")
print(th.cat([new_tensor_0, new_tensor_1]))
print(th.cat([msg_0.data, msg_1.data]))
# Shut-down all the servers
if client.get_id() == 0:
......
......@@ -407,7 +407,7 @@ class KVClient(object):
"""
msg = _recv_kv_msg(self._receiver)
assert msg.type == KVMsgType.PULL_BACK, 'Recv kv msg error.'
return msg.rank, msg.data
return msg
def barrier(self):
"""Barrier for all client nodes
......
......@@ -45,8 +45,8 @@ def start_client(server_embed):
client.barrier()
client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor = client.pull_wait()
assert server_id == 0
msg = client.pull_wait()
assert msg.rank == 0
target_tensor_0 = th.tensor(
[[ 0., 0., 0.],
......@@ -55,14 +55,14 @@ def start_client(server_embed):
[ 0., 0., 0.],
[10., 10., 10.]])
assert th.equal(new_tensor, target_tensor_0) == True
assert th.equal(msg.data, target_tensor_0) == True
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
server_id, new_tensor = client.pull_wait()
msg = client.pull_wait()
target_tensor_1 = th.tensor([ 0., 0., 5., 0., 10.])
assert th.equal(new_tensor, target_tensor_1) == True
assert th.equal(msg.data, target_tensor_1) == True
client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
......@@ -70,31 +70,32 @@ def start_client(server_embed):
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
client.pull(name='server_embed', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
_, tensor_0 = client.pull_wait()
_, tensor_1 = client.pull_wait()
_, tensor_2 = client.pull_wait()
_, tensor_3 = client.pull_wait()
_, tensor_4 = client.pull_wait()
msg_0 = client.pull_wait()
msg_1 = client.pull_wait()
msg_2 = client.pull_wait()
msg_3 = client.pull_wait()
msg_4 = client.pull_wait()
target_tensor_2 = th.tensor([ 2., 2., 7., 2., 12.])
assert th.equal(tensor_0, target_tensor_0) == True
assert th.equal(tensor_1, target_tensor_1) == True
assert th.equal(tensor_2, target_tensor_0) == True
assert th.equal(tensor_3, target_tensor_1) == True
assert th.equal(tensor_4, target_tensor_2) == True
assert th.equal(msg_0.data, target_tensor_0) == True
assert th.equal(msg_1.data, target_tensor_1) == True
assert th.equal(msg_2.data, target_tensor_0) == True
assert th.equal(msg_3.data, target_tensor_1) == True
assert th.equal(msg_4.data, target_tensor_2) == True
server_embed += target_tensor_2
client.pull(name='server_embed', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
_, tensor_5 = client.pull_wait()
msg_5 = client.pull_wait()
assert th.equal(tensor_5, target_tensor_2 * 2) == True
assert th.equal(msg_5.data, target_tensor_2 * 2) == True
client.shut_down()
if __name__ == '__main__':
server_embed = th.tensor([2., 2., 2., 2., 2.])
# use pytorch shared memory
server_embed.share_memory_()
pid = os.fork()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment