rpc_basic.py 4.45 KB
Newer Older
1
import os
2

3
4
import backend as F

5
import dgl
6
from numpy.testing import assert_array_equal
7

8
INTEGER = 2
9
STR = "hello world!"
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
HELLO_SERVICE_ID = 901231
TENSOR = F.zeros((1000, 1000), F.int64, F.cpu())


def tensor_func(tensor):
    return tensor * 2


class HelloResponse(dgl.distributed.Response):
    def __init__(self, hello_str, integer, tensor):
        self.hello_str = hello_str
        self.integer = integer
        self.tensor = tensor

    def __getstate__(self):
        return self.hello_str, self.integer, self.tensor

    def __setstate__(self, state):
        self.hello_str, self.integer, self.tensor = state


class HelloRequest(dgl.distributed.Request):
    def __init__(self, hello_str, integer, tensor, func):
        self.hello_str = hello_str
        self.integer = integer
        self.tensor = tensor
        self.func = func

    def __getstate__(self):
        return self.hello_str, self.integer, self.tensor, self.func

    def __setstate__(self, state):
        self.hello_str, self.integer, self.tensor, self.func = state

    def process_request(self, server_state):
        assert self.hello_str == STR
        assert self.integer == INTEGER
        new_tensor = self.func(self.tensor)
        res = HelloResponse(self.hello_str, self.integer, new_tensor)
        return res


52
53
54
def start_server(
    server_id, ip_config, num_servers, num_clients, net_type, keep_alive
):
55
    server_state = dgl.distributed.ServerState(
56
57
        None, local_g=None, partition_book=None, keep_alive=keep_alive
    )
58
    dgl.distributed.register_service(
59
60
        HELLO_SERVICE_ID, HelloRequest, HelloResponse
    )
61
    print("Start server {}".format(server_id))
62
63
64
65
66
67
68
69
    dgl.distributed.start_server(
        server_id=server_id,
        ip_config=ip_config,
        num_servers=num_servers,
        num_clients=num_clients,
        server_state=server_state,
        net_type=net_type,
    )
70
71
72
73


def start_client(ip_config, num_servers, group_id, net_type):
    dgl.distributed.register_service(
74
75
        HELLO_SERVICE_ID, HelloRequest, HelloResponse
    )
76
    dgl.distributed.connect_to_server(
77
78
79
80
81
        ip_config=ip_config,
        num_servers=num_servers,
        group_id=group_id,
        net_type=net_type,
    )
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    req = HelloRequest(STR, INTEGER, TENSOR, tensor_func)
    server_namebook = dgl.distributed.read_ip_config(ip_config, num_servers)
    for server_id in server_namebook.keys():
        # test send and recv
        dgl.distributed.send_request(server_id, req)
        res = dgl.distributed.recv_response()
        assert res.hello_str == STR
        assert res.integer == INTEGER
        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
        # test remote_call
        target_and_requests = []
        for i in range(10):
            target_and_requests.append((server_id, req))
        res_list = dgl.distributed.remote_call(target_and_requests)
        for res in res_list:
            assert res.hello_str == STR
            assert res.integer == INTEGER
            assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
        # test send_request_to_machine
        dgl.distributed.send_request_to_machine(server_id, req)
        res = dgl.distributed.recv_response()
        assert res.hello_str == STR
        assert res.integer == INTEGER
        assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
        # test remote_call_to_machine
        target_and_requests = []
        for i in range(10):
            target_and_requests.append((server_id, req))
        res_list = dgl.distributed.remote_call_to_machine(target_and_requests)
        for res in res_list:
            assert res.hello_str == STR
            assert res.integer == INTEGER
            assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))


def main():
118
119
120
121
122
123
124
125
126
127
    ip_config = os.environ.get("DIST_DGL_TEST_IP_CONFIG")
    num_servers = int(os.environ.get("DIST_DGL_TEST_NUM_SERVERS"))
    net_type = os.environ.get("DIST_DGL_TEST_NET_TYPE", "tensorpipe")
    if os.environ.get("DIST_DGL_TEST_ROLE", "server") == "server":
        server_id = int(os.environ.get("DIST_DGL_TEST_SERVER_ID"))
        num_clients = int(os.environ.get("DIST_DGL_TEST_NUM_CLIENTS"))
        keep_alive = "DIST_DGL_TEST_KEEP_ALIVE" in os.environ
        start_server(
            server_id, ip_config, num_servers, num_clients, net_type, keep_alive
        )
128
    else:
129
        group_id = int(os.environ.get("DIST_DGL_TEST_GROUP_ID", "0"))
130
131
132
        start_client(ip_config, num_servers, group_id, net_type)


133
if __name__ == "__main__":
134
    main()