"src/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "3fef5d27d32ef2cec0871e6676f5c09c7e91fe02"
Unverified Commit 0b4935d4 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

Test shared-mem on kvstore (#976)

parent 7897fa3b
...@@ -12,15 +12,17 @@ client_namebook = { 0:'127.0.0.1:50061' } ...@@ -12,15 +12,17 @@ client_namebook = { 0:'127.0.0.1:50061' }
server_namebook = { 0:'127.0.0.1:50062' } server_namebook = { 0:'127.0.0.1:50062' }
def start_server(): def start_server(server_embed):
server = dgl.contrib.KVServer( server = dgl.contrib.KVServer(
server_id=0, server_id=0,
client_namebook=client_namebook, client_namebook=client_namebook,
server_addr=server_namebook[0]) server_addr=server_namebook[0])
server.init_data(name='server_embed', data_tensor=server_embed)
server.start() server.start()
def start_client(): def start_client(server_embed):
client = dgl.contrib.KVClient( client = dgl.contrib.KVClient(
client_id=0, client_id=0,
server_namebook=server_namebook, server_namebook=server_namebook,
...@@ -38,6 +40,7 @@ def start_client(): ...@@ -38,6 +40,7 @@ def start_client():
for i in range(5): for i in range(5):
client.push(name='embed_0', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_0) client.push(name='embed_0', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_0)
client.push(name='embed_1', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1) client.push(name='embed_1', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1)
client.push(name='server_embed', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1)
client.barrier() client.barrier()
...@@ -65,23 +68,40 @@ def start_client(): ...@@ -65,23 +68,40 @@ def start_client():
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4])) client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4])) client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4])) client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
client.pull(name='server_embed', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
_, tensor_0 = client.pull_wait() _, tensor_0 = client.pull_wait()
_, tensor_1 = client.pull_wait() _, tensor_1 = client.pull_wait()
_, tensor_2 = client.pull_wait() _, tensor_2 = client.pull_wait()
_, tensor_3 = client.pull_wait() _, tensor_3 = client.pull_wait()
_, tensor_4 = client.pull_wait()
target_tensor_2 = th.tensor([ 2., 2., 7., 2., 12.])
assert th.equal(tensor_0, target_tensor_0) == True
assert th.equal(tensor_1, target_tensor_1) == True
assert th.equal(tensor_2, target_tensor_0) == True
assert th.equal(tensor_3, target_tensor_1) == True
assert th.equal(tensor_4, target_tensor_2) == True
server_embed += target_tensor_2
assert th.equal(new_tensor_0, target_tensor_0) == True client.pull(name='server_embed', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
assert th.equal(new_tensor_1, target_tensor_1) == True _, tensor_5 = client.pull_wait()
assert th.equal(new_tensor_2, target_tensor_0) == True
assert th.equal(new_tensor_3, target_tensor_1) == True assert th.equal(tensor_5, target_tensor_2 * 2) == True
client.shut_down() client.shut_down()
if __name__ == '__main__': if __name__ == '__main__':
server_embed = th.tensor([2., 2., 2., 2., 2.])
server_embed.share_memory_()
pid = os.fork() pid = os.fork()
if pid == 0: if pid == 0:
start_server() start_server(server_embed)
else: else:
time.sleep(2) # wait server start time.sleep(2) # wait server start
start_client() start_client(server_embed)
assert th.equal(server_embed, th.tensor([ 4., 4., 14., 4., 24.])) == True
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment