test_kvstore.py 1.9 KB
Newer Older
Chao Ma's avatar
Chao Ma committed
1
2
3
4
import backend as F
import numpy as np
import scipy as sp
import dgl
5
import torch as th
Chao Ma's avatar
Chao Ma committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from dgl import utils

import os
import time

client_namebook = { 0:'127.0.0.1:50061' }

server_namebook = { 0:'127.0.0.1:50062' }

def start_server():
    server = dgl.contrib.KVServer(
        server_id=0, 
        client_namebook=client_namebook, 
        server_addr=server_namebook[0])

    server.start()

def start_client():
    client = dgl.contrib.KVClient(
        client_id=0, 
        server_namebook=server_namebook, 
        client_addr=client_namebook[0])

    client.connect()

31
32
33
    # Initialize data on server
    client.init_data(name='embed_0', server_id=0, shape=[5, 3], init_type='zero')
    client.init_data(name='embed_1', server_id=0, shape=[5], init_type='uniform', low=0.0, high=0.0)
Chao Ma's avatar
Chao Ma committed
34

35
36
    data_0 = th.tensor([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]])
    data_1 = th.tensor([0., 1., 2.])
Chao Ma's avatar
Chao Ma committed
37
38

    for i in range(5):
39
40
        client.push(name='embed_0', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_0)
        client.push(name='embed_1', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1)
Chao Ma's avatar
Chao Ma committed
41

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    client.barrier()

    client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
    server_id, new_tensor = client.pull_wait()
    assert server_id == 0

    target_tensor = th.tensor(
        [[ 0., 0., 0.],
         [ 0., 0., 0.],
         [ 5., 5., 5.],
         [ 0., 0., 0.],
         [10., 10., 10.]])

    assert th.equal(new_tensor, target_tensor) == True

    client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
    server_id, new_tensor = client.pull_wait()

    target_tensor = th.tensor([ 0., 0., 5., 0., 10.])

    assert th.equal(new_tensor, target_tensor) == True
Chao Ma's avatar
Chao Ma committed
63

Chao Ma's avatar
Chao Ma committed
64
65
66
67
68
69
70
71
72
    client.shut_down()

if __name__ == '__main__':
    pid = os.fork()
    if pid == 0:
        start_server()
    else:
        time.sleep(2) # wait server start
        start_client()