"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9e9ed353a2859550815c71b45e801efbccc350c2"
rpc_basic.py 4.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import dgl
import backend as F
from numpy.testing import assert_array_equal

INTEGER = 2
STR = 'hello world!'
HELLO_SERVICE_ID = 901231
TENSOR = F.zeros((1000, 1000), F.int64, F.cpu())


def tensor_func(tensor):
    return tensor * 2


class HelloResponse(dgl.distributed.Response):
    def __init__(self, hello_str, integer, tensor):
        self.hello_str = hello_str
        self.integer = integer
        self.tensor = tensor

    def __getstate__(self):
        return self.hello_str, self.integer, self.tensor

    def __setstate__(self, state):
        self.hello_str, self.integer, self.tensor = state


class HelloRequest(dgl.distributed.Request):
    def __init__(self, hello_str, integer, tensor, func):
        self.hello_str = hello_str
        self.integer = integer
        self.tensor = tensor
        self.func = func

    def __getstate__(self):
        return self.hello_str, self.integer, self.tensor, self.func

    def __setstate__(self, state):
        self.hello_str, self.integer, self.tensor, self.func = state

    def process_request(self, server_state):
        assert self.hello_str == STR
        assert self.integer == INTEGER
        new_tensor = self.func(self.tensor)
        res = HelloResponse(self.hello_str, self.integer, new_tensor)
        return res


def start_server(server_id, ip_config, num_servers, num_clients, net_type, keep_alive):
    server_state = dgl.distributed.ServerState(
        None, local_g=None, partition_book=None, keep_alive=keep_alive)
    dgl.distributed.register_service(
        HELLO_SERVICE_ID, HelloRequest, HelloResponse)
    print("Start server {}".format(server_id))
    dgl.distributed.start_server(server_id=server_id,
                                 ip_config=ip_config,
                                 num_servers=num_servers,
                                 num_clients=num_clients,
                                 server_state=server_state,
                                 net_type=net_type)


def start_client(ip_config, num_servers, group_id, net_type):
    dgl.distributed.register_service(
        HELLO_SERVICE_ID, HelloRequest, HelloResponse)
    dgl.distributed.connect_to_server(
        ip_config=ip_config, num_servers=num_servers, group_id=group_id, net_type=net_type)
    req = HelloRequest(STR, INTEGER, TENSOR, tensor_func)
    server_namebook = dgl.distributed.read_ip_config(ip_config, num_servers)
    for server_id in server_namebook.keys():
        # test send and recv
        dgl.distributed.send_request(server_id, req)
        res = dgl.distributed.recv_response()
        assert res.hello_str == STR
        assert res.integer == INTEGER
        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
        # test remote_call
        target_and_requests = []
        for i in range(10):
            target_and_requests.append((server_id, req))
        res_list = dgl.distributed.remote_call(target_and_requests)
        for res in res_list:
            assert res.hello_str == STR
            assert res.integer == INTEGER
            assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
        # test send_request_to_machine
        dgl.distributed.send_request_to_machine(server_id, req)
        res = dgl.distributed.recv_response()
        assert res.hello_str == STR
        assert res.integer == INTEGER
        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
        # test remote_call_to_machine
        target_and_requests = []
        for i in range(10):
            target_and_requests.append((server_id, req))
        res_list = dgl.distributed.remote_call_to_machine(target_and_requests)
        for res in res_list:
            assert res.hello_str == STR
            assert res.integer == INTEGER
            assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))


def main():
    ip_config = os.environ.get('DIST_DGL_TEST_IP_CONFIG')
    num_servers = int(os.environ.get('DIST_DGL_TEST_NUM_SERVERS'))
    net_type = os.environ.get('DIST_DGL_TEST_NET_TYPE', 'tensorpipe')
    if os.environ.get('DIST_DGL_TEST_ROLE', 'server') == 'server':
        server_id = int(os.environ.get('DIST_DGL_TEST_SERVER_ID'))
        num_clients = int(os.environ.get('DIST_DGL_TEST_NUM_CLIENTS'))
        keep_alive = 'DIST_DGL_TEST_KEEP_ALIVE' in os.environ
        start_server(server_id, ip_config, num_servers,
                     num_clients, net_type, keep_alive)
    else:
        group_id = int(os.environ.get('DIST_DGL_TEST_GROUP_ID', '0'))
        start_client(ip_config, num_servers, group_id, net_type)


if __name__ == '__main__':
    main()