client.py 2.09 KB
Newer Older
1
import os
Chao Ma's avatar
Chao Ma committed
2
import argparse
3
import time
Chao Ma's avatar
Chao Ma committed
4

5
6
7
8
9
10
11
import dgl
from dgl.contrib import KVClient

import torch as th

partition = th.tensor([0,0,1,1,2,2,3,3])

12
13
14
15
16
ID = []
ID.append(th.tensor([0,1]))
ID.append(th.tensor([2,3]))
ID.append(th.tensor([4,5]))
ID.append(th.tensor([6,7]))
Chao Ma's avatar
Chao Ma committed
17

18
19
20
21
22
23
DATA = []
DATA.append(th.tensor([[1.,1.,1.,],[1.,1.,1.,]]))
DATA.append(th.tensor([[2.,2.,2.,],[2.,2.,2.,]]))
DATA.append(th.tensor([[3.,3.,3.,],[3.,3.,3.,]]))
DATA.append(th.tensor([[4.,4.,4.,],[4.,4.,4.,]]))

Chao Ma's avatar
Chao Ma committed
24

25
26
27
class ArgParser(argparse.ArgumentParser):
    def __init__(self):
        super(ArgParser, self).__init__()
Chao Ma's avatar
Chao Ma committed
28

29
30
31
32
        self.add_argument('--ip_config', type=str, default='ip_config.txt',
                          help='IP configuration file of kvstore.')
        self.add_argument('--num_worker', type=int, default=2,
                          help='Number of worker (client nodes) on single-machine.')
33
34


35
36
37
38
def start_client(args):
    """Start client
    """
    server_namebook = dgl.contrib.read_ip_config(filename=args.ip_config)
39

40
    my_client = KVClient(server_namebook=server_namebook)
41

42
    my_client.connect()
Chao Ma's avatar
Chao Ma committed
43

44
45
46
47
48
    if my_client.get_id() % args.num_worker == 0:
        my_client.set_partition_book(name='entity_embed', partition_book=partition)
    else:
        time.sleep(3)
        my_client.set_partition_book(name='entity_embed')
Chao Ma's avatar
Chao Ma committed
49

50
    my_client.print()
51

52
    my_client.barrier()
53

54
    print("send request...")
Chao Ma's avatar
Chao Ma committed
55

56
57
    for i in range(4):
        my_client.push(name='entity_embed', id_tensor=ID[i], data_tensor=DATA[i])
Chao Ma's avatar
Chao Ma committed
58

59
    my_client.barrier()
60

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    if my_client.get_id() % args.num_worker == 0:
        res = my_client.pull(name='entity_embed', id_tensor=th.tensor([0,1,2,3,4,5,6,7]))
        print(res)

    my_client.barrier()

    my_client.push(name='entity_embed', id_tensor=ID[my_client.get_machine_id()], data_tensor=th.tensor([[0.,0.,0.],[0.,0.,0.]]))

    my_client.barrier()

    if my_client.get_id() % args.num_worker == 0:
        res = my_client.pull(name='entity_embed', id_tensor=th.tensor([0,1,2,3,4,5,6,7]))
        print(res)

        my_client.shut_down()


if __name__ == '__main__':
    args = ArgParser().parse_args()
    start_client(args)