Unverified Commit a67ba946 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[RPC] Add send_request_to_machine() and remote_call_to_machine() (#1619)

* add send_request_to_machine()

* update

* update

* update

* update

* update

* update

* fix lint

* update
parent 50962c43
"""Define distributed kvstore""" """Define distributed kvstore"""
import os import os
import random
import numpy as np import numpy as np
from . import rpc from . import rpc
...@@ -993,10 +992,7 @@ class KVClient(object): ...@@ -993,10 +992,7 @@ class KVClient(object):
local_data = partial_data local_data = partial_data
else: # push data to remote server else: # push data to remote server
request = PushRequest(name, partial_id, partial_data) request = PushRequest(name, partial_id, partial_data)
# randomly select a server node in target machine for load-balance rpc.send_request_to_machine(machine_idx, request)
server_id = random.randint(machine_idx*self._group_count, \
(machine_idx+1)*self._group_count-1)
rpc.send_request(server_id, request)
start += count[idx] start += count[idx]
if local_id is not None: # local push if local_id is not None: # local push
self._push_handler(self._data_store, name, local_id, local_data) self._push_handler(self._data_store, name, local_id, local_data)
...@@ -1041,10 +1037,7 @@ class KVClient(object): ...@@ -1041,10 +1037,7 @@ class KVClient(object):
local_id = self._part_policy[name].to_local(partial_id) local_id = self._part_policy[name].to_local(partial_id)
else: # pull data from remote server else: # pull data from remote server
request = PullRequest(name, partial_id) request = PullRequest(name, partial_id)
# randomly select a server node in target machine for load-balance rpc.send_request_to_machine(machine_idx, request)
server_id = random.randint(machine_idx*self._group_count, \
(machine_idx+1)*self._group_count-1)
rpc.send_request(server_id, request)
pull_count += 1 pull_count += 1
start += count[idx] start += count[idx]
# recv response # recv response
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
server and clients.""" server and clients."""
import abc import abc
import pickle import pickle
import random
from .._ffi.object import register_object, ObjectBase from .._ffi.object import register_object, ObjectBase
from .._ffi.function import _init_api from .._ffi.function import _init_api
...@@ -12,7 +13,8 @@ __all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \ ...@@ -12,7 +13,8 @@ __all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \
'create_sender', 'create_receiver', 'finalize_sender', 'finalize_receiver', \ 'create_sender', 'create_receiver', 'finalize_sender', 'finalize_receiver', \
'receiver_wait', 'add_receiver_addr', 'sender_connect', 'read_ip_config', \ 'receiver_wait', 'add_receiver_addr', 'sender_connect', 'read_ip_config', \
'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \ 'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call'] 'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \
'send_request_to_machine', 'remote_call_to_machine']
REQUEST_CLASS_TO_SERVICE_ID = {} REQUEST_CLASS_TO_SERVICE_ID = {}
RESPONSE_CLASS_TO_SERVICE_ID = {} RESPONSE_CLASS_TO_SERVICE_ID = {}
...@@ -220,6 +222,16 @@ def get_num_server(): ...@@ -220,6 +222,16 @@ def get_num_server():
""" """
return _CAPI_DGLRPCGetNumServer() return _CAPI_DGLRPCGetNumServer()
def set_num_server_per_machine(num_server):
"""Set the total number of server per machine
"""
_CAPI_DGLRPCSetNumServerPerMachine(num_server)
def get_num_server_per_machine():
"""Get the total number of server per machine
"""
return _CAPI_DGLRPCGetNumServerPerMachine()
def incr_msg_seq(): def incr_msg_seq():
"""Increment the message sequence number and return the old one. """Increment the message sequence number and return the old one.
...@@ -517,6 +529,35 @@ def send_request(target, request): ...@@ -517,6 +529,35 @@ def send_request(target, request):
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors) msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors)
send_rpc_message(msg, server_id) send_rpc_message(msg, server_id)
def send_request_to_machine(target, request):
"""Send one request to the target machine, which will randomly
select a server node to process this request.
The operation is non-blocking -- it does not guarantee the payloads have
reached the target or even have left the sender process. However,
all the payloads (i.e., data and arrays) can be safely freed after this
function returns.
Parameters
----------
target : int
ID of target machine.
request : Request
The request to send.
Raises
------
ConnectionError if there is any problem with the connection.
"""
service_id = request.service_id
msg_seq = incr_msg_seq()
client_id = get_rank()
server_id = random.randint(target*get_num_server_per_machine(),
(target+1)*get_num_server_per_machine()-1)
data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors)
send_rpc_message(msg, server_id)
def send_response(target, response): def send_response(target, response):
"""Send one response to the target client. """Send one response to the target client.
...@@ -644,6 +685,69 @@ def remote_call(target_and_requests, timeout=0): ...@@ -644,6 +685,69 @@ def remote_call(target_and_requests, timeout=0):
Responses for each target-request pair. If the request does not have Responses for each target-request pair. If the request does not have
response, None is placed. response, None is placed.
Raises
------
ConnectionError if there is any problem with the connection.
"""
# TODO(chao): handle timeout
all_res = [None] * len(target_and_requests)
msgseq2pos = {}
num_res = 0
myrank = get_rank()
for pos, (target, request) in enumerate(target_and_requests):
# send request
service_id = request.service_id
msg_seq = incr_msg_seq()
client_id = get_rank()
server_id = random.randint(target*get_num_server_per_machine(),
(target+1)*get_num_server_per_machine()-1)
data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors)
send_rpc_message(msg, server_id)
# check if has response
res_cls = get_service_property(service_id)[1]
if res_cls is not None:
num_res += 1
msgseq2pos[msg_seq] = pos
while num_res != 0:
# recv response
msg = recv_rpc_message(timeout)
num_res -= 1
_, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id]
if res_cls is None:
raise DGLError('Got response message from service ID {}, '
'but no response class is registered.'.format(msg.service_id))
res = deserialize_from_payload(res_cls, msg.data, msg.tensors)
if msg.client_id != myrank:
raise DGLError('Got reponse of request sent by client {}, '
'different from my rank {}!'.format(msg.client_id, myrank))
# set response
all_res[msgseq2pos[msg.msg_seq]] = res
return all_res
def remote_call_to_machine(target_and_requests, timeout=0):
"""Invoke registered services on remote machine
(which will ramdom select a server to process the request) and collect responses.
The operation is blocking -- it returns when it receives all responses
or it times out.
If the target server state is available locally, it invokes local computation
to calculate the response.
Parameters
----------
target_and_requests : list[(int, Request)]
A list of requests and the machine they should be sent to.
timeout : int, optional
The timeout value in milliseconds. If zero, wait indefinitely.
Returns
-------
list[Response]
Responses for each target-request pair. If the request does not have
response, None is placed.
Raises Raises
------ ------
ConnectionError if there is any problem with the connection. ConnectionError if there is any problem with the connection.
......
...@@ -122,6 +122,7 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket ...@@ -122,6 +122,7 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket
group_count.append(server_info[3]) group_count.append(server_info[3])
if server_info[0] > max_machine_id: if server_info[0] > max_machine_id:
max_machine_id = server_info[0] max_machine_id = server_info[0]
rpc.set_num_server_per_machine(group_count[0])
num_machines = max_machine_id+1 num_machines = max_machine_id+1
rpc.set_num_machines(num_machines) rpc.set_num_machines(num_machines)
machine_id = get_local_machine_id(server_namebook) machine_id = get_local_machine_id(server_namebook)
......
...@@ -159,6 +159,17 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer") ...@@ -159,6 +159,17 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer")
*rv = RPCContext::ThreadLocal()->num_servers; *rv = RPCContext::ThreadLocal()->num_servers;
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const int32_t num_servers = args[0];
*rv = RPCContext::ThreadLocal()->num_servers_per_machine = num_servers;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServerPerMachine")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::ThreadLocal()->num_servers_per_machine;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCIncrMsgSeq") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCIncrMsgSeq")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = (RPCContext::ThreadLocal()->msg_seq)++; *rv = (RPCContext::ThreadLocal()->msg_seq)++;
......
...@@ -56,6 +56,11 @@ struct RPCContext { ...@@ -56,6 +56,11 @@ struct RPCContext {
*/ */
int32_t num_servers = 0; int32_t num_servers = 0;
/*!
* \brief Total number of server per machine.
*/
int32_t num_servers_per_machine = 0;
/*! /*!
* \brief Sender communicator. * \brief Sender communicator.
*/ */
......
...@@ -100,6 +100,21 @@ def start_client(): ...@@ -100,6 +100,21 @@ def start_client():
for i in range(10): for i in range(10):
target_and_requests.append((0, req)) target_and_requests.append((0, req))
res_list = dgl.distributed.remote_call(target_and_requests) res_list = dgl.distributed.remote_call(target_and_requests)
for res in res_list:
assert res.hello_str == STR
assert res.integer == INTEGER
assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
# test send_request_to_machine
dgl.distributed.send_request_to_machine(0, req)
res = dgl.distributed.recv_response()
assert res.hello_str == STR
assert res.integer == INTEGER
assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
# test remote_call_to_machine
target_and_requests = []
for i in range(10):
target_and_requests.append((0, req))
res_list = dgl.distributed.remote_call_to_machine(target_and_requests)
for res in res_list: for res in res_list:
assert res.hello_str == STR assert res.hello_str == STR
assert res.integer == INTEGER assert res.integer == INTEGER
...@@ -153,8 +168,6 @@ def test_rpc(): ...@@ -153,8 +168,6 @@ def test_rpc():
start_client() start_client()
if __name__ == '__main__': if __name__ == '__main__':
test_rank()
test_msg_seq()
test_serialize() test_serialize()
test_rpc_msg() test_rpc_msg()
test_rpc() test_rpc()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment