test_kvstore.py 3.39 KB
Newer Older
Chao Ma's avatar
Chao Ma committed
1
2
3
4
import backend as F
import numpy as np
import scipy as sp
import dgl
5
import torch as th
Chao Ma's avatar
Chao Ma committed
6
7
8
9
10
11
12
13
14
from dgl import utils

import os
import time

client_namebook = { 0:'127.0.0.1:50061' }

server_namebook = { 0:'127.0.0.1:50062' }

Chao Ma's avatar
Chao Ma committed
15
def start_server(server_embed):
Chao Ma's avatar
Chao Ma committed
16
17
18
19
20
    server = dgl.contrib.KVServer(
        server_id=0, 
        client_namebook=client_namebook, 
        server_addr=server_namebook[0])

Chao Ma's avatar
Chao Ma committed
21
22
    server.init_data(name='server_embed', data_tensor=server_embed)

Chao Ma's avatar
Chao Ma committed
23
24
    server.start()

Chao Ma's avatar
Chao Ma committed
25
def start_client(server_embed):
Chao Ma's avatar
Chao Ma committed
26
27
28
29
30
31
32
    client = dgl.contrib.KVClient(
        client_id=0, 
        server_namebook=server_namebook, 
        client_addr=client_namebook[0])

    client.connect()

33
34
35
    # Initialize data on server
    client.init_data(name='embed_0', server_id=0, shape=[5, 3], init_type='zero')
    client.init_data(name='embed_1', server_id=0, shape=[5], init_type='uniform', low=0.0, high=0.0)
Chao Ma's avatar
Chao Ma committed
36

37
38
    data_0 = th.tensor([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]])
    data_1 = th.tensor([0., 1., 2.])
Chao Ma's avatar
Chao Ma committed
39
40

    for i in range(5):
41
42
        client.push(name='embed_0', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_0)
        client.push(name='embed_1', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1)
Chao Ma's avatar
Chao Ma committed
43
        client.push(name='server_embed', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1)
Chao Ma's avatar
Chao Ma committed
44

45
46
47
    client.barrier()

    client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
Chao Ma's avatar
Chao Ma committed
48
49
    msg = client.pull_wait()
    assert msg.rank == 0
50

Chao Ma's avatar
Chao Ma committed
51
    target_tensor_0 = th.tensor(
52
53
54
55
56
57
        [[ 0., 0., 0.],
         [ 0., 0., 0.],
         [ 5., 5., 5.],
         [ 0., 0., 0.],
         [10., 10., 10.]])

Chao Ma's avatar
Chao Ma committed
58
    assert th.equal(msg.data, target_tensor_0) == True
59
60

    client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
Chao Ma's avatar
Chao Ma committed
61
    msg = client.pull_wait()
62

Chao Ma's avatar
Chao Ma committed
63
    target_tensor_1 = th.tensor([ 0., 0., 5., 0., 10.])
64

Chao Ma's avatar
Chao Ma committed
65
    assert th.equal(msg.data, target_tensor_1) == True
Chao Ma's avatar
Chao Ma committed
66
67
68
69
70

    client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
    client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
    client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
    client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
Chao Ma's avatar
Chao Ma committed
71
    client.pull(name='server_embed', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
Chao Ma's avatar
Chao Ma committed
72

Chao Ma's avatar
Chao Ma committed
73
74
75
76
77
    msg_0 = client.pull_wait()
    msg_1 = client.pull_wait()
    msg_2 = client.pull_wait()
    msg_3 = client.pull_wait()
    msg_4 = client.pull_wait()
Chao Ma's avatar
Chao Ma committed
78
79
80

    target_tensor_2 = th.tensor([ 2., 2., 7., 2., 12.])

Chao Ma's avatar
Chao Ma committed
81
82
83
84
85
    assert th.equal(msg_0.data, target_tensor_0) == True
    assert th.equal(msg_1.data, target_tensor_1) == True
    assert th.equal(msg_2.data, target_tensor_0) == True
    assert th.equal(msg_3.data, target_tensor_1) == True
    assert th.equal(msg_4.data, target_tensor_2) == True
Chao Ma's avatar
Chao Ma committed
86
87

    server_embed += target_tensor_2
Chao Ma's avatar
Chao Ma committed
88

Chao Ma's avatar
Chao Ma committed
89
    client.pull(name='server_embed', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
Chao Ma's avatar
Chao Ma committed
90
    msg_5 = client.pull_wait()
Chao Ma's avatar
Chao Ma committed
91

Chao Ma's avatar
Chao Ma committed
92
    assert th.equal(msg_5.data, target_tensor_2 * 2) == True
Chao Ma's avatar
Chao Ma committed
93

Chao Ma's avatar
Chao Ma committed
94
95
96
    client.shut_down()

if __name__ == '__main__':
Chao Ma's avatar
Chao Ma committed
97
    server_embed = th.tensor([2., 2., 2., 2., 2.])
Chao Ma's avatar
Chao Ma committed
98
    # use pytorch shared memory
Chao Ma's avatar
Chao Ma committed
99
100
    server_embed.share_memory_()

Chao Ma's avatar
Chao Ma committed
101
102
    pid = os.fork()
    if pid == 0:
Chao Ma's avatar
Chao Ma committed
103
        start_server(server_embed)
Chao Ma's avatar
Chao Ma committed
104
105
    else:
        time.sleep(2) # wait server start
Chao Ma's avatar
Chao Ma committed
106
107
108
        start_client(server_embed)

    assert th.equal(server_embed, th.tensor([ 4., 4., 14., 4., 24.])) == True