test_kvstore.py 3.4 KB
Newer Older
Chao Ma's avatar
Chao Ma committed
1
import dgl
2
import argparse
3
import torch as th
Chao Ma's avatar
Chao Ma committed
4
import time
5
import backend as F
Chao Ma's avatar
Chao Ma committed
6

7
from multiprocessing import Process
Chao Ma's avatar
Chao Ma committed
8

9
10
11
12
13
ID = []
ID.append(th.tensor([0,1]))
ID.append(th.tensor([2,3]))
ID.append(th.tensor([4,5]))
ID.append(th.tensor([6,7]))
Chao Ma's avatar
Chao Ma committed
14

15
16
17
18
19
DATA = []
DATA.append(th.tensor([[1.,1.,1.,],[1.,1.,1.,]]))
DATA.append(th.tensor([[2.,2.,2.,],[2.,2.,2.,]]))
DATA.append(th.tensor([[3.,3.,3.,],[3.,3.,3.,]]))
DATA.append(th.tensor([[4.,4.,4.,],[4.,4.,4.,]]))
Chao Ma's avatar
Chao Ma committed
20

21
22
edata_partition_book = {'edata':th.tensor([0,0,1,1,2,2,3,3])}
ndata_partition_book = {'ndata':th.tensor([0,0,1,1,2,2,3,3])}
Chao Ma's avatar
Chao Ma committed
23

24
25
ndata_g2l = []
edata_g2l = []
Chao Ma's avatar
Chao Ma committed
26

27
28
29
30
ndata_g2l.append({'ndata':th.tensor([0,1,0,0,0,0,0,0])})
ndata_g2l.append({'ndata':th.tensor([0,0,0,1,0,0,0,0])})
ndata_g2l.append({'ndata':th.tensor([0,0,0,0,0,1,0,0])})
ndata_g2l.append({'ndata':th.tensor([0,0,0,0,0,0,0,1])})
Chao Ma's avatar
Chao Ma committed
31

32
33
34
35
edata_g2l.append({'edata':th.tensor([0,1,0,0,0,0,0,0])})
edata_g2l.append({'edata':th.tensor([0,0,0,1,0,0,0,0])})
edata_g2l.append({'edata':th.tensor([0,0,0,0,0,1,0,0])})
edata_g2l.append({'edata':th.tensor([0,0,0,0,0,0,0,1])})
Chao Ma's avatar
Chao Ma committed
36

37
38
def start_client(flag):
    time.sleep(3)
Chao Ma's avatar
Chao Ma committed
39

40
41
42
43
    client = dgl.contrib.start_client(ip_config='ip_config.txt', 
                                      ndata_partition_book=ndata_partition_book, 
                                      edata_partition_book=edata_partition_book,
                                      close_shared_mem=flag)
Chao Ma's avatar
Chao Ma committed
44

45
46
    client.push(name='edata', id_tensor=ID[client.get_id()], data_tensor=DATA[client.get_id()])
    client.push(name='ndata', id_tensor=ID[client.get_id()], data_tensor=DATA[client.get_id()])
Chao Ma's avatar
Chao Ma committed
47

48
49
    client.barrier()

50
51
    tensor_edata = client.pull(name='edata', id_tensor=th.tensor([0,1,2,3,4,5,6,7]))
    tensor_ndata = client.pull(name='ndata', id_tensor=th.tensor([0,1,2,3,4,5,6,7]))
52
53


54
55
56
57
58
59
60
61
    target_tensor = th.tensor([[1., 1., 1.],
                               [1., 1., 1.],
                               [2., 2., 2.],
                               [2., 2., 2.],
                               [3., 3., 3.],
                               [3., 3., 3.],
                               [4., 4., 4.],
                               [4., 4., 4.]])
Chao Ma's avatar
Chao Ma committed
62

63
    assert F.array_equal(tensor_edata, target_tensor)
Chao Ma's avatar
Chao Ma committed
64

65
    assert F.array_equal(tensor_ndata, target_tensor)
Chao Ma's avatar
Chao Ma committed
66

67
    client.barrier()
Chao Ma's avatar
Chao Ma committed
68

69
70
    if client.get_id() == 0:
        client.shut_down()
Chao Ma's avatar
Chao Ma committed
71

72
73
74
75
76
77
78
79
80
81
def start_server(server_id, num_client):
    
    dgl.contrib.start_server(
        server_id=server_id,
        ip_config='ip_config.txt',
        num_client=num_client,
        ndata={'ndata':th.tensor([[0.,0.,0.],[0.,0.,0.]])},
        edata={'edata':th.tensor([[0.,0.,0.],[0.,0.,0.]])},
        ndata_g2l=ndata_g2l[server_id],
        edata_g2l=edata_g2l[server_id])
Chao Ma's avatar
Chao Ma committed
82
83

if __name__ == '__main__':
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

    # server process
    p0 = Process(target=start_server, args=(0, 4))
    p1 = Process(target=start_server, args=(1, 4))
    p2 = Process(target=start_server, args=(2, 4))
    p3 = Process(target=start_server, args=(3, 4))

    # client process
    p4 = Process(target=start_client, args=(True,))
    p5 = Process(target=start_client, args=(True,))
    p6 = Process(target=start_client, args=(False,))
    p7 = Process(target=start_client, args=(False,))


    # start server process
    p0.start()
    p1.start()
    p2.start()
    p3.start()

    # start client process
    p4.start()
    p5.start()
    p6.start()
    p7.start()


    p0.join()
    p1.join()
    p2.join()
    p3.join()

    p4.join()
    p5.join()
    p6.join()
    p7.join()