rpc_basic.py 4.26 KB
Newer Older
1
import os
2

3
4
import backend as F

5
import dgl
6
from numpy.testing import assert_array_equal
7

8
INTEGER = 2
9
STR = "hello world!"
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
HELLO_SERVICE_ID = 901231
TENSOR = F.zeros((1000, 1000), F.int64, F.cpu())


def tensor_func(tensor):
    return tensor * 2


class HelloResponse(dgl.distributed.Response):
    def __init__(self, hello_str, integer, tensor):
        self.hello_str = hello_str
        self.integer = integer
        self.tensor = tensor

    def __getstate__(self):
        return self.hello_str, self.integer, self.tensor

    def __setstate__(self, state):
        self.hello_str, self.integer, self.tensor = state


class HelloRequest(dgl.distributed.Request):
    def __init__(self, hello_str, integer, tensor, func):
        self.hello_str = hello_str
        self.integer = integer
        self.tensor = tensor
        self.func = func

    def __getstate__(self):
        return self.hello_str, self.integer, self.tensor, self.func

    def __setstate__(self, state):
        self.hello_str, self.integer, self.tensor, self.func = state

    def process_request(self, server_state):
        assert self.hello_str == STR
        assert self.integer == INTEGER
        new_tensor = self.func(self.tensor)
        res = HelloResponse(self.hello_str, self.integer, new_tensor)
        return res


52
def start_server(server_id, ip_config, num_servers, num_clients, keep_alive):
53
    server_state = dgl.distributed.ServerState(
54
55
        None, local_g=None, partition_book=None, keep_alive=keep_alive
    )
56
    dgl.distributed.register_service(
57
58
        HELLO_SERVICE_ID, HelloRequest, HelloResponse
    )
59
    print("Start server {}".format(server_id))
60
61
62
63
64
65
66
    dgl.distributed.start_server(
        server_id=server_id,
        ip_config=ip_config,
        num_servers=num_servers,
        num_clients=num_clients,
        server_state=server_state,
    )
67
68


69
def start_client(ip_config, num_servers, group_id):
70
    dgl.distributed.register_service(
71
72
        HELLO_SERVICE_ID, HelloRequest, HelloResponse
    )
73
    dgl.distributed.connect_to_server(
74
75
76
77
        ip_config=ip_config,
        num_servers=num_servers,
        group_id=group_id,
    )
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    req = HelloRequest(STR, INTEGER, TENSOR, tensor_func)
    server_namebook = dgl.distributed.read_ip_config(ip_config, num_servers)
    for server_id in server_namebook.keys():
        # test send and recv
        dgl.distributed.send_request(server_id, req)
        res = dgl.distributed.recv_response()
        assert res.hello_str == STR
        assert res.integer == INTEGER
        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
        # test remote_call
        target_and_requests = []
        for i in range(10):
            target_and_requests.append((server_id, req))
        res_list = dgl.distributed.remote_call(target_and_requests)
        for res in res_list:
            assert res.hello_str == STR
            assert res.integer == INTEGER
            assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
        # test send_request_to_machine
        dgl.distributed.send_request_to_machine(server_id, req)
        res = dgl.distributed.recv_response()
        assert res.hello_str == STR
        assert res.integer == INTEGER
        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
        # test remote_call_to_machine
        target_and_requests = []
        for i in range(10):
            target_and_requests.append((server_id, req))
        res_list = dgl.distributed.remote_call_to_machine(target_and_requests)
        for res in res_list:
            assert res.hello_str == STR
            assert res.integer == INTEGER
            assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))


def main():
114
115
116
117
118
119
    ip_config = os.environ.get("DIST_DGL_TEST_IP_CONFIG")
    num_servers = int(os.environ.get("DIST_DGL_TEST_NUM_SERVERS"))
    if os.environ.get("DIST_DGL_TEST_ROLE", "server") == "server":
        server_id = int(os.environ.get("DIST_DGL_TEST_SERVER_ID"))
        num_clients = int(os.environ.get("DIST_DGL_TEST_NUM_CLIENTS"))
        keep_alive = "DIST_DGL_TEST_KEEP_ALIVE" in os.environ
120
        start_server(server_id, ip_config, num_servers, num_clients, keep_alive)
121
    else:
122
        group_id = int(os.environ.get("DIST_DGL_TEST_GROUP_ID", "0"))
123
        start_client(ip_config, num_servers, group_id)
124
125


126
if __name__ == "__main__":
127
    main()