Unverified Commit 37be02a4 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Feature] enable socket net_type for rpc (#3951)

* [Feature] enable socket net_type for rpc

* fix lint

* fix lint

* fix build issue on windows

* fix test failure on windows

* fix test failure

* fix cpp unit test failure

* net_type blocking max_try_times

* fix other comments

* fix lint

* fix comment

* fix lint

* fix cpp
parent c3baf433
...@@ -259,7 +259,7 @@ def run(args, device, data): ...@@ -259,7 +259,7 @@ def run(args, device, data):
def main(args): def main(args):
print(socket.gethostname(), 'Initializing DGL dist') print(socket.gethostname(), 'Initializing DGL dist')
dgl.distributed.initialize(args.ip_config) dgl.distributed.initialize(args.ip_config, net_type=args.net_type)
if not args.standalone: if not args.standalone:
print(socket.gethostname(), 'Initializing DGL process group') print(socket.gethostname(), 'Initializing DGL process group')
th.distributed.init_process_group(backend=args.backend) th.distributed.init_process_group(backend=args.backend)
...@@ -325,6 +325,8 @@ if __name__ == '__main__': ...@@ -325,6 +325,8 @@ if __name__ == '__main__':
parser.add_argument('--standalone', action='store_true', help='run in the standalone mode') parser.add_argument('--standalone', action='store_true', help='run in the standalone mode')
parser.add_argument('--pad-data', default=False, action='store_true', parser.add_argument('--pad-data', default=False, action='store_true',
help='Pad train nid to the same length across machine, to ensure num of batches to be the same.') help='Pad train nid to the same length across machine, to ensure num of batches to be the same.')
parser.add_argument('--net_type', type=str, default='socket',
help="backend net type, 'socket' or 'tensorpipe'")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -174,7 +174,7 @@ class CustomPool: ...@@ -174,7 +174,7 @@ class CustomPool:
def initialize(ip_config, num_servers=1, num_workers=0, def initialize(ip_config, num_servers=1, num_workers=0,
max_queue_size=MAX_QUEUE_SIZE, net_type='socket', max_queue_size=MAX_QUEUE_SIZE, net_type='tensorpipe',
num_worker_threads=1): num_worker_threads=1):
"""Initialize DGL's distributed module """Initialize DGL's distributed module
...@@ -201,9 +201,9 @@ def initialize(ip_config, num_servers=1, num_workers=0, ...@@ -201,9 +201,9 @@ def initialize(ip_config, num_servers=1, num_workers=0,
Note that the 20 GB is just an upper-bound and DGL uses zero-copy and Note that the 20 GB is just an upper-bound and DGL uses zero-copy and
it will not allocate 20GB memory at once. it will not allocate 20GB memory at once.
net_type : str, optional net_type : str, optional
Networking type. Currently the only valid option is ``'socket'``. Networking type. Valid options are: ``'socket'``, ``'tensorpipe'``.
Default: ``'socket'`` Default: ``'tensorpipe'``
num_worker_threads: int num_worker_threads: int
The number of threads in a worker process. The number of threads in a worker process.
...@@ -235,7 +235,8 @@ def initialize(ip_config, num_servers=1, num_workers=0, ...@@ -235,7 +235,8 @@ def initialize(ip_config, num_servers=1, num_workers=0,
int(os.environ.get('DGL_NUM_CLIENT')), int(os.environ.get('DGL_NUM_CLIENT')),
os.environ.get('DGL_CONF_PATH'), os.environ.get('DGL_CONF_PATH'),
graph_format=formats, graph_format=formats,
keep_alive=keep_alive) keep_alive=keep_alive,
net_type=net_type)
serv.start() serv.start()
sys.exit() sys.exit()
else: else:
......
...@@ -311,10 +311,13 @@ class DistGraphServer(KVServer): ...@@ -311,10 +311,13 @@ class DistGraphServer(KVServer):
The graph formats. The graph formats.
keep_alive : bool keep_alive : bool
Whether to keep server alive when clients exit Whether to keep server alive when clients exit
net_type : str
Backend rpc type: ``'socket'`` or ``'tensorpipe'``
''' '''
def __init__(self, server_id, ip_config, num_servers, def __init__(self, server_id, ip_config, num_servers,
num_clients, part_config, disable_shared_mem=False, num_clients, part_config, disable_shared_mem=False,
graph_format=('csc', 'coo'), keep_alive=False): graph_format=('csc', 'coo'), keep_alive=False,
net_type='tensorpipe'):
super(DistGraphServer, self).__init__(server_id=server_id, super(DistGraphServer, self).__init__(server_id=server_id,
ip_config=ip_config, ip_config=ip_config,
num_servers=num_servers, num_servers=num_servers,
...@@ -322,6 +325,7 @@ class DistGraphServer(KVServer): ...@@ -322,6 +325,7 @@ class DistGraphServer(KVServer):
self.ip_config = ip_config self.ip_config = ip_config
self.num_servers = num_servers self.num_servers = num_servers
self.keep_alive = keep_alive self.keep_alive = keep_alive
self.net_type = net_type
# Load graph partition data. # Load graph partition data.
if self.is_backup_server(): if self.is_backup_server():
# The backup server doesn't load the graph partition. It'll initialized afterwards. # The backup server doesn't load the graph partition. It'll initialized afterwards.
...@@ -376,7 +380,9 @@ class DistGraphServer(KVServer): ...@@ -376,7 +380,9 @@ class DistGraphServer(KVServer):
start_server(server_id=self.server_id, start_server(server_id=self.server_id,
ip_config=self.ip_config, ip_config=self.ip_config,
num_servers=self.num_servers, num_servers=self.num_servers,
num_clients=self.num_clients, server_state=server_state) num_clients=self.num_clients,
server_state=server_state,
net_type=self.net_type)
class DistGraph: class DistGraph:
'''The class for accessing a distributed graph. '''The class for accessing a distributed graph.
......
...@@ -15,7 +15,7 @@ from .. import backend as F ...@@ -15,7 +15,7 @@ from .. import backend as F
__all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \ __all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \
'create_sender', 'create_receiver', 'finalize_sender', 'finalize_receiver', \ 'create_sender', 'create_receiver', 'finalize_sender', 'finalize_receiver', \
'receiver_wait', 'connect_receiver', 'read_ip_config', 'get_group_id', \ 'wait_for_senders', 'connect_receiver', 'read_ip_config', 'get_group_id', \
'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \ 'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \ 'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \
'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', \ 'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', \
...@@ -112,7 +112,7 @@ def create_sender(max_queue_size, net_type): ...@@ -112,7 +112,7 @@ def create_sender(max_queue_size, net_type):
max_queue_size : int max_queue_size : int
Maximal size (bytes) of network queue buffer. Maximal size (bytes) of network queue buffer.
net_type : str net_type : str
Networking type. Current options are: 'socket'. Networking type. Current options are: 'socket', 'tensorpipe'.
""" """
max_thread_count = int(os.getenv('DGL_SOCKET_MAX_THREAD_COUNT', '0')) max_thread_count = int(os.getenv('DGL_SOCKET_MAX_THREAD_COUNT', '0'))
_CAPI_DGLRPCCreateSender(int(max_queue_size), net_type, max_thread_count) _CAPI_DGLRPCCreateSender(int(max_queue_size), net_type, max_thread_count)
...@@ -125,7 +125,7 @@ def create_receiver(max_queue_size, net_type): ...@@ -125,7 +125,7 @@ def create_receiver(max_queue_size, net_type):
max_queue_size : int max_queue_size : int
Maximal size (bytes) of network queue buffer. Maximal size (bytes) of network queue buffer.
net_type : str net_type : str
Networking type. Current options are: 'socket'. Networking type. Current options are: 'socket', 'tensorpipe'.
""" """
max_thread_count = int(os.getenv('DGL_SOCKET_MAX_THREAD_COUNT', '0')) max_thread_count = int(os.getenv('DGL_SOCKET_MAX_THREAD_COUNT', '0'))
_CAPI_DGLRPCCreateReceiver(int(max_queue_size), net_type, max_thread_count) _CAPI_DGLRPCCreateReceiver(int(max_queue_size), net_type, max_thread_count)
...@@ -140,7 +140,7 @@ def finalize_receiver(): ...@@ -140,7 +140,7 @@ def finalize_receiver():
""" """
_CAPI_DGLRPCFinalizeReceiver() _CAPI_DGLRPCFinalizeReceiver()
def receiver_wait(ip_addr, port, num_senders, blocking=True): def wait_for_senders(ip_addr, port, num_senders, blocking=True):
"""Wait all of the senders' connections. """Wait all of the senders' connections.
This api will be blocked until all the senders connect to the receiver. This api will be blocked until all the senders connect to the receiver.
...@@ -156,7 +156,7 @@ def receiver_wait(ip_addr, port, num_senders, blocking=True): ...@@ -156,7 +156,7 @@ def receiver_wait(ip_addr, port, num_senders, blocking=True):
blocking : bool blocking : bool
whether to wait blockingly whether to wait blockingly
""" """
_CAPI_DGLRPCReceiverWait(ip_addr, int(port), int(num_senders), blocking) _CAPI_DGLRPCWaitForSenders(ip_addr, int(port), int(num_senders), blocking)
def connect_receiver(ip_addr, port, recv_id, group_id=-1): def connect_receiver(ip_addr, port, recv_id, group_id=-1):
"""Connect to target receiver """Connect to target receiver
...@@ -175,6 +175,15 @@ def connect_receiver(ip_addr, port, recv_id, group_id=-1): ...@@ -175,6 +175,15 @@ def connect_receiver(ip_addr, port, recv_id, group_id=-1):
raise DGLError("Invalid target id: {}".format(target_id)) raise DGLError("Invalid target id: {}".format(target_id))
return _CAPI_DGLRPCConnectReceiver(ip_addr, int(port), int(target_id)) return _CAPI_DGLRPCConnectReceiver(ip_addr, int(port), int(target_id))
def connect_receiver_finalize():
"""Finalize the action to connect to receivers. Make sure that either all connections are
successfully established or connection fails.
When "socket" network backend is in use, the function issues actual requests to receiver
sockets to establish connections.
"""
_CAPI_DGLRPCConnectReceiverFinalize()
def set_rank(rank): def set_rank(rank):
"""Set the rank of this process. """Set the rank of this process.
......
...@@ -103,7 +103,7 @@ def get_local_usable_addr(probe_addr): ...@@ -103,7 +103,7 @@ def get_local_usable_addr(probe_addr):
def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
net_type='socket', group_id=0): net_type='tensorpipe', group_id=0):
"""Connect this client to server. """Connect this client to server.
Parameters Parameters
...@@ -117,7 +117,7 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, ...@@ -117,7 +117,7 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
Note that the 20 GB is just an upper-bound and DGL uses zero-copy and Note that the 20 GB is just an upper-bound and DGL uses zero-copy and
it will not allocate 20GB memory at once. it will not allocate 20GB memory at once.
net_type : str net_type : str
Networking type. Current options are: 'socket'. Networking type. Current options are: 'socket', 'tensorpipe'.
group_id : int group_id : int
Indicates which group this client belongs to. Clients that are Indicates which group this client belongs to. Clients that are
booted together in each launch are gathered as a group and should booted together in each launch are gathered as a group and should
...@@ -129,7 +129,8 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, ...@@ -129,7 +129,8 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
""" """
assert num_servers > 0, 'num_servers (%d) must be a positive number.' % num_servers assert num_servers > 0, 'num_servers (%d) must be a positive number.' % num_servers
assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % max_queue_size assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % max_queue_size
assert net_type in ('socket'), 'net_type (%s) can only be \'socket\'.' % net_type assert net_type in ('socket', 'tensorpipe'), \
'net_type (%s) can only be \'socket\' or \'tensorpipe\'.' % net_type
# Register some basic service # Register some basic service
rpc.register_service(rpc.CLIENT_REGISTER, rpc.register_service(rpc.CLIENT_REGISTER,
rpc.ClientRegisterRequest, rpc.ClientRegisterRequest,
...@@ -169,16 +170,19 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, ...@@ -169,16 +170,19 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
server_port = addr[2] server_port = addr[2]
while not rpc.connect_receiver(server_ip, server_port, server_id): while not rpc.connect_receiver(server_ip, server_port, server_id):
time.sleep(3) time.sleep(3)
rpc.connect_receiver_finalize()
# Get local usable IP address and port # Get local usable IP address and port
ip_addr = get_local_usable_addr(server_ip) ip_addr = get_local_usable_addr(server_ip)
client_ip, client_port = ip_addr.split(':') client_ip, client_port = ip_addr.split(':')
# wait server connect back
rpc.receiver_wait(client_ip, client_port, num_servers, blocking=False)
print("Client [{}] waits on {}:{}".format(os.getpid(), client_ip, client_port))
# Register client on server # Register client on server
register_req = rpc.ClientRegisterRequest(ip_addr) register_req = rpc.ClientRegisterRequest(ip_addr)
for server_id in range(num_servers): for server_id in range(num_servers):
rpc.send_request(server_id, register_req) rpc.send_request(server_id, register_req)
# wait server connect back
rpc.wait_for_senders(client_ip, client_port, num_servers,
blocking=net_type == 'socket')
print("Client [{}] waits on {}:{}".format(
os.getpid(), client_ip, client_port))
# recv client ID from server # recv client ID from server
res = rpc.recv_response() res = rpc.recv_response()
rpc.set_rank(res.client_id) rpc.set_rank(res.client_id)
...@@ -222,7 +226,7 @@ def shutdown_servers(ip_config, num_servers): ...@@ -222,7 +226,7 @@ def shutdown_servers(ip_config, num_servers):
rpc.register_sig_handler() rpc.register_sig_handler()
server_namebook = rpc.read_ip_config(ip_config, num_servers) server_namebook = rpc.read_ip_config(ip_config, num_servers)
num_servers = len(server_namebook) num_servers = len(server_namebook)
rpc.create_sender(MAX_QUEUE_SIZE, 'socket') rpc.create_sender(MAX_QUEUE_SIZE, 'tensorpipe')
# Get connected with all server nodes # Get connected with all server nodes
for server_id, addr in server_namebook.items(): for server_id, addr in server_namebook.items():
server_ip = addr[1] server_ip = addr[1]
......
"""Functions used by server.""" """Functions used by server."""
import time import time
import os
from ..base import DGLError from ..base import DGLError
from . import rpc from . import rpc
from .constants import MAX_QUEUE_SIZE, SERVER_EXIT, SERVER_KEEP_ALIVE from .constants import MAX_QUEUE_SIZE, SERVER_EXIT, SERVER_KEEP_ALIVE
def start_server(server_id, ip_config, num_servers, num_clients, server_state, \ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
max_queue_size=MAX_QUEUE_SIZE, net_type='socket'): max_queue_size=MAX_QUEUE_SIZE, net_type='tensorpipe'):
"""Start DGL server, which will be shared with all the rpc services. """Start DGL server, which will be shared with all the rpc services.
This is a blocking function -- it returns only when the server shutdown. This is a blocking function -- it returns only when the server shutdown.
...@@ -31,14 +32,17 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \ ...@@ -31,14 +32,17 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
Note that the 20 GB is just an upper-bound because DGL uses zero-copy and Note that the 20 GB is just an upper-bound because DGL uses zero-copy and
it will not allocate 20GB memory at once. it will not allocate 20GB memory at once.
net_type : str net_type : str
Networking type. Current options are: 'socket'. Networking type. Current options are: ``'socket'`` or ``'tensorpipe'``.
""" """
assert server_id >= 0, 'server_id (%d) cannot be a negative number.' % server_id assert server_id >= 0, 'server_id (%d) cannot be a negative number.' % server_id
assert num_servers > 0, 'num_servers (%d) must be a positive number.' % num_servers assert num_servers > 0, 'num_servers (%d) must be a positive number.' % num_servers
assert num_clients >= 0, 'num_client (%d) cannot be a negative number.' % num_clients assert num_clients >= 0, 'num_client (%d) cannot be a negative number.' % num_clients
assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % max_queue_size assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % max_queue_size
assert net_type in ('socket'), 'net_type (%s) can only be \'socket\'' % net_type assert net_type in ('socket', 'tensorpipe'), \
'net_type (%s) can only be \'socket\' or \'tensorpipe\'' % net_type
if server_state.keep_alive: if server_state.keep_alive:
assert net_type == 'tensorpipe', \
"net_type can only be 'tensorpipe' if 'keep_alive' is enabled."
print("As configured, this server will keep alive for multiple" print("As configured, this server will keep alive for multiple"
" client groups until force shutdown request is received.") " client groups until force shutdown request is received.")
# Register signal handler. # Register signal handler.
...@@ -68,8 +72,9 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \ ...@@ -68,8 +72,9 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
# Once all the senders connect to server, server will not # Once all the senders connect to server, server will not
# accept new sender's connection # accept new sender's connection
print( print(
"Server is waiting for connections non-blockingly on [{}:{}]...".format(ip_addr, port)) "Server is waiting for connections on [{}:{}]...".format(ip_addr, port))
rpc.receiver_wait(ip_addr, port, num_clients, blocking=False) rpc.wait_for_senders(ip_addr, port, num_clients,
blocking=net_type == 'socket')
rpc.set_num_client(num_clients) rpc.set_num_client(num_clients)
recv_clients = {} recv_clients = {}
while True: while True:
...@@ -83,11 +88,21 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \ ...@@ -83,11 +88,21 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
# a new client group is ready # a new client group is ready
ips.sort() ips.sort()
client_namebook = dict(enumerate(ips)) client_namebook = dict(enumerate(ips))
time.sleep(3) # wait for clients' receivers ready
max_try_times = int(os.environ.get('DGL_DIST_MAX_TRY_TIMES', 120))
for client_id, addr in client_namebook.items(): for client_id, addr in client_namebook.items():
client_ip, client_port = addr.split(':') client_ip, client_port = addr.split(':')
# TODO[Rhett]: server should not be blocked endlessly. try_times = 0
while not rpc.connect_receiver(client_ip, client_port, client_id, group_id): while not rpc.connect_receiver(client_ip, client_port, client_id, group_id):
try_times += 1
if try_times >= max_try_times:
raise DGLError("Failed to connect to receiver [{}:{}] after {} "
"retries. Please check availability of this target "
"receiver or change the max retry times via "
"'DGL_DIST_MAX_TRY_TIMES'.".format(
client_ip, client_port, max_try_times))
time.sleep(1) time.sleep(1)
rpc.connect_receiver_finalize()
if rpc.get_rank() == 0: # server_0 send all the IDs if rpc.get_rank() == 0: # server_0 send all the IDs
for client_id, _ in client_namebook.items(): for client_id, _ in client_namebook.items():
register_res = rpc.ClientRegisterResponse(client_id) register_res = rpc.ClientRegisterResponse(client_id)
......
...@@ -250,19 +250,19 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderAddReceiver") ...@@ -250,19 +250,19 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderAddReceiver")
int recv_id = args[3]; int recv_id = args[3];
network::Sender* sender = static_cast<network::Sender*>(chandle); network::Sender* sender = static_cast<network::Sender*>(chandle);
std::string addr; std::string addr;
if (sender->Type() == "socket") { if (sender->NetType() == "socket") {
addr = StringPrintf("socket://%s:%d", ip.c_str(), port); addr = StringPrintf("socket://%s:%d", ip.c_str(), port);
} else { } else {
LOG(FATAL) << "Unknown communicator type: " << sender->Type(); LOG(FATAL) << "Unknown communicator type: " << sender->NetType();
} }
sender->AddReceiver(addr.c_str(), recv_id); sender->ConnectReceiver(addr.c_str(), recv_id);
}); });
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderConnect") DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderConnect")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
network::Sender* sender = static_cast<network::Sender*>(chandle); network::Sender* sender = static_cast<network::Sender*>(chandle);
if (sender->Connect() == false) { if (sender->ConnectReceiverFinalize() == false) {
LOG(FATAL) << "Sender connection failed."; LOG(FATAL) << "Sender connection failed.";
} }
}); });
...@@ -275,10 +275,10 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverWait") ...@@ -275,10 +275,10 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverWait")
int num_sender = args[3]; int num_sender = args[3];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle); network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
std::string addr; std::string addr;
if (receiver->Type() == "socket") { if (receiver->NetType() == "socket") {
addr = StringPrintf("socket://%s:%d", ip.c_str(), port); addr = StringPrintf("socket://%s:%d", ip.c_str(), port);
} else { } else {
LOG(FATAL) << "Unknown communicator type: " << receiver->Type(); LOG(FATAL) << "Unknown communicator type: " << receiver->NetType();
} }
if (receiver->Wait(addr.c_str(), num_sender) == false) { if (receiver->Wait(addr.c_str(), num_sender) == false) {
LOG(FATAL) << "Wait sender socket failed."; LOG(FATAL) << "Wait sender socket failed.";
......
/*!
* Copyright (c) 2022 by Contributors
* \file net_type.h
* \brief Base communicator for DGL distributed training.
*/
#ifndef DGL_RPC_NET_TYPE_H_
#define DGL_RPC_NET_TYPE_H_
#include <string>
#include "rpc_msg.h"
namespace dgl {
namespace rpc {
struct RPCBase {
/*!
* \brief Finalize Receiver
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/
virtual void Finalize() = 0;
/*!
* \brief Communicator type: 'socket', 'tensorpipe', etc
*/
virtual const std::string &NetType() const = 0;
};
struct RPCSender : RPCBase {
/*!
* \brief Connect to a receiver.
*
* When there are multiple receivers to be connected, application will call `ConnectReceiver`
* for each and then call `ConnectReceiverFinalize` to make sure that either all the connections are
* successfully established or some of them fail.
*
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
* \param recv_id receiver's ID
* \return True for success and False for fail
*
* The function is *not* thread-safe; only one thread can invoke this API.
*/
virtual bool ConnectReceiver(const std::string &addr, int recv_id) = 0;
/*!
* \brief Finalize the action to connect to receivers. Make sure that either
* all connections are successfully established or connection fails.
* \return True for success and False for fail
*
* The function is *not* thread-safe; only one thread can invoke this API.
*/
virtual bool ConnectReceiverFinalize() { return true; }
/*!
* \brief Send RPCMessage to specified Receiver.
* \param msg data message
* \param recv_id receiver's ID
*/
virtual void Send(const RPCMessage &msg, int recv_id) = 0;
};
struct RPCReceiver : RPCBase {
/*!
* \brief Wait for all the Senders to connect
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50051', 'mpi://0'
* \param num_sender total number of Senders
* \param blocking whether wait blockingly
* \return True for success and False for fail
*
* Wait() is not thread-safe and only one thread can invoke this API.
*/
virtual bool Wait(const std::string &addr, int num_sender,
bool blocking = true) = 0;
/*!
* \brief Recv RPCMessage from Sender. Actually removing data from queue.
* \param msg pointer of RPCmessage
*/
virtual void Recv(RPCMessage *msg) = 0;
};
} // namespace rpc
} // namespace dgl
#endif // DGL_RPC_NET_TYPE_H_
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <string> #include <string>
#include "../net_type.h"
#include "msg_queue.h" #include "msg_queue.h"
namespace dgl { namespace dgl {
...@@ -23,7 +24,7 @@ namespace network { ...@@ -23,7 +24,7 @@ namespace network {
* networking libraries such TCP socket and MPI. One Sender can connect to * networking libraries such TCP socket and MPI. One Sender can connect to
* multiple receivers and it can send data to specified receiver via receiver's ID. * multiple receivers and it can send data to specified receiver via receiver's ID.
*/ */
class Sender { class Sender : public rpc::RPCSender {
public: public:
/*! /*!
* \brief Sender constructor * \brief Sender constructor
...@@ -40,23 +41,6 @@ class Sender { ...@@ -40,23 +41,6 @@ class Sender {
virtual ~Sender() {} virtual ~Sender() {}
/*!
* \brief Add receiver's address and ID to the sender's namebook
* \param addr Networking address, e.g., 'socket://127.0.0.1:50091', 'mpi://0'
* \param id receiver's ID
*
* AddReceiver() is not thread-safe and only one thread can invoke this API.
*/
virtual void AddReceiver(const char* addr, int id) = 0;
/*!
* \brief Connect with all the Receivers
* \return True for success and False for fail
*
* Connect() is not thread-safe and only one thread can invoke this API.
*/
virtual bool Connect() = 0;
/*! /*!
* \brief Send data to specified Receiver. * \brief Send data to specified Receiver.
* \param msg data message * \param msg data message
...@@ -72,18 +56,6 @@ class Sender { ...@@ -72,18 +56,6 @@ class Sender {
*/ */
virtual STATUS Send(Message msg, int recv_id) = 0; virtual STATUS Send(Message msg, int recv_id) = 0;
/*!
* \brief Finalize Sender
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/
virtual void Finalize() = 0;
/*!
* \brief Communicator type: 'socket', 'mpi', etc.
*/
virtual std::string Type() const = 0;
protected: protected:
/*! /*!
* \brief Size of message queue * \brief Size of message queue
...@@ -103,7 +75,7 @@ class Sender { ...@@ -103,7 +75,7 @@ class Sender {
* libraries such as TCP socket and MPI. One Receiver can connect with multiple Senders * libraries such as TCP socket and MPI. One Receiver can connect with multiple Senders
* and it can receive data from multiple Senders concurrently. * and it can receive data from multiple Senders concurrently.
*/ */
class Receiver { class Receiver : public rpc::RPCReceiver {
public: public:
/*! /*!
* \brief Receiver constructor * \brief Receiver constructor
...@@ -122,16 +94,6 @@ class Receiver { ...@@ -122,16 +94,6 @@ class Receiver {
virtual ~Receiver() {} virtual ~Receiver() {}
/*!
* \brief Wait for all the Senders to connect
* \param addr Networking address, e.g., 'socket://127.0.0.1:50051', 'mpi://0'
* \param num_sender total number of Senders
* \return True for success and False for fail
*
* Wait() is not thread-safe and only one thread can invoke this API.
*/
virtual bool Wait(const char* addr, int num_sender) = 0;
/*! /*!
* \brief Recv data from Sender * \brief Recv data from Sender
* \param msg pointer of data message * \param msg pointer of data message
...@@ -158,18 +120,6 @@ class Receiver { ...@@ -158,18 +120,6 @@ class Receiver {
*/ */
virtual STATUS RecvFrom(Message* msg, int send_id) = 0; virtual STATUS RecvFrom(Message* msg, int send_id) = 0;
/*!
* \brief Finalize Receiver
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/
virtual void Finalize() = 0;
/*!
* \brief Communicator type: 'socket', 'mpi', etc
*/
virtual std::string Type() const = 0;
protected: protected:
/*! /*!
* \brief Size of message queue * \brief Size of message queue
......
...@@ -27,8 +27,7 @@ namespace network { ...@@ -27,8 +27,7 @@ namespace network {
/////////////////////////////////////// SocketSender /////////////////////////////////////////// /////////////////////////////////////// SocketSender ///////////////////////////////////////////
void SocketSender::AddReceiver(const char* addr, int recv_id) { bool SocketSender::ConnectReceiver(const std::string& addr, int recv_id) {
CHECK_NOTNULL(addr);
if (recv_id < 0) { if (recv_id < 0) {
LOG(FATAL) << "recv_id cannot be a negative number."; LOG(FATAL) << "recv_id cannot be a negative number.";
} }
...@@ -36,25 +35,27 @@ void SocketSender::AddReceiver(const char* addr, int recv_id) { ...@@ -36,25 +35,27 @@ void SocketSender::AddReceiver(const char* addr, int recv_id) {
std::vector<std::string> ip_and_port; std::vector<std::string> ip_and_port;
SplitStringUsing(addr, "//", &substring); SplitStringUsing(addr, "//", &substring);
// Check address format // Check address format
if (substring[0] != "socket:" || substring.size() != 2) { if (substring[0] != "tcp:" || substring.size() != 2) {
LOG(FATAL) << "Incorrect address format:" << addr LOG(FATAL) << "Incorrect address format:" << addr
<< " Please provide right address format, " << " Please provide right address format, "
<< "e.g, 'socket://127.0.0.1:50051'. "; << "e.g, 'tcp://127.0.0.1:50051'. ";
} }
// Get IP and port // Get IP and port
SplitStringUsing(substring[1], ":", &ip_and_port); SplitStringUsing(substring[1], ":", &ip_and_port);
if (ip_and_port.size() != 2) { if (ip_and_port.size() != 2) {
LOG(FATAL) << "Incorrect address format:" << addr LOG(FATAL) << "Incorrect address format:" << addr
<< " Please provide right address format, " << " Please provide right address format, "
<< "e.g, 'socket://127.0.0.1:50051'. "; << "e.g, 'tcp://127.0.0.1:50051'. ";
} }
IPAddr address; IPAddr address;
address.ip = ip_and_port[0]; address.ip = ip_and_port[0];
address.port = std::stoi(ip_and_port[1]); address.port = std::stoi(ip_and_port[1]);
receiver_addrs_[recv_id] = address; receiver_addrs_[recv_id] = address;
return true;
} }
bool SocketSender::Connect() { bool SocketSender::ConnectReceiverFinalize() {
// Create N sockets for Receiver // Create N sockets for Receiver
int receiver_count = static_cast<int>(receiver_addrs_.size()); int receiver_count = static_cast<int>(receiver_addrs_.size());
if (max_thread_count_ == 0 || max_thread_count_ > receiver_count) { if (max_thread_count_ == 0 || max_thread_count_ > receiver_count) {
...@@ -79,11 +80,7 @@ bool SocketSender::Connect() { ...@@ -79,11 +80,7 @@ bool SocketSender::Connect() {
LOG(INFO) << "Try to connect to: " << ip << ":" << port; LOG(INFO) << "Try to connect to: " << ip << ":" << port;
} }
try_count++; try_count++;
#ifdef _WIN32 std::this_thread::sleep_for(std::chrono::seconds(5));
Sleep(5);
#else // !_WIN32
sleep(5);
#endif // _WIN32
} }
} }
if (bo == false) { if (bo == false) {
...@@ -103,6 +100,34 @@ bool SocketSender::Connect() { ...@@ -103,6 +100,34 @@ bool SocketSender::Connect() {
return true; return true;
} }
void SocketSender::Send(const rpc::RPCMessage& msg, int recv_id) {
std::shared_ptr<std::string> zerocopy_blob(new std::string());
StreamWithBuffer zc_write_strm(zerocopy_blob.get(), true);
zc_write_strm.Write(msg);
int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size();
zerocopy_blob->append(reinterpret_cast<char*>(&nonempty_ndarray_count),
sizeof(int32_t));
Message rpc_meta_msg;
rpc_meta_msg.data = const_cast<char*>(zerocopy_blob->data());
rpc_meta_msg.size = zerocopy_blob->size();
rpc_meta_msg.deallocator = [zerocopy_blob](Message*) {};
CHECK_EQ(Send(
rpc_meta_msg, recv_id), ADD_SUCCESS);
// send real ndarray data
for (auto ptr : zc_write_strm.buffer_list()) {
Message ndarray_data_msg;
ndarray_data_msg.data = reinterpret_cast<char*>(ptr.data);
if (ptr.size == 0) {
LOG(FATAL) << "Cannot send a empty NDArray.";
}
ndarray_data_msg.size = ptr.size;
NDArray tensor = ptr.tensor;
ndarray_data_msg.deallocator = [tensor](Message*) {};
CHECK_EQ(Send(
ndarray_data_msg, recv_id), ADD_SUCCESS);
}
}
STATUS SocketSender::Send(Message msg, int recv_id) { STATUS SocketSender::Send(Message msg, int recv_id) {
CHECK_NOTNULL(msg.data); CHECK_NOTNULL(msg.data);
CHECK_GT(msg.size, 0); CHECK_GT(msg.size, 0);
...@@ -119,11 +144,7 @@ void SocketSender::Finalize() { ...@@ -119,11 +144,7 @@ void SocketSender::Finalize() {
// wait until queue is empty // wait until queue is empty
auto& mq = msg_queue_[i]; auto& mq = msg_queue_[i];
while (mq->Empty() == false) { while (mq->Empty() == false) {
#ifdef _WIN32 std::this_thread::sleep_for(std::chrono::seconds(1));
// just loop
#else // !_WIN32
usleep(1000);
#endif // _WIN32
} }
// All queues have only one producer, which is main thread, so // All queues have only one producer, which is main thread, so
// the producerID argument here should be zero. // the producerID argument here should be zero.
...@@ -185,25 +206,24 @@ void SocketSender::SendLoop( ...@@ -185,25 +206,24 @@ void SocketSender::SendLoop(
} }
/////////////////////////////////////// SocketReceiver /////////////////////////////////////////// /////////////////////////////////////// SocketReceiver ///////////////////////////////////////////
bool SocketReceiver::Wait(const std::string &addr, int num_sender, bool blocking) {
bool SocketReceiver::Wait(const char* addr, int num_sender) {
CHECK_NOTNULL(addr);
CHECK_GT(num_sender, 0); CHECK_GT(num_sender, 0);
CHECK_EQ(blocking, true);
std::vector<std::string> substring; std::vector<std::string> substring;
std::vector<std::string> ip_and_port; std::vector<std::string> ip_and_port;
SplitStringUsing(addr, "//", &substring); SplitStringUsing(addr, "//", &substring);
// Check address format // Check address format
if (substring[0] != "socket:" || substring.size() != 2) { if (substring[0] != "tcp:" || substring.size() != 2) {
LOG(FATAL) << "Incorrect address format:" << addr LOG(FATAL) << "Incorrect address format:" << addr
<< " Please provide right address format, " << " Please provide right address format, "
<< "e.g, 'socket://127.0.0.1:50051'. "; << "e.g, 'tcp://127.0.0.1:50051'. ";
} }
// Get IP and port // Get IP and port
SplitStringUsing(substring[1], ":", &ip_and_port); SplitStringUsing(substring[1], ":", &ip_and_port);
if (ip_and_port.size() != 2) { if (ip_and_port.size() != 2) {
LOG(FATAL) << "Incorrect address format:" << addr LOG(FATAL) << "Incorrect address format:" << addr
<< " Please provide right address format, " << " Please provide right address format, "
<< "e.g, 'socket://127.0.0.1:50051'. "; << "e.g, 'tcp://127.0.0.1:50051'. ";
} }
std::string ip = ip_and_port[0]; std::string ip = ip_and_port[0];
int port = stoi(ip_and_port[1]); int port = stoi(ip_and_port[1]);
...@@ -255,6 +275,26 @@ bool SocketReceiver::Wait(const char* addr, int num_sender) { ...@@ -255,6 +275,26 @@ bool SocketReceiver::Wait(const char* addr, int num_sender) {
return true; return true;
} }
void SocketReceiver::Recv(rpc::RPCMessage* msg) {
Message rpc_meta_msg;
int send_id;
CHECK_EQ(Recv(
&rpc_meta_msg, &send_id), REMOVE_SUCCESS);
char* count_ptr = rpc_meta_msg.data+rpc_meta_msg.size-sizeof(int32_t);
int32_t nonempty_ndarray_count = *(reinterpret_cast<int32_t*>(count_ptr));
// Recv real ndarray data
std::vector<void*> buffer_list(nonempty_ndarray_count);
for (int i = 0; i < nonempty_ndarray_count; ++i) {
Message ndarray_data_msg;
CHECK_EQ(RecvFrom(
&ndarray_data_msg, send_id), REMOVE_SUCCESS);
buffer_list[i] = ndarray_data_msg.data;
}
StreamWithBuffer zc_read_strm(rpc_meta_msg.data, rpc_meta_msg.size-sizeof(int32_t), buffer_list);
zc_read_strm.Read(msg);
rpc_meta_msg.deallocator(&rpc_meta_msg);
}
STATUS SocketReceiver::Recv(Message* msg, int* send_id) { STATUS SocketReceiver::Recv(Message* msg, int* send_id) {
// queue_sem_ is a semaphore indicating how many elements in multiple // queue_sem_ is a semaphore indicating how many elements in multiple
// message queues. // message queues.
...@@ -288,11 +328,7 @@ void SocketReceiver::Finalize() { ...@@ -288,11 +328,7 @@ void SocketReceiver::Finalize() {
for (auto& mq : msg_queue_) { for (auto& mq : msg_queue_) {
// wait until queue is empty // wait until queue is empty
while (mq.second->Empty() == false) { while (mq.second->Empty() == false) {
#ifdef _WIN32 std::this_thread::sleep_for(std::chrono::seconds(1));
// just loop
#else // !_WIN32
usleep(1000);
#endif // _WIN32
} }
mq.second->SignalFinished(mq.first); mq.second->SignalFinished(mq.first);
} }
......
...@@ -49,21 +49,48 @@ class SocketSender : public Sender { ...@@ -49,21 +49,48 @@ class SocketSender : public Sender {
: Sender(queue_size, max_thread_count) {} : Sender(queue_size, max_thread_count) {}
/*! /*!
* \brief Add receiver's address and ID to the sender's namebook * \brief Connect to a receiver.
* \param addr Networking address, e.g., 'socket://127.0.0.1:50091', 'mpi://0' *
* \param id receiver's ID * When there are multiple receivers to be connected, application will call `ConnectReceiver`
* for each and then call `ConnectReceiverFinalize` to make sure that either all the connections are
* successfully established or some of them fail.
*
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
* \param recv_id receiver's ID
* \return True for success and False for fail
* *
* AddReceiver() is not thread-safe and only one thread can invoke this API. * The function is *not* thread-safe; only one thread can invoke this API.
*/ */
void AddReceiver(const char* addr, int recv_id); bool ConnectReceiver(const std::string& addr, int recv_id) override;
/*! /*!
* \brief Connect with all the Receivers * \brief Finalize the action to connect to receivers. Make sure that either
* all connections are successfully established or connection fails.
* \return True for success and False for fail * \return True for success and False for fail
* *
* Connect() is not thread-safe and only one thread can invoke this API. * The function is *not* thread-safe; only one thread can invoke this API.
*/
bool ConnectReceiverFinalize() override;
/*!
* \brief Send RPCMessage to specified Receiver.
* \param msg data message
* \param recv_id receiver's ID
*/
void Send(const rpc::RPCMessage& msg, int recv_id) override;
/*!
* \brief Finalize TPSender
*/ */
bool Connect(); void Finalize() override;
/*!
* \brief Communicator type: 'socket'
*/
const std::string &NetType() const override {
static const std::string net_type = "socket";
return net_type;
}
/*! /*!
* \brief Send data to specified Receiver. Actually pushing message to message queue. * \brief Send data to specified Receiver. Actually pushing message to message queue.
...@@ -78,19 +105,7 @@ class SocketSender : public Sender { ...@@ -78,19 +105,7 @@ class SocketSender : public Sender {
* (4) Messages sent to the same receiver are guaranteed to be received in the same order. * (4) Messages sent to the same receiver are guaranteed to be received in the same order.
* There is no guarantee for messages sent to different receivers. * There is no guarantee for messages sent to different receivers.
*/ */
STATUS Send(Message msg, int recv_id); STATUS Send(Message msg, int recv_id) override;
/*!
* \brief Finalize SocketSender
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/
void Finalize();
/*!
* \brief Communicator type: 'socket'
*/
inline std::string Type() const { return std::string("socket"); }
private: private:
/*! /*!
...@@ -145,13 +160,21 @@ class SocketReceiver : public Receiver { ...@@ -145,13 +160,21 @@ class SocketReceiver : public Receiver {
/*! /*!
* \brief Wait for all the Senders to connect * \brief Wait for all the Senders to connect
* \param addr Networking address, e.g., 'socket://127.0.0.1:50051', 'mpi://0' * \param addr Networking address, e.g., 'tcp://127.0.0.1:50051', 'mpi://0'
* \param num_sender total number of Senders * \param num_sender total number of Senders
* \param blocking whether wait blockingly
* \return True for success and False for fail * \return True for success and False for fail
* *
* Wait() is not thread-safe and only one thread can invoke this API. * Wait() is not thread-safe and only one thread can invoke this API.
*/ */
bool Wait(const char* addr, int num_sender); bool Wait(const std::string &addr, int num_sender,
bool blocking = true) override;
/*!
* \brief Recv RPCMessage from Sender. Actually removing data from queue.
* \param msg pointer of RPCmessage
*/
void Recv(rpc::RPCMessage* msg) override;
/*! /*!
* \brief Recv data from Sender. Actually removing data from msg_queue. * \brief Recv data from Sender. Actually removing data from msg_queue.
...@@ -164,7 +187,7 @@ class SocketReceiver : public Receiver { ...@@ -164,7 +187,7 @@ class SocketReceiver : public Receiver {
* (2) The Recv() API is thread-safe. * (2) The Recv() API is thread-safe.
* (3) Memory allocated by communicator but will not own it after the function returns. * (3) Memory allocated by communicator but will not own it after the function returns.
*/ */
STATUS Recv(Message* msg, int* send_id); STATUS Recv(Message* msg, int* send_id) override;
/*! /*!
* \brief Recv data from a specified Sender. Actually removing data from msg_queue. * \brief Recv data from a specified Sender. Actually removing data from msg_queue.
...@@ -177,19 +200,22 @@ class SocketReceiver : public Receiver { ...@@ -177,19 +200,22 @@ class SocketReceiver : public Receiver {
* (2) The RecvFrom() API is thread-safe. * (2) The RecvFrom() API is thread-safe.
* (3) Memory allocated by communicator but will not own it after the function returns. * (3) Memory allocated by communicator but will not own it after the function returns.
*/ */
STATUS RecvFrom(Message* msg, int send_id); STATUS RecvFrom(Message* msg, int send_id) override;
/*! /*!
* \brief Finalize SocketReceiver * \brief Finalize SocketReceiver
* *
* Finalize() is not thread-safe and only one thread can invoke this API. * Finalize() is not thread-safe and only one thread can invoke this API.
*/ */
void Finalize(); void Finalize() override;
/*! /*!
* \brief Communicator type: 'socket' * \brief Communicator type: 'socket'
*/ */
inline std::string Type() const { return std::string("socket"); } const std::string &NetType() const override {
static const std::string net_type = "socket";
return net_type;
}
private: private:
struct RecvContext { struct RecvContext {
......
...@@ -114,18 +114,38 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender") ...@@ -114,18 +114,38 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
int64_t msg_queue_size = args[0]; int64_t msg_queue_size = args[0];
std::string type = args[1]; std::string type = args[1];
InitGlobalTpContext(); int max_thread_count = args[2];
RPCContext::getInstance()->sender = if (type == "tensorpipe") {
std::make_shared<TPSender>(RPCContext::getInstance()->ctx); InitGlobalTpContext();
RPCContext::getInstance()->sender.reset(
new TPSender(RPCContext::getInstance()->ctx));
} else if (type == "socket") {
RPCContext::getInstance()->sender.reset(
new network::SocketSender(msg_queue_size, max_thread_count));
} else {
LOG(FATAL) << "Unknown communicator type for rpc sender: " << type;
}
LOG(INFO) << "Sender with NetType~"
<< RPCContext::getInstance()->sender->NetType() << " is created.";
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
int64_t msg_queue_size = args[0]; int64_t msg_queue_size = args[0];
std::string type = args[1]; std::string type = args[1];
InitGlobalTpContext(); int max_thread_count = args[2];
RPCContext::getInstance()->receiver = if (type == "tensorpipe") {
std::make_shared<TPReceiver>(RPCContext::getInstance()->ctx); InitGlobalTpContext();
RPCContext::getInstance()->receiver.reset(
new TPReceiver(RPCContext::getInstance()->ctx));
} else if (type == "socket") {
RPCContext::getInstance()->receiver.reset(
new network::SocketReceiver(msg_queue_size, max_thread_count));
} else {
LOG(FATAL) << "Unknown communicator type for rpc receiver: " << type;
}
LOG(INFO) << "Receiver with NetType~"
<< RPCContext::getInstance()->receiver->NetType() << " is created.";
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeSender") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeSender")
...@@ -138,7 +158,7 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeReceiver") ...@@ -138,7 +158,7 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeReceiver")
RPCContext::getInstance()->receiver->Finalize(); RPCContext::getInstance()->receiver->Finalize();
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReceiverWait") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCWaitForSenders")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
std::string ip = args[0]; std::string ip = args[0];
int port = args[1]; int port = args[1];
...@@ -161,6 +181,11 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver") ...@@ -161,6 +181,11 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver")
*rv = RPCContext::getInstance()->sender->ConnectReceiver(addr, recv_id); *rv = RPCContext::getInstance()->sender->ConnectReceiver(addr, recv_id);
}); });
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiverFinalize")
.set_body([](DGLArgs args, DGLRetValue* rv) {
RPCContext::getInstance()->sender->ConnectReceiverFinalize();
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t rank = args[0]; const int32_t rank = args[0];
......
...@@ -19,7 +19,9 @@ ...@@ -19,7 +19,9 @@
#include <unordered_map> #include <unordered_map>
#include "./rpc_msg.h" #include "./rpc_msg.h"
#include "./tensorpipe/tp_communicator.h" #include "net_type.h"
#include "network/socket_communicator.h"
#include "tensorpipe/tp_communicator.h"
#include "./network/common.h" #include "./network/common.h"
#include "./server_state.h" #include "./server_state.h"
...@@ -79,12 +81,12 @@ struct RPCContext { ...@@ -79,12 +81,12 @@ struct RPCContext {
/*! /*!
* \brief Sender communicator. * \brief Sender communicator.
*/ */
std::shared_ptr<TPSender> sender; std::shared_ptr<RPCSender> sender;
/*! /*!
* \brief Receiver communicator. * \brief Receiver communicator.
*/ */
std::shared_ptr<TPReceiver> receiver; std::shared_ptr<RPCReceiver> receiver;
/*! /*!
* \brief Tensorpipe global context * \brief Tensorpipe global context
...@@ -194,8 +196,4 @@ RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout = 0); ...@@ -194,8 +196,4 @@ RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout = 0);
} // namespace rpc } // namespace rpc
} // namespace dgl } // namespace dgl
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::rpc::RPCMessage, true);
} // namespace dmlc
#endif // DGL_RPC_RPC_H_ #endif // DGL_RPC_RPC_H_
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#define DGL_RPC_RPC_MSG_H_ #define DGL_RPC_RPC_MSG_H_
#include <dgl/runtime/object.h> #include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/zerocopy_serializer.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -70,4 +72,9 @@ DGL_DEFINE_OBJECT_REF(RPCMessageRef, RPCMessage); ...@@ -70,4 +72,9 @@ DGL_DEFINE_OBJECT_REF(RPCMessageRef, RPCMessage);
} // namespace rpc } // namespace rpc
} // namespace dgl } // namespace dgl
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::rpc::RPCMessage, true);
} // namespace dmlc
#endif // DGL_RPC_RPC_MSG_H_ #endif // DGL_RPC_RPC_MSG_H_
...@@ -17,12 +17,11 @@ ...@@ -17,12 +17,11 @@
#include <vector> #include <vector>
#include <atomic> #include <atomic>
#include "./queue.h" #include "./queue.h"
#include "../net_type.h"
namespace dgl { namespace dgl {
namespace rpc { namespace rpc {
class RPCMessage;
typedef Queue<RPCMessage> RPCMessageQueue; typedef Queue<RPCMessage> RPCMessageQueue;
/*! /*!
...@@ -30,7 +29,7 @@ typedef Queue<RPCMessage> RPCMessageQueue; ...@@ -30,7 +29,7 @@ typedef Queue<RPCMessage> RPCMessageQueue;
* *
* TPSender is the communicator implemented by tcp socket. * TPSender is the communicator implemented by tcp socket.
*/ */
class TPSender { class TPSender : public RPCSender {
public: public:
/*! /*!
* \brief Sender constructor * \brief Sender constructor
...@@ -47,30 +46,39 @@ class TPSender { ...@@ -47,30 +46,39 @@ class TPSender {
~TPSender() { Finalize(); } ~TPSender() { Finalize(); }
/*! /*!
* \brief Connect to receiver with address and ID * \brief Connect to a receiver.
*
* When there are multiple receivers to be connected, application will call `ConnectReceiver`
* for each and then call `ConnectReceiverFinalize` to make sure that either all the connections are
* successfully established or some of them fail.
*
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50091' * \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
* \param recv_id receiver's ID * \param recv_id receiver's ID
* \return True for success and False for fail * \return True for success and False for fail
* *
* ConnectReceiver() is not thread-safe and only one thread can invoke this API. * The function is *not* thread-safe; only one thread can invoke this API.
*/ */
bool ConnectReceiver(const std::string& addr, int recv_id); bool ConnectReceiver(const std::string& addr, int recv_id) override;
/*! /*!
* \brief Send RPCMessage to specified Receiver. * \brief Send RPCMessage to specified Receiver.
* \param msg data message \param recv_id receiver's ID * \param msg data message
* \param recv_id receiver's ID
*/ */
void Send(const RPCMessage& msg, int recv_id); void Send(const RPCMessage& msg, int recv_id) override;
/*! /*!
* \brief Finalize TPSender * \brief Finalize TPSender
*/ */
void Finalize(); void Finalize() override;
/*! /*!
* \brief Communicator type: 'tp' * \brief Communicator type: 'tp'
*/ */
inline std::string Type() const { return std::string("tp"); } const std::string &NetType() const override {
static const std::string net_type = "tensorpipe";
return net_type;
}
private: private:
/*! /*!
...@@ -95,7 +103,7 @@ class TPSender { ...@@ -95,7 +103,7 @@ class TPSender {
* *
* Tensorpipe Receiver is the communicator implemented by tcp socket. * Tensorpipe Receiver is the communicator implemented by tcp socket.
*/ */
class TPReceiver { class TPReceiver : public RPCReceiver {
public: public:
/*! /*!
* \brief Receiver constructor * \brief Receiver constructor
...@@ -121,33 +129,29 @@ class TPReceiver { ...@@ -121,33 +129,29 @@ class TPReceiver {
* *
* Wait() is not thread-safe and only one thread can invoke this API. * Wait() is not thread-safe and only one thread can invoke this API.
*/ */
bool Wait(const std::string &addr, int num_sender, bool blocking = true); bool Wait(const std::string &addr, int num_sender,
bool blocking = true) override;
/*! /*!
* \brief Recv RPCMessage from Sender. Actually removing data from queue. * \brief Recv RPCMessage from Sender. Actually removing data from queue.
* \param msg pointer of RPCmessage * \param msg pointer of RPCmessage
* \param send_id which sender current msg comes from
* \return Status code
*
* (1) The Recv() API is blocking, which will not
* return until getting data from message queue.
* (2) The Recv() API is thread-safe.
* (3) Memory allocated by communicator but will not own it after the function
* returns.
*/ */
void Recv(RPCMessage* msg); void Recv(RPCMessage* msg) override;
/*! /*!
* \brief Finalize SocketReceiver * \brief Finalize SocketReceiver
* *
* Finalize() is not thread-safe and only one thread can invoke this API. * Finalize() is not thread-safe and only one thread can invoke this API.
*/ */
void Finalize(); void Finalize() override;
/*! /*!
* \brief Communicator type: 'tp' (tensorpipe) * \brief Communicator type: 'tp' (tensorpipe)
*/ */
inline std::string Type() const { return std::string("tp"); } const std::string &NetType() const override {
static const std::string net_type = "tensorpipe";
return net_type;
}
/*! /*!
* \brief Issue a receive request on pipe, and push the result into queue * \brief Issue a receive request on pipe, and push the result into queue
......
...@@ -34,9 +34,9 @@ const int kNumReceiver = 3; ...@@ -34,9 +34,9 @@ const int kNumReceiver = 3;
const int kNumMessage = 10; const int kNumMessage = 10;
const char* ip_addr[] = { const char* ip_addr[] = {
"socket://127.0.0.1:50091", "tcp://127.0.0.1:50091",
"socket://127.0.0.1:50092", "tcp://127.0.0.1:50092",
"socket://127.0.0.1:50093" "tcp://127.0.0.1:50093"
}; };
static void start_client(); static void start_client();
...@@ -64,9 +64,9 @@ TEST(SocketCommunicatorTest, SendAndRecv) { ...@@ -64,9 +64,9 @@ TEST(SocketCommunicatorTest, SendAndRecv) {
void start_client() { void start_client() {
SocketSender sender(kQueueSize, kThreadNum); SocketSender sender(kQueueSize, kThreadNum);
for (int i = 0; i < kNumReceiver; ++i) { for (int i = 0; i < kNumReceiver; ++i) {
sender.AddReceiver(ip_addr[i], i); sender.ConnectReceiver(ip_addr[i], i);
} }
sender.Connect(); sender.ConnectReceiverFinalize();
for (int i = 0; i < kNumMessage; ++i) { for (int i = 0; i < kNumMessage; ++i) {
for (int n = 0; n < kNumReceiver; ++n) { for (int n = 0; n < kNumReceiver; ++n) {
char* str_data = new char[9]; char* str_data = new char[9];
...@@ -140,7 +140,7 @@ TEST(SocketCommunicatorTest, SendAndRecv) { ...@@ -140,7 +140,7 @@ TEST(SocketCommunicatorTest, SendAndRecv) {
srand((unsigned)time(NULL)); srand((unsigned)time(NULL));
int port = (rand() % (5000-3000+1))+ 3000; int port = (rand() % (5000-3000+1))+ 3000;
std::string ip_addr = "socket://127.0.0.1:" + std::to_string(port); std::string ip_addr = "tcp://127.0.0.1:" + std::to_string(port);
std::ofstream out("addr.txt"); std::ofstream out("addr.txt");
out << ip_addr; out << ip_addr;
out.close(); out.close();
...@@ -170,8 +170,8 @@ static void start_client() { ...@@ -170,8 +170,8 @@ static void start_client() {
std::istreambuf_iterator<char>()); std::istreambuf_iterator<char>());
t.close(); t.close();
SocketSender sender(kQueueSize, kThreadNum); SocketSender sender(kQueueSize, kThreadNum);
sender.AddReceiver(ip_addr.c_str(), 0); sender.ConnectReceiver(ip_addr.c_str(), 0);
sender.Connect(); sender.ConnectReceiverFinalize();
char* str_data = new char[9]; char* str_data = new char[9];
memcpy(str_data, "123456789", 9); memcpy(str_data, "123456789", 9);
Message msg = {str_data, 9}; Message msg = {str_data, 9};
......
...@@ -83,7 +83,7 @@ class HelloRequest(dgl.distributed.Request): ...@@ -83,7 +83,7 @@ class HelloRequest(dgl.distributed.Request):
res = HelloResponse(self.hello_str, self.integer, new_tensor) res = HelloResponse(self.hello_str, self.integer, new_tensor)
return res return res
def start_server(num_clients, ip_config, server_id=0, keep_alive=False, num_servers=1): def start_server(num_clients, ip_config, server_id=0, keep_alive=False, num_servers=1, net_type='tensorpipe'):
print("Sleep 1 seconds to test client re-connect.") print("Sleep 1 seconds to test client re-connect.")
time.sleep(1) time.sleep(1)
server_state = dgl.distributed.ServerState( server_state = dgl.distributed.ServerState(
...@@ -95,12 +95,13 @@ def start_server(num_clients, ip_config, server_id=0, keep_alive=False, num_serv ...@@ -95,12 +95,13 @@ def start_server(num_clients, ip_config, server_id=0, keep_alive=False, num_serv
ip_config=ip_config, ip_config=ip_config,
num_servers=num_servers, num_servers=num_servers,
num_clients=num_clients, num_clients=num_clients,
server_state=server_state) server_state=server_state,
net_type=net_type)
def start_client(ip_config, group_id=0, num_servers=1): def start_client(ip_config, group_id=0, num_servers=1, net_type='tensorpipe'):
dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse) dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
dgl.distributed.connect_to_server( dgl.distributed.connect_to_server(
ip_config=ip_config, num_servers=num_servers, group_id=group_id) ip_config=ip_config, num_servers=num_servers, group_id=group_id, net_type=net_type)
req = HelloRequest(STR, INTEGER, TENSOR, simple_func) req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
# test send and recv # test send and recv
dgl.distributed.send_request(0, req) dgl.distributed.send_request(0, req)
...@@ -183,16 +184,18 @@ def test_rpc(): ...@@ -183,16 +184,18 @@ def test_rpc():
pclient.join() pclient.join()
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_multi_client(): @pytest.mark.parametrize("net_type", ['socket', 'tensorpipe'])
def test_multi_client(net_type):
reset_envs() reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
generate_ip_config("rpc_ip_config_mul_client.txt", 1, 1) ip_config = "rpc_ip_config_mul_client.txt"
generate_ip_config(ip_config, 1, 1)
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
num_clients = 20 num_clients = 20
pserver = ctx.Process(target=start_server, args=(num_clients, "rpc_ip_config_mul_client.txt")) pserver = ctx.Process(target=start_server, args=(num_clients, ip_config, 0, False, 1, net_type))
pclient_list = [] pclient_list = []
for i in range(num_clients): for i in range(num_clients):
pclient = ctx.Process(target=start_client, args=("rpc_ip_config_mul_client.txt",)) pclient = ctx.Process(target=start_client, args=(ip_config, 0, 1, net_type))
pclient_list.append(pclient) pclient_list.append(pclient)
pserver.start() pserver.start()
for i in range(num_clients): for i in range(num_clients):
...@@ -280,5 +283,6 @@ if __name__ == '__main__': ...@@ -280,5 +283,6 @@ if __name__ == '__main__':
test_serialize() test_serialize()
test_rpc_msg() test_rpc_msg()
test_rpc() test_rpc()
test_multi_client() test_multi_client('socket')
test_multi_client('tesnsorpipe')
test_multi_thread_rpc() test_multi_thread_rpc()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment